Skip to content

Commit 03f8ba9

Browse files
authored
add benchmark function, used internally (#151)
* add benchmark function, used internally * lint * 2.1.4 (fix messed up publish command)
1 parent e05af9e commit 03f8ba9

File tree

14 files changed

+290
-12
lines changed

14 files changed

+290
-12
lines changed

actions.hpp

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,14 +684,137 @@ json action_current_status(app_t &app, json &body)
684684
};
685685
}
686686

687+
//
688+
// benchmark & perplexity
689+
//
690+
691+
json action_test_benchmark(app_t &app, json &body)
692+
{
693+
std::string type = body.at("type"); // "pp" (prompt proc) or "tg" (tok gen)
694+
int n_samples = body.at("n_samples"); // n_batch in pp and n_predict in pg
695+
696+
llama_kv_cache_clear(app.ctx);
697+
int n_vocab = llama_vocab_n_tokens(app.vocab);
698+
int64_t t_start = ggml_time_ms();
699+
700+
if (type == "pp")
701+
{
702+
llama_batch batch = llama_batch_init(n_samples, 0, 1);
703+
for (int i = 0; i < n_samples; i++)
704+
{
705+
common_batch_add(batch, i % n_vocab, i, {0}, i == n_samples - 1);
706+
}
707+
int ret = llama_decode(app.ctx, batch);
708+
llama_batch_free(batch);
709+
if (ret != 0)
710+
{
711+
return json{{"error", "llama_decode failed with status = " + std::to_string(ret)}};
712+
}
713+
}
714+
else if (type == "tg")
715+
{
716+
llama_batch batch = llama_batch_init(1, 0, 1);
717+
for (int i = 0; i < n_samples; i++)
718+
{
719+
common_batch_clear(batch);
720+
common_batch_add(batch, i % n_vocab, i, {0}, true);
721+
int ret = llama_decode(app.ctx, batch);
722+
if (ret != 0)
723+
{
724+
return json{{"error", "llama_decode failed with status = " + std::to_string(ret)}};
725+
}
726+
}
727+
llama_batch_free(batch);
728+
}
729+
else
730+
{
731+
return json{{"error", "unknown type: " + type}};
732+
}
733+
734+
int64_t t_end = ggml_time_ms();
735+
return json{
736+
{"success", true},
737+
{"t_ms", t_end - t_start},
738+
};
739+
}
740+
741+
json action_test_perplexity(app_t &app, json &body)
742+
{
743+
llama_tokens input = body["input"];
744+
const size_t n = input.size();
745+
746+
int64_t t_start = ggml_time_ms();
747+
748+
if (n < 2)
749+
{
750+
return json{{"error", "Input must contain at least two tokens"}};
751+
}
752+
753+
// Clear existing context to start fresh
754+
llama_kv_cache_clear(app.ctx);
755+
app.tokens.clear();
756+
757+
const int32_t n_vocab = llama_vocab_n_tokens(app.vocab);
758+
double nll = 0.0;
759+
760+
static auto log_softmax = [](int n_vocab, const float *logits, int tok) -> double
761+
{
762+
float max_logit = logits[0];
763+
for (int i = 1; i < n_vocab; ++i)
764+
{
765+
max_logit = std::max(max_logit, logits[i]);
766+
}
767+
double sum_exp = 0.0;
768+
for (int i = 0; i < n_vocab; ++i)
769+
{
770+
sum_exp += expf(logits[i] - max_logit);
771+
}
772+
return logits[tok] - max_logit - log(sum_exp);
773+
};
774+
775+
for (size_t i = 0; i < n - 1; ++i)
776+
{
777+
// Prepare batch with current token (input[i])
778+
common_batch_clear(app.batch);
779+
common_batch_add(app.batch, input[i], i, {0}, true); // Enable logits for this token
780+
781+
if (llama_decode(app.ctx, app.batch) != 0)
782+
{
783+
return json{{"error", "Decoding failed at position " + std::to_string(i)}};
784+
}
785+
786+
float *logits = llama_get_logits_ith(app.ctx, 0);
787+
788+
// Get true next token (input[i+1])
789+
const int32_t true_token = input[i + 1];
790+
791+
nll += -log_softmax(n_vocab, logits, true_token);
792+
}
793+
794+
// Calculate final metrics
795+
const double cross_entropy = nll / (n - 1);
796+
const double ppl = std::exp(cross_entropy);
797+
798+
int64_t t_end = ggml_time_ms();
799+
800+
return json{
801+
{"success", true},
802+
{"ppl", ppl},
803+
{"nll", nll},
804+
{"cross_entropy", cross_entropy},
805+
{"n_tokens", n - 1},
806+
{"t_ms", t_end - t_start},
807+
};
808+
}
809+
687810
//////////////////////////////////////////
688811

689812
// because we can't support jinja for now, we temporary use an old version of common_chat_apply_template
690813
// TODO: support jinja
691814
std::string common_chat_apply_template_old(const struct llama_model *model,
692-
const std::string &tmpl,
693-
const std::vector<common_chat_msg> &msgs,
694-
bool add_ass)
815+
const std::string &tmpl,
816+
const std::vector<common_chat_msg> &msgs,
817+
bool add_ass)
695818
{
696819
int alloc_size = 0;
697820
bool fallback = false; // indicate if we must fallback to default chatml

examples/main/src/App.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Sidebar from './components/Sidebar';
77
import { MessagesProvider } from './utils/messages.context';
88
import { Screen } from './utils/types';
99
import { useWllama, WllamaProvider } from './utils/wllama.context';
10+
import './utils/benchmark';
1011

1112
function App() {
1213
return (
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import { Wllama } from '@wllama/wllama';
2+
import { WLLAMA_CONFIG_PATHS } from '../config';
3+
import { delay } from './utils';
4+
5+
// TODO: this is console-only for now, should we implement a GUI in the future?
6+
7+
const WIKITEXT_URL =
8+
'https://raw.githubusercontent.com/wangfin/QAsystem/refs/heads/master/QAManagement/language_model/data/wikitext-2/valid.txt';
9+
10+
const BENCH_MODELS = [
11+
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q8_0.gguf',
12+
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf',
13+
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_K_M.gguf',
14+
'https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q5_K_L.gguf',
15+
];
16+
17+
const BENCH_N_REPEATED = 4;
18+
19+
const BENCH_CONFIGS: { type: 'pp' | 'tg'; n_samples: number }[] = [
20+
{ type: 'pp', n_samples: 32 },
21+
{ type: 'pp', n_samples: 64 },
22+
{ type: 'pp', n_samples: 128 },
23+
{ type: 'pp', n_samples: 256 },
24+
{ type: 'tg', n_samples: 32 },
25+
{ type: 'tg', n_samples: 64 },
26+
{ type: 'tg', n_samples: 128 },
27+
{ type: 'tg', n_samples: 256 },
28+
];
29+
30+
async function loadModel(modelUrl: string) {
31+
const modelFile = modelUrl.split('/').pop();
32+
const wllama = new Wllama(WLLAMA_CONFIG_PATHS);
33+
await wllama.loadModelFromUrl(modelUrl, {
34+
n_batch: 512,
35+
n_ctx: 4096,
36+
progressCallback: ({ total, loaded }) => {
37+
console.log(`Model ${modelFile}: ${Math.round((100 * loaded) / total)}%`);
38+
},
39+
});
40+
return { wllama, modelFile };
41+
}
42+
43+
async function benchmark() {
44+
const output: any[][] = [
45+
['model', 'threads', 'test', 't/s'],
46+
['---', '---', '---', '---'],
47+
];
48+
for (const modelUrl of BENCH_MODELS) {
49+
const [{ wllama, modelFile }] = await Promise.all([
50+
loadModel(modelUrl),
51+
delay(10000), // force delay for CPU to cool down
52+
]);
53+
console.clear();
54+
const nThreads = wllama.getNumThreads();
55+
for (const config of BENCH_CONFIGS) {
56+
const { type, n_samples } = config;
57+
const results: number[] = [];
58+
for (let i = 0; i < BENCH_N_REPEATED; i++) {
59+
console.log('Running', modelFile, config);
60+
const { t_ms } = await wllama._testBenchmark(type, n_samples);
61+
const t_per_tok = n_samples / (t_ms / 1000);
62+
results.push(t_per_tok);
63+
console.log('Run ', i, 'pref:', t_per_tok, 't/s');
64+
}
65+
const t_avg = results.reduce((a, b) => a + b, 0) / results.length;
66+
const t_plus_minus = Math.abs(
67+
Math.max(...results) - Math.min(...results)
68+
);
69+
output.push([
70+
modelFile,
71+
nThreads,
72+
`${type} ${n_samples}`,
73+
`${t_avg.toFixed(2)} ± ${t_plus_minus.toFixed(2)}`,
74+
]);
75+
}
76+
wllama.exit();
77+
}
78+
79+
console.table(output);
80+
const markdown = output
81+
.map((row) => '| ' + row.join(' | ') + ' |')
82+
.join('\n');
83+
console.log(markdown);
84+
}
85+
86+
async function perplexity() {
87+
const output: any[][] = [
88+
['model', 'PPL', 'n_tokens'],
89+
['---', '---', '---'],
90+
];
91+
const LIMIT_TOKENS = 2048;
92+
const wikitext = await fetch(WIKITEXT_URL).then((res) => res.text());
93+
console.log('Loaded wikitext:', wikitext.substring(0, 100), '...');
94+
for (const modelUrl of BENCH_MODELS) {
95+
const { wllama, modelFile } = await loadModel(modelUrl);
96+
console.clear();
97+
let tokens = await wllama.tokenize(
98+
wikitext.substring(0, LIMIT_TOKENS * 16)
99+
);
100+
tokens = tokens.slice(0, LIMIT_TOKENS);
101+
console.log('Running', modelFile, 'n_tokens', tokens.length);
102+
const { ppl } = await wllama._testPerplexity(tokens);
103+
console.log('PPL:', ppl);
104+
output.push([modelFile, ppl, tokens.length]);
105+
wllama.exit();
106+
}
107+
108+
console.table(output);
109+
const markdown = output
110+
.map((row) => '| ' + row.join(' | ') + ' |')
111+
.join('\n');
112+
console.log(markdown);
113+
}
114+
115+
(window as any).__benchmark = benchmark;
116+
(window as any).__perplexity = perplexity;

llama.cpp

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@wllama/wllama",
3-
"version": "2.1.3",
3+
"version": "2.1.4",
44
"description": "WebAssembly binding for llama.cpp - Enabling on-browser LLM inference",
55
"main": "index.js",
66
"type": "module",

scripts/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ services:
1919
mkdir -p wasm/single-thread
2020
cd wasm/single-thread
2121
22-
export SHARED_EMCC_CFLAGS="--no-entry -O3 -msimd128 -fno-rtti -DNDEBUG -flto=full -frtti -fwasm-exceptions -sEXPORT_ALL=1 -sEXPORT_ES6=0 -sMODULARIZE=0 -sINITIAL_MEMORY=128MB -sMAXIMUM_MEMORY=4096MB -sALLOW_MEMORY_GROWTH=1 -sFORCE_FILESYSTEM=1 -sEXPORTED_FUNCTIONS=_main,_wllama_start,_wllama_action,_wllama_exit,_wllama_debug -sEXPORTED_RUNTIME_METHODS=ccall,cwrap -sNO_EXIT_RUNTIME=1"
22+
export SHARED_EMCC_CFLAGS="--no-entry -O3 -msimd128 -DNDEBUG -flto=full -frtti -fwasm-exceptions -sEXPORT_ALL=1 -sEXPORT_ES6=0 -sMODULARIZE=0 -sINITIAL_MEMORY=128MB -sMAXIMUM_MEMORY=4096MB -sALLOW_MEMORY_GROWTH=1 -sFORCE_FILESYSTEM=1 -sEXPORTED_FUNCTIONS=_main,_wllama_start,_wllama_action,_wllama_exit,_wllama_debug -sEXPORTED_RUNTIME_METHODS=ccall,cwrap -sNO_EXIT_RUNTIME=1"
2323
2424
# emcc --clear-cache
2525

src/multi-thread/wllama.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/multi-thread/wllama.wasm

10.4 KB
Binary file not shown.

src/single-thread/wllama.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/single-thread/wllama.wasm

11.1 KB
Binary file not shown.

0 commit comments

Comments
 (0)