Skip to content

Commit 3ad4494

Browse files
authored
feat: token biases (#196)
1 parent b542b53 commit 3ad4494

File tree

16 files changed

+253
-22
lines changed

16 files changed

+253
-22
lines changed

llama/addon.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <algorithm>
44
#include <sstream>
55
#include <vector>
6+
#include <unordered_map>
67

78
#include "common.h"
89
#include "common/grammar-parser.h"
@@ -1334,6 +1335,8 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
13341335
float repeat_penalty_presence_penalty = 0.00f; // 0.0 = disabled
13351336
float repeat_penalty_frequency_penalty = 0.00f; // 0.0 = disabled
13361337
std::vector<llama_token> repeat_penalty_tokens;
1338+
std::unordered_map<llama_token, float> tokenBiases;
1339+
bool useTokenBiases = false;
13371340
bool use_repeat_penalty = false;
13381341

13391342
AddonContextSampleTokenWorker(const Napi::CallbackInfo& info, AddonContext* ctx)
@@ -1378,6 +1381,19 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
13781381
use_repeat_penalty = true;
13791382
}
13801383

1384+
if (options.Has("tokenBiasKeys") && options.Has("tokenBiasValues")) {
1385+
Napi::Uint32Array tokenBiasKeys = options.Get("tokenBiasKeys").As<Napi::Uint32Array>();
1386+
Napi::Float32Array tokenBiasValues = options.Get("tokenBiasValues").As<Napi::Float32Array>();
1387+
1388+
if (tokenBiasKeys.ElementLength() == tokenBiasValues.ElementLength()) {
1389+
for (size_t i = 0; i < tokenBiasKeys.ElementLength(); i++) {
1390+
tokenBiases[static_cast<llama_token>(tokenBiasKeys[i])] = tokenBiasValues[i];
1391+
}
1392+
1393+
useTokenBiases = true;
1394+
}
1395+
}
1396+
13811397
if (options.Has("repeatPenaltyPresencePenalty")) {
13821398
repeat_penalty_presence_penalty = options.Get("repeatPenaltyPresencePenalty").As<Napi::Number>().FloatValue();
13831399
}
@@ -1426,18 +1442,33 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
14261442
// Select the best prediction.
14271443
auto logits = llama_get_logits_ith(ctx->ctx, batchLogitIndex);
14281444
auto n_vocab = llama_n_vocab(ctx->model->model);
1445+
auto eos_token = llama_token_eos(ctx->model->model);
14291446

14301447
std::vector<llama_token_data> candidates;
14311448
candidates.reserve(n_vocab);
14321449

14331450
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
1434-
candidates.emplace_back(llama_token_data { token_id, logits[token_id], 0.0f });
1451+
auto logit = logits[token_id];
1452+
1453+
if (useTokenBiases) {
1454+
bool hasTokenBias = tokenBiases.find(token_id) != tokenBiases.end();
1455+
if (hasTokenBias) {
1456+
auto logitBias = tokenBiases.at(token_id);
1457+
if (logitBias == -INFINITY || logitBias < -INFINITY) {
1458+
if (token_id != eos_token) {
1459+
logit = -INFINITY;
1460+
}
1461+
} else {
1462+
logit += logitBias;
1463+
}
1464+
}
1465+
}
1466+
1467+
candidates.emplace_back(llama_token_data { token_id, logit, 0.0f });
14351468
}
14361469

14371470
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
14381471

1439-
auto eos_token = llama_token_eos(ctx->model->model);
1440-
14411472
if (use_repeat_penalty && !repeat_penalty_tokens.empty()) {
14421473
llama_sample_repetition_penalties(
14431474
ctx->ctx,
@@ -1452,6 +1483,13 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
14521483

14531484
if (use_grammar && (grammar_evaluation_state)->grammar != nullptr) {
14541485
llama_sample_grammar(ctx->ctx, &candidates_p, (grammar_evaluation_state)->grammar);
1486+
1487+
if ((candidates_p.size == 0 || candidates_p.data[0].logit == -INFINITY) && useTokenBiases) {
1488+
// logit biases caused grammar sampling to fail, so sampling again without logit biases
1489+
useTokenBiases = false;
1490+
SampleToken();
1491+
return;
1492+
}
14551493
}
14561494

14571495
if (temperature <= 0) {

src/bindings/AddonTypes.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ export type AddonContext = {
108108
repeatPenaltyTokens?: Uint32Array,
109109
repeatPenaltyPresencePenalty?: number, // alpha_presence
110110
repeatPenaltyFrequencyPenalty?: number, // alpha_frequency
111-
grammarEvaluationState?: AddonGrammarEvaluationState
111+
grammarEvaluationState?: AddonGrammarEvaluationState,
112+
tokenBiasKeys?: Uint32Array,
113+
tokenBiasValues?: Float32Array
112114
}): Promise<Token>,
113115
disposeSequence(sequenceId: number): void,
114116

src/cli/commands/ChatCommand.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ async function RunChat({
323323
successText: chalk.blue("Model loaded"),
324324
failText: chalk.blue("Failed to load model"),
325325
liveUpdates: !debug,
326-
noProgress: debug
326+
noProgress: debug,
327+
liveCtrlCSendsAbortSignal: true
327328
}, async (progressUpdater) => {
328329
try {
329330
return await llama.loadModel({
@@ -336,8 +337,14 @@ async function RunChat({
336337
ignoreMemorySafetyChecks: gpuLayers != null,
337338
onLoadProgress(loadProgress: number) {
338339
progressUpdater.setProgress(loadProgress);
339-
}
340+
},
341+
loadSignal: progressUpdater.abortSignal
340342
});
343+
} catch (err) {
344+
if (err === progressUpdater.abortSignal?.reason)
345+
process.exit(0);
346+
347+
throw err;
341348
} finally {
342349
if (llama.logLevel === LlamaLogLevel.debug) {
343350
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing

src/cli/commands/CompleteCommand.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ async function RunCompletion({
238238
successText: chalk.blue("Model loaded"),
239239
failText: chalk.blue("Failed to load model"),
240240
liveUpdates: !debug,
241-
noProgress: debug
241+
noProgress: debug,
242+
liveCtrlCSendsAbortSignal: true
242243
}, async (progressUpdater) => {
243244
try {
244245
return await llama.loadModel({
@@ -251,8 +252,14 @@ async function RunCompletion({
251252
ignoreMemorySafetyChecks: gpuLayers != null,
252253
onLoadProgress(loadProgress: number) {
253254
progressUpdater.setProgress(loadProgress);
254-
}
255+
},
256+
loadSignal: progressUpdater.abortSignal
255257
});
258+
} catch (err) {
259+
if (err === progressUpdater.abortSignal?.reason)
260+
process.exit(0);
261+
262+
throw err;
256263
} finally {
257264
if (llama.logLevel === LlamaLogLevel.debug) {
258265
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing

src/cli/commands/InfillCommand.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ async function RunInfill({
262262
successText: chalk.blue("Model loaded"),
263263
failText: chalk.blue("Failed to load model"),
264264
liveUpdates: !debug,
265-
noProgress: debug
265+
noProgress: debug,
266+
liveCtrlCSendsAbortSignal: true
266267
}, async (progressUpdater) => {
267268
try {
268269
return await llama.loadModel({
@@ -275,8 +276,14 @@ async function RunInfill({
275276
ignoreMemorySafetyChecks: gpuLayers != null,
276277
onLoadProgress(loadProgress: number) {
277278
progressUpdater.setProgress(loadProgress);
278-
}
279+
},
280+
loadSignal: progressUpdater.abortSignal
279281
});
282+
} catch (err) {
283+
if (err === progressUpdater.abortSignal?.reason)
284+
process.exit(0);
285+
286+
throw err;
280287
} finally {
281288
if (llama.logLevel === LlamaLogLevel.debug) {
282289
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing

src/cli/utils/ConsoleInteraction.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ export class ConsoleInteraction {
8484

8585
if (callbacks.length === 0 && key === ConsoleInteractionKey.ctrlC) {
8686
process.stdout.write("\n");
87+
this.stop();
8788
process.exit(0);
8889
}
8990

src/cli/utils/consolePromptQuestion.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ export async function consolePromptQuestion(question: string, {
3636
clearLastLines(linesUsed);
3737

3838
if (exitOnCtrlC) {
39+
rl.close();
3940
process.exit(0);
4041
} else
4142
accept(null);

src/evaluator/LlamaChat/LlamaChat.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import {UNKNOWN_UNICODE_CHAR} from "../../consts.js";
1414
import {getQueuedTokensBeforeStopTrigger} from "../../utils/getQueuedTokensBeforeStopTrigger.js";
1515
import {resolveChatWrapper} from "../../chatWrappers/utils/resolveChatWrapper.js";
1616
import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js";
17+
import {TokenBias} from "../TokenBias.js";
1718
import {
1819
eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy
1920
} from "./utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js";
@@ -85,6 +86,13 @@ export type LLamaChatGenerateResponseOptions<Functions extends ChatModelFunction
8586

8687
repeatPenalty?: false | LLamaContextualRepeatPenalty,
8788

89+
/**
90+
* Adjust the probability of tokens being generated.
91+
* Can be used to bias the model to generate tokens that you want it to lean towards,
92+
* or to avoid generating tokens that you want it to avoid.
93+
*/
94+
tokenBias?: TokenBias | (() => TokenBias),
95+
8896
/**
8997
* See the parameter `evaluationPriority` on the `LlamaContextSequence.evaluate()` function for more information.
9098
*/
@@ -249,6 +257,7 @@ export class LlamaChat {
249257
grammar,
250258
trimWhitespaceSuffix = false,
251259
repeatPenalty = {},
260+
tokenBias,
252261
evaluationPriority = 5,
253262
functions,
254263
documentFunctionParams,
@@ -532,6 +541,7 @@ export class LlamaChat {
532541
frequencyPenalty,
533542
presencePenalty
534543
},
544+
tokenBias,
535545
evaluationPriority,
536546
yieldEosToken: true
537547
}));

src/evaluator/LlamaChatSession/LlamaChatSession.ts

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {LlamaContextSequence} from "../LlamaContext/LlamaContext.js";
77
import {LlamaGrammar} from "../LlamaGrammar.js";
88
import {LlamaChat, LLamaChatContextShiftOptions, LlamaChatResponse} from "../LlamaChat/LlamaChat.js";
99
import {EvaluationPriority} from "../LlamaContext/types.js";
10+
import {TokenBias} from "../TokenBias.js";
1011

1112

1213
export type LlamaChatSessionOptions = {
@@ -96,7 +97,14 @@ export type LLamaChatPromptOptions<Functions extends ChatSessionModelFunctions |
9697
*/
9798
evaluationPriority?: EvaluationPriority,
9899

99-
repeatPenalty?: false | LlamaChatSessionRepeatPenalty
100+
repeatPenalty?: false | LlamaChatSessionRepeatPenalty,
101+
102+
/**
103+
* Adjust the probability of tokens being generated.
104+
* Can be used to bias the model to generate tokens that you want it to lean towards,
105+
* or to avoid generating tokens that you want it to avoid.
106+
*/
107+
tokenBias?: TokenBias | (() => TokenBias)
100108
} & ({
101109
grammar?: LlamaGrammar,
102110
functions?: never,
@@ -249,14 +257,16 @@ export class LlamaChatSession {
249257
topP,
250258
grammar,
251259
trimWhitespaceSuffix = false,
252-
repeatPenalty
260+
repeatPenalty,
261+
tokenBias
253262
}: LLamaChatPromptOptions<Functions> = {}) {
254263
const {responseText} = await this.promptWithMeta<Functions>(prompt, {
255264
// this is a workaround to allow passing both `functions` and `grammar`
256265
functions: functions as undefined,
257266
documentFunctionParams: documentFunctionParams as undefined,
258267

259-
onToken, signal, maxTokens, temperature, minP, topK, topP, grammar, trimWhitespaceSuffix, repeatPenalty
268+
onToken, signal, maxTokens, temperature, minP, topK, topP, grammar, trimWhitespaceSuffix, repeatPenalty,
269+
tokenBias
260270
});
261271

262272
return responseText;
@@ -279,6 +289,7 @@ export class LlamaChatSession {
279289
grammar,
280290
trimWhitespaceSuffix = false,
281291
repeatPenalty,
292+
tokenBias,
282293
evaluationPriority
283294
}: LLamaChatPromptOptions<Functions> = {}) {
284295
this._ensureNotDisposed();
@@ -325,6 +336,7 @@ export class LlamaChatSession {
325336
minP,
326337
topK,
327338
topP,
339+
tokenBias,
328340
maxTokens,
329341
temperature,
330342
trimWhitespaceSuffix,

src/evaluator/LlamaCompletion.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {LlamaGrammarEvaluationState} from "./LlamaGrammarEvaluationState.js";
1212
import {LlamaGrammar} from "./LlamaGrammar.js";
1313
import {EvaluationPriority} from "./LlamaContext/types.js";
1414
import {LlamaContextSequence} from "./LlamaContext/LlamaContext.js";
15+
import {TokenBias} from "./TokenBias.js";
1516

1617
export type LlamaCompletionOptions = {
1718
contextSequence: LlamaContextSequence,
@@ -76,6 +77,13 @@ export type LlamaCompletionGenerationOptions = {
7677

7778
repeatPenalty?: false | LLamaContextualRepeatPenalty,
7879

80+
/**
81+
* Adjust the probability of tokens being generated.
82+
* Can be used to bias the model to generate tokens that you want it to lean towards,
83+
* or to avoid generating tokens that you want it to avoid.
84+
*/
85+
tokenBias?: TokenBias | (() => TokenBias),
86+
7987
/**
8088
* See the parameter `evaluationPriority` on the `LlamaContextSequence.evaluate()` function for more information.
8189
*/
@@ -195,6 +203,7 @@ export class LlamaCompletion {
195203
topP,
196204
trimWhitespaceSuffix = false,
197205
repeatPenalty = {},
206+
tokenBias,
198207
evaluationPriority = 5,
199208
grammar,
200209
stopGenerationTriggers,
@@ -274,6 +283,7 @@ export class LlamaCompletion {
274283
topP,
275284
trimWhitespaceSuffix,
276285
repeatPenalty,
286+
tokenBias,
277287
evaluationPriority,
278288
grammar,
279289
contextShiftSize,
@@ -326,6 +336,7 @@ export class LlamaCompletion {
326336
topP,
327337
trimWhitespaceSuffix = false,
328338
repeatPenalty = {},
339+
tokenBias,
329340
evaluationPriority = 5,
330341
grammar,
331342
contextShiftSize = defaultContextShiftSize,
@@ -455,6 +466,7 @@ export class LlamaCompletion {
455466
topP,
456467
trimWhitespaceSuffix,
457468
repeatPenalty,
469+
tokenBias,
458470
evaluationPriority,
459471
grammar,
460472
contextShiftSize,
@@ -489,6 +501,7 @@ export class LlamaCompletion {
489501
topP,
490502
trimWhitespaceSuffix = false,
491503
repeatPenalty = {},
504+
tokenBias,
492505
evaluationPriority = 5,
493506
grammar,
494507
contextShiftSize = defaultContextShiftSize,
@@ -603,6 +616,7 @@ export class LlamaCompletion {
603616
frequencyPenalty,
604617
presencePenalty
605618
},
619+
tokenBias,
606620
evaluationPriority,
607621
yieldEosToken: true
608622
}));

0 commit comments

Comments
 (0)