3
3
#include < algorithm>
4
4
#include < sstream>
5
5
#include < vector>
6
+ #include < unordered_map>
6
7
7
8
#include " common.h"
8
9
#include " common/grammar-parser.h"
@@ -1334,6 +1335,8 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
1334
1335
float repeat_penalty_presence_penalty = 0 .00f ; // 0.0 = disabled
1335
1336
float repeat_penalty_frequency_penalty = 0 .00f ; // 0.0 = disabled
1336
1337
std::vector<llama_token> repeat_penalty_tokens;
1338
+ std::unordered_map<llama_token, float > tokenBiases;
1339
+ bool useTokenBiases = false ;
1337
1340
bool use_repeat_penalty = false ;
1338
1341
1339
1342
AddonContextSampleTokenWorker (const Napi::CallbackInfo& info, AddonContext* ctx)
@@ -1378,6 +1381,19 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
1378
1381
use_repeat_penalty = true ;
1379
1382
}
1380
1383
1384
+ if (options.Has (" tokenBiasKeys" ) && options.Has (" tokenBiasValues" )) {
1385
+ Napi::Uint32Array tokenBiasKeys = options.Get (" tokenBiasKeys" ).As <Napi::Uint32Array>();
1386
+ Napi::Float32Array tokenBiasValues = options.Get (" tokenBiasValues" ).As <Napi::Float32Array>();
1387
+
1388
+ if (tokenBiasKeys.ElementLength () == tokenBiasValues.ElementLength ()) {
1389
+ for (size_t i = 0 ; i < tokenBiasKeys.ElementLength (); i++) {
1390
+ tokenBiases[static_cast <llama_token>(tokenBiasKeys[i])] = tokenBiasValues[i];
1391
+ }
1392
+
1393
+ useTokenBiases = true ;
1394
+ }
1395
+ }
1396
+
1381
1397
if (options.Has (" repeatPenaltyPresencePenalty" )) {
1382
1398
repeat_penalty_presence_penalty = options.Get (" repeatPenaltyPresencePenalty" ).As <Napi::Number>().FloatValue ();
1383
1399
}
@@ -1426,18 +1442,33 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
1426
1442
// Select the best prediction.
1427
1443
auto logits = llama_get_logits_ith (ctx->ctx , batchLogitIndex);
1428
1444
auto n_vocab = llama_n_vocab (ctx->model ->model );
1445
+ auto eos_token = llama_token_eos (ctx->model ->model );
1429
1446
1430
1447
std::vector<llama_token_data> candidates;
1431
1448
candidates.reserve (n_vocab);
1432
1449
1433
1450
for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
1434
- candidates.emplace_back (llama_token_data { token_id, logits[token_id], 0 .0f });
1451
+ auto logit = logits[token_id];
1452
+
1453
+ if (useTokenBiases) {
1454
+ bool hasTokenBias = tokenBiases.find (token_id) != tokenBiases.end ();
1455
+ if (hasTokenBias) {
1456
+ auto logitBias = tokenBiases.at (token_id);
1457
+ if (logitBias == -INFINITY || logitBias < -INFINITY) {
1458
+ if (token_id != eos_token) {
1459
+ logit = -INFINITY;
1460
+ }
1461
+ } else {
1462
+ logit += logitBias;
1463
+ }
1464
+ }
1465
+ }
1466
+
1467
+ candidates.emplace_back (llama_token_data { token_id, logit, 0 .0f });
1435
1468
}
1436
1469
1437
1470
llama_token_data_array candidates_p = { candidates.data (), candidates.size (), false };
1438
1471
1439
- auto eos_token = llama_token_eos (ctx->model ->model );
1440
-
1441
1472
if (use_repeat_penalty && !repeat_penalty_tokens.empty ()) {
1442
1473
llama_sample_repetition_penalties (
1443
1474
ctx->ctx ,
@@ -1452,6 +1483,13 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {
1452
1483
1453
1484
if (use_grammar && (grammar_evaluation_state)->grammar != nullptr ) {
1454
1485
llama_sample_grammar (ctx->ctx , &candidates_p, (grammar_evaluation_state)->grammar );
1486
+
1487
+ if ((candidates_p.size == 0 || candidates_p.data [0 ].logit == -INFINITY) && useTokenBiases) {
1488
+ // logit biases caused grammar sampling to fail, so sampling again without logit biases
1489
+ useTokenBiases = false ;
1490
+ SampleToken ();
1491
+ return ;
1492
+ }
1455
1493
}
1456
1494
1457
1495
if (temperature <= 0 ) {
0 commit comments