Skip to content

Commit 9cdbce9

Browse files
authored
feat(JSON Schema Grammar): $defs and $ref support with full inferred types (#472)
* feat(JSON Schema Grammar): `$defs` and `$ref` support with full inferred types * feat(`inspect gguf` command): format and print the Jinja chat template with `--key .chatTemplate` * fix(`JinjaTemplateChatWrapper`): first function call prefix detection * fix(`QwenChatWrapper`): improve Qwen chat template detection * fix: apply `maxTokens` on function calling parameters * fix: adjust default prompt completion length based on SWA size when relevant * fix: improve thought segmentation syntax extraction * fix: adapt to `llama.cpp` changes
1 parent ea8d904 commit 9cdbce9

31 files changed

+1463
-139
lines changed

llama/addon/AddonContext.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ Napi::Value AddonContext::DisposeSequence(const Napi::CallbackInfo& info) {
587587

588588
int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
589589

590-
bool result = llama_kv_self_seq_rm(ctx, sequenceId, -1, -1);
590+
bool result = llama_memory_seq_rm(llama_get_memory(ctx), sequenceId, -1, -1);
591591

592592
if (!result) {
593593
Napi::Error::New(info.Env(), "Failed to dispose sequence").ThrowAsJavaScriptException();
@@ -606,7 +606,7 @@ Napi::Value AddonContext::RemoveTokenCellsFromSequence(const Napi::CallbackInfo&
606606
int32_t startPos = info[1].As<Napi::Number>().Int32Value();
607607
int32_t endPos = info[2].As<Napi::Number>().Int32Value();
608608

609-
bool result = llama_kv_self_seq_rm(ctx, sequenceId, startPos, endPos);
609+
bool result = llama_memory_seq_rm(llama_get_memory(ctx), sequenceId, startPos, endPos);
610610

611611
return Napi::Boolean::New(info.Env(), result);
612612
}
@@ -621,7 +621,7 @@ Napi::Value AddonContext::ShiftSequenceTokenCells(const Napi::CallbackInfo& info
621621
int32_t endPos = info[2].As<Napi::Number>().Int32Value();
622622
int32_t shiftDelta = info[3].As<Napi::Number>().Int32Value();
623623

624-
llama_kv_self_seq_add(ctx, sequenceId, startPos, endPos, shiftDelta);
624+
llama_memory_seq_add(llama_get_memory(ctx), sequenceId, startPos, endPos, shiftDelta);
625625

626626
return info.Env().Undefined();
627627
}
@@ -634,7 +634,7 @@ Napi::Value AddonContext::GetSequenceKvCacheMinPosition(const Napi::CallbackInfo
634634
int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
635635

636636

637-
const auto minPosition = llama_kv_self_seq_pos_min(ctx, sequenceId);
637+
const auto minPosition = llama_memory_seq_pos_min(llama_get_memory(ctx), sequenceId);
638638

639639
return Napi::Number::New(info.Env(), minPosition);
640640
}
@@ -647,7 +647,7 @@ Napi::Value AddonContext::GetSequenceKvCacheMaxPosition(const Napi::CallbackInfo
647647
int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();
648648

649649

650-
const auto maxPosition = llama_kv_self_seq_pos_max(ctx, sequenceId);
650+
const auto maxPosition = llama_memory_seq_pos_max(llama_get_memory(ctx), sequenceId);
651651

652652
return Napi::Number::New(info.Env(), maxPosition);
653653
}

package-lock.json

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@
197197
"ignore": "^7.0.4",
198198
"ipull": "^3.9.2",
199199
"is-unicode-supported": "^2.1.0",
200-
"lifecycle-utils": "^2.0.0",
200+
"lifecycle-utils": "^2.0.1",
201201
"log-symbols": "^7.0.0",
202202
"nanoid": "^5.1.5",
203203
"node-addon-api": "^8.3.1",

src/bindings/Llama.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import {DisposedError, EventRelay, withLock} from "lifecycle-utils";
55
import {getConsoleLogPrefix} from "../utils/getConsoleLogPrefix.js";
66
import {LlamaModel, LlamaModelOptions} from "../evaluator/LlamaModel/LlamaModel.js";
77
import {DisposeGuard} from "../utils/DisposeGuard.js";
8-
import {GbnfJsonSchema} from "../utils/gbnfJson/types.js";
8+
import {GbnfJsonDefList, GbnfJsonSchema} from "../utils/gbnfJson/types.js";
99
import {LlamaJsonSchemaGrammar} from "../evaluator/LlamaJsonSchemaGrammar.js";
1010
import {LlamaGrammar, LlamaGrammarOptions} from "../evaluator/LlamaGrammar.js";
1111
import {ThreadsSplitter} from "../utils/ThreadsSplitter.js";
@@ -345,8 +345,11 @@ export class Llama {
345345
* @see [Using a JSON Schema Grammar](https://node-llama-cpp.withcat.ai/guide/grammar#json-schema) tutorial
346346
* @see [Reducing Hallucinations When Using JSON Schema Grammar](https://node-llama-cpp.withcat.ai/guide/grammar#reducing-json-schema-hallucinations) tutorial
347347
*/
348-
public async createGrammarForJsonSchema<const T extends GbnfJsonSchema>(schema: Readonly<T>) {
349-
return new LlamaJsonSchemaGrammar<T>(this, schema);
348+
public async createGrammarForJsonSchema<
349+
const T extends GbnfJsonSchema<Defs>,
350+
const Defs extends GbnfJsonDefList<Defs> = Record<any, any>
351+
>(schema: Readonly<T> & GbnfJsonSchema<Defs>) {
352+
return new LlamaJsonSchemaGrammar<T, Defs>(this, schema);
350353
}
351354
/* eslint-enable @stylistic/max-len */
352355

src/chatWrappers/QwenChatWrapper.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ export class QwenChatWrapper extends ChatWrapper {
8484
segments: {
8585
reiterateStackAfterFunctionCalls: true,
8686
thought: {
87-
prefix: LlamaText(new SpecialTokensText("<think>")),
88-
suffix: LlamaText(new SpecialTokensText("</think>"))
87+
prefix: LlamaText(new SpecialTokensText("<think>\n")),
88+
suffix: LlamaText(new SpecialTokensText("\n</think>"))
8989
}
9090
}
9191
};
@@ -247,7 +247,9 @@ export class QwenChatWrapper extends ChatWrapper {
247247
public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate(): ChatWrapperJinjaMatchConfiguration<typeof this> {
248248
return [
249249
[{}, {}, {_requireFunctionCallSettingsExtraction: true}],
250-
[{_lineBreakBeforeFunctionCallPrefix: true}, {}, {_requireFunctionCallSettingsExtraction: true}]
250+
[{_lineBreakBeforeFunctionCallPrefix: true}, {}, {_requireFunctionCallSettingsExtraction: true}],
251+
[{thoughts: "discourage"}, {}, {_requireFunctionCallSettingsExtraction: true}],
252+
[{thoughts: "discourage", _lineBreakBeforeFunctionCallPrefix: true}, {}, {_requireFunctionCallSettingsExtraction: true}]
251253
];
252254
}
253255
}

src/chatWrappers/generic/JinjaTemplateChatWrapper.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper {
671671
return res;
672672
};
673673

674-
const validateThatAllMessageIdsAreUsed = (parts: ReturnType<typeof splitText<string[]>>) => {
674+
const validateThatAllMessageIdsAreUsed = (parts: ReturnType<typeof splitText<string>>) => {
675675
const messageIdsLeft = new Set(messageIds);
676676

677677
for (const part of parts) {

src/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.ts

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,22 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
9494
modelMessage2
9595
]
9696
}];
97+
const chatHistoryOnlyCall: ChatHistoryItem[] = [...baseChatHistory, {
98+
type: "model",
99+
response: [
100+
{
101+
type: "functionCall",
102+
name: func1name,
103+
104+
// convert to number since this will go through JSON.stringify,
105+
// and we want to avoid escaping characters in the rendered output
106+
params: Number(func1params),
107+
result: Number(func1result),
108+
startsNewChunk: true
109+
},
110+
modelMessage2
111+
]
112+
}];
97113
const chatHistory2Calls: ChatHistoryItem[] = [...baseChatHistory, {
98114
type: "model",
99115
response: [
@@ -257,6 +273,17 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
257273
stringifyFunctionResults: stringifyResult,
258274
combineModelMessageAndToolCalls
259275
});
276+
const renderedOnlyCall = getFirstValidResult([
277+
() => renderTemplate({
278+
chatHistory: chatHistoryOnlyCall,
279+
functions: functions1,
280+
additionalParams,
281+
stringifyFunctionParams: stringifyParams,
282+
stringifyFunctionResults: stringifyResult,
283+
combineModelMessageAndToolCalls
284+
}),
285+
() => undefined
286+
]);
260287
const rendered2Calls = getFirstValidResult([
261288
() => renderTemplate({
262289
chatHistory: chatHistory2Calls,
@@ -411,14 +438,46 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
411438
parallelismResultPrefix
412439
} = resolveParallelismBetweenSectionsParts(func2ParamsToFunc1Result.text.slice(callSuffixLength, -resultPrefixLength));
413440

441+
let revivedCallPrefix = reviveSeparatorText(callPrefixText, idToStaticContent, contentIds);
442+
const revivedParallelismCallSectionPrefix = removeCommonRevivedPrefix(
443+
reviveSeparatorText(parallelismCallPrefix, idToStaticContent, contentIds),
444+
!combineModelMessageAndToolCalls
445+
? textBetween2TextualModelResponses
446+
: LlamaText()
447+
);
448+
let revivedParallelismCallBetweenCalls = reviveSeparatorText(parallelismBetweenCallsText, idToStaticContent, contentIds);
449+
450+
if (revivedParallelismCallSectionPrefix.values.length === 0 && renderedOnlyCall != null) {
451+
const userMessage1ToModelMessage1Start = getTextBetweenIds(rendered1Call, userMessage1, modelMessage1);
452+
const onlyCallUserMessage1ToFunc1Name = getTextBetweenIds(renderedOnlyCall, userMessage1, func1name);
453+
454+
if (userMessage1ToModelMessage1Start.text != null && onlyCallUserMessage1ToFunc1Name.text != null) {
455+
const onlyCallModelMessagePrefixLength = findCommandStartLength(
456+
userMessage1ToModelMessage1Start.text,
457+
onlyCallUserMessage1ToFunc1Name.text
458+
);
459+
const onlyCallCallPrefixText = onlyCallUserMessage1ToFunc1Name.text.slice(onlyCallModelMessagePrefixLength);
460+
const revivedOnlyCallCallPrefixText = reviveSeparatorText(onlyCallCallPrefixText, idToStaticContent, contentIds);
461+
462+
const optionalCallPrefix = removeCommonRevivedSuffix(revivedCallPrefix, revivedOnlyCallCallPrefixText);
463+
if (optionalCallPrefix.values.length > 0) {
464+
revivedCallPrefix = removeCommonRevivedPrefix(revivedCallPrefix, optionalCallPrefix);
465+
revivedParallelismCallBetweenCalls = LlamaText([
466+
optionalCallPrefix,
467+
revivedParallelismCallBetweenCalls
468+
]);
469+
}
470+
}
471+
}
472+
414473
return {
415474
stringifyParams,
416475
stringifyResult,
417476
combineModelMessageAndToolCalls,
418477
settings: {
419478
call: {
420479
optionalPrefixSpace: true,
421-
prefix: reviveSeparatorText(callPrefixText, idToStaticContent, contentIds),
480+
prefix: revivedCallPrefix,
422481
paramsPrefix: reviveSeparatorText(callParamsPrefixText, idToStaticContent, contentIds),
423482
suffix: reviveSeparatorText(callSuffixText, idToStaticContent, contentIds),
424483
emptyCallParamsPlaceholder: {}
@@ -445,13 +504,8 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
445504
},
446505
parallelism: {
447506
call: {
448-
sectionPrefix: removeCommonRevivedPrefix(
449-
reviveSeparatorText(parallelismCallPrefix, idToStaticContent, contentIds),
450-
!combineModelMessageAndToolCalls
451-
? textBetween2TextualModelResponses
452-
: LlamaText()
453-
),
454-
betweenCalls: reviveSeparatorText(parallelismBetweenCallsText, idToStaticContent, contentIds),
507+
sectionPrefix: revivedParallelismCallSectionPrefix,
508+
betweenCalls: revivedParallelismCallBetweenCalls,
455509
sectionSuffix: reviveSeparatorText(parallelismCallSuffixText, idToStaticContent, contentIds)
456510
},
457511
result: {
@@ -524,14 +578,48 @@ function removeCommonRevivedPrefix(target: LlamaText, matchStart: LlamaText) {
524578
} else if (targetValue instanceof SpecialToken && matchStartValue instanceof SpecialToken) {
525579
if (targetValue.value === matchStartValue.value)
526580
continue;
527-
}
581+
} else if (LlamaText(targetValue ?? "").compare(LlamaText(matchStartValue ?? "")))
582+
continue;
528583

529584
return LlamaText(target.values.slice(commonStartLength));
530585
}
531586

532587
return LlamaText(target.values.slice(matchStart.values.length));
533588
}
534589

590+
function removeCommonRevivedSuffix(target: LlamaText, matchEnd: LlamaText) {
591+
for (
592+
let commonEndLength = 0;
593+
commonEndLength < target.values.length && commonEndLength < matchEnd.values.length;
594+
commonEndLength++
595+
) {
596+
const targetValue = target.values[target.values.length - commonEndLength - 1];
597+
const matchEndValue = matchEnd.values[matchEnd.values.length - commonEndLength - 1];
598+
599+
if (typeof targetValue === "string" && typeof matchEndValue === "string") {
600+
if (targetValue === matchEndValue)
601+
continue;
602+
} else if (targetValue instanceof SpecialTokensText && matchEndValue instanceof SpecialTokensText) {
603+
const commonLength = findCommonEndLength(targetValue.value, matchEndValue.value);
604+
if (commonLength === targetValue.value.length && commonLength === matchEndValue.value.length)
605+
continue;
606+
607+
return LlamaText([
608+
...target.values.slice(0, target.values.length - commonEndLength - 1),
609+
new SpecialTokensText(targetValue.value.slice(0, targetValue.value.length - commonLength))
610+
]);
611+
} else if (targetValue instanceof SpecialToken && matchEndValue instanceof SpecialToken) {
612+
if (targetValue.value === matchEndValue.value)
613+
continue;
614+
} else if (LlamaText(targetValue ?? "").compare(LlamaText(matchEndValue ?? "")))
615+
continue;
616+
617+
return LlamaText(target.values.slice(0, target.values.length - commonEndLength - 1));
618+
}
619+
620+
return LlamaText(target.values.slice(0, target.values.length - matchEnd.values.length));
621+
}
622+
535623
function findCommandStartLength(text1: string, text2: string) {
536624
let commonStartLength = 0;
537625
while (commonStartLength < text1.length && commonStartLength < text2.length) {

src/chatWrappers/generic/utils/extractSegmentSettingsFromTokenizerAndChatTemplate.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,42 @@ export function extractSegmentSettingsFromTokenizerAndChatTemplate(
88
function tryMatchPrefixSuffixPair(tryMatchGroups: [prefix: string, suffix: string][]) {
99
if (chatTemplate != null) {
1010
for (const [prefix, suffix] of tryMatchGroups) {
11+
if (
12+
(
13+
hasAll(chatTemplate.replaceAll(prefix + "\\n\\n" + suffix, ""), [
14+
prefix + "\\n\\n",
15+
"\\n\\n" + suffix
16+
])
17+
) || (
18+
hasAll(chatTemplate.replaceAll(prefix + "\n\n" + suffix, ""), [
19+
prefix + "\n\n",
20+
"\n\n" + suffix
21+
])
22+
)
23+
)
24+
return {
25+
prefix: LlamaText(new SpecialTokensText(prefix + "\n\n")),
26+
suffix: LlamaText(new SpecialTokensText("\n\n" + suffix))
27+
};
28+
29+
if (
30+
(
31+
hasAll(chatTemplate.replaceAll(prefix + "\\n" + suffix, ""), [
32+
prefix + "\\n",
33+
"\\n" + suffix
34+
])
35+
) || (
36+
hasAll(chatTemplate.replaceAll(prefix + "\n" + suffix, ""), [
37+
prefix + "\n",
38+
"\n" + suffix
39+
])
40+
)
41+
)
42+
return {
43+
prefix: LlamaText(new SpecialTokensText(prefix + "\n")),
44+
suffix: LlamaText(new SpecialTokensText("\n" + suffix))
45+
};
46+
1147
if (chatTemplate.includes(prefix) && chatTemplate.includes(suffix))
1248
return {
1349
prefix: LlamaText(new SpecialTokensText(prefix)),
@@ -46,3 +82,7 @@ export function extractSegmentSettingsFromTokenizerAndChatTemplate(
4682
])
4783
});
4884
}
85+
86+
function hasAll(text: string, matches: string[]) {
87+
return matches.every((match) => text.includes(match));
88+
}

src/cli/commands/ChatCommand.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import {defineChatSessionFunction} from "../../evaluator/LlamaChatSession/utils/
1212
import {getLlama} from "../../bindings/getLlama.js";
1313
import {LlamaGrammar} from "../../evaluator/LlamaGrammar.js";
1414
import {LlamaChatSession} from "../../evaluator/LlamaChatSession/LlamaChatSession.js";
15-
import {LlamaJsonSchemaGrammar} from "../../evaluator/LlamaJsonSchemaGrammar.js";
1615
import {
1716
BuildGpu, LlamaLogLevel, LlamaLogLevelGreaterThan, nodeLlamaCppGpuOptions, parseNodeLlamaCppGpuOption
1817
} from "../../bindings/types.js";
@@ -529,8 +528,7 @@ async function RunChat({
529528
});
530529

531530
const grammar = jsonSchemaGrammarFilePath != null
532-
? new LlamaJsonSchemaGrammar(
533-
llama,
531+
? await llama.createGrammarForJsonSchema(
534532
await fs.readJson(
535533
path.resolve(process.cwd(), jsonSchemaGrammarFilePath)
536534
)

0 commit comments

Comments
 (0)