Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions docs/guide/chat-session.md
Original file line number Diff line number Diff line change
Expand Up @@ -898,3 +898,58 @@ const fullResponse = a1.response

console.log("Full response: " + fullResponse);
```

## Set Thinking Budget {#thinking-budget}
You can set a thinking budget to limit the number of tokens a thinking model can spend on [thought segments](#stream-response-segments).
```typescript
import {
getLlama, LlamaChatSession, resolveModelFile, Token
} from "node-llama-cpp";

const modelPath = await resolveModelFile("hf:Qwen/Qwen3-14B-GGUF:Q4_K_M");

const llama = await getLlama();
const model = await llama.loadModel({modelPath});
const context = await model.createContext();
const session = new LlamaChatSession({
contextSequence: context.getSequence()
});


const q1 = "Where do llamas come from?";
console.log("User: " + q1);

const maxThoughtTokens = 100;

let responseTokens = 0;
let thoughtTokens = 0;

process.stdout.write("AI: ");
const response = await session.prompt(q1, {
budgets: {
thoughtTokens: maxThoughtTokens
},
onResponseChunk(chunk) {
const isThoughtSegment = chunk.type === "segment" &&
chunk.segmentType === "thought";

if (chunk.type === "segment" && chunk.segmentStartTime != null)
process.stdout.write(` [segment start: ${chunk.segmentType}] `);

process.stdout.write(chunk.text);

if (chunk.type === "segment" && chunk.segmentEndTime != null)
process.stdout.write(` [segment end: ${chunk.segmentType}] `);

if (isThoughtSegment)
thoughtTokens += chunk.tokens.length;
else
responseTokens += chunk.tokens.length;
}
});

console.log("Response: " + response);

console.log("Response tokens: " + responseTokens);
console.log("Thought tokens: " + thoughtTokens);
```
19 changes: 19 additions & 0 deletions llama/addon/AddonContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad
context_params.n_threads = std::max(cpu_get_num_math(), 1);
context_params.n_threads_batch = context_params.n_threads;
context_params.no_perf = true;
context_params.swa_full = false;

if (info.Length() > 1 && info[1].IsObject()) {
Napi::Object options = info[1].As<Napi::Object>();
Expand Down Expand Up @@ -433,6 +434,10 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad
if (options.Has("performanceTracking")) {
context_params.no_perf = !(options.Get("performanceTracking").As<Napi::Boolean>().Value());
}

if (options.Has("swaFullCache")) {
context_params.swa_full = options.Get("swaFullCache").As<Napi::Boolean>().Value();
}
}
}
AddonContext::~AddonContext() {
Expand Down Expand Up @@ -620,6 +625,19 @@ Napi::Value AddonContext::ShiftSequenceTokenCells(const Napi::CallbackInfo& info

return info.Env().Undefined();
}
Napi::Value AddonContext::GetSequenceKvCacheMinPosition(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

int32_t sequenceId = info[0].As<Napi::Number>().Int32Value();


const auto minPosition = llama_kv_self_seq_pos_min(ctx, sequenceId);

return Napi::Number::New(info.Env(), minPosition);
}
Napi::Value AddonContext::DecodeBatch(const Napi::CallbackInfo& info) {
AddonContextDecodeBatchWorker* worker = new AddonContextDecodeBatchWorker(info.Env(), this);
worker->Queue();
Expand Down Expand Up @@ -926,6 +944,7 @@ void AddonContext::init(Napi::Object exports) {
InstanceMethod("disposeSequence", &AddonContext::DisposeSequence),
InstanceMethod("removeTokenCellsFromSequence", &AddonContext::RemoveTokenCellsFromSequence),
InstanceMethod("shiftSequenceTokenCells", &AddonContext::ShiftSequenceTokenCells),
InstanceMethod("getSequenceKvCacheMinPosition", &AddonContext::GetSequenceKvCacheMinPosition),
InstanceMethod("decodeBatch", &AddonContext::DecodeBatch),
InstanceMethod("sampleToken", &AddonContext::SampleToken),
InstanceMethod("getEmbedding", &AddonContext::GetEmbedding),
Expand Down
1 change: 1 addition & 0 deletions llama/addon/AddonContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
Napi::Value DisposeSequence(const Napi::CallbackInfo& info);
Napi::Value RemoveTokenCellsFromSequence(const Napi::CallbackInfo& info);
Napi::Value ShiftSequenceTokenCells(const Napi::CallbackInfo& info);
Napi::Value GetSequenceKvCacheMinPosition(const Napi::CallbackInfo& info);
Napi::Value DecodeBatch(const Napi::CallbackInfo& info);
Napi::Value SampleToken(const Napi::CallbackInfo& info);

Expand Down
14 changes: 14 additions & 0 deletions llama/addon/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ Napi::Value addonGetTypeSizeForGgmlType(const Napi::CallbackInfo& info) {
return Napi::Number::New(info.Env(), typeSize);
}

Napi::Value addonGetGgmlGraphOverheadCustom(const Napi::CallbackInfo& info) {
if (info.Length() < 2 || !info[0].IsNumber() || !info[1].IsBoolean()) {
return Napi::Number::New(info.Env(), 0);
}

const size_t size = info[0].As<Napi::Number>().Uint32Value();
const bool grads = info[1].As<Napi::Boolean>().Value();

const auto graphOverhead = ggml_graph_overhead_custom(size, grads);

return Napi::Number::New(info.Env(), graphOverhead);
}

Napi::Value addonGetConsts(const Napi::CallbackInfo& info) {
Napi::Object consts = Napi::Object::New(info.Env());
consts.Set("ggmlMaxDims", Napi::Number::New(info.Env(), GGML_MAX_DIMS));
Expand Down Expand Up @@ -231,6 +244,7 @@ Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
Napi::PropertyDescriptor::Function("getMathCores", addonGetMathCores),
Napi::PropertyDescriptor::Function("getBlockSizeForGgmlType", addonGetBlockSizeForGgmlType),
Napi::PropertyDescriptor::Function("getTypeSizeForGgmlType", addonGetTypeSizeForGgmlType),
Napi::PropertyDescriptor::Function("getGgmlGraphOverheadCustom", addonGetGgmlGraphOverheadCustom),
Napi::PropertyDescriptor::Function("getConsts", addonGetConsts),
Napi::PropertyDescriptor::Function("setLogger", setLogger),
Napi::PropertyDescriptor::Function("setLoggerLogLevel", setLoggerLogLevel),
Expand Down
5 changes: 4 additions & 1 deletion src/bindings/AddonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ export type BindingModule = {
embeddings?: boolean,
ranking?: boolean,
threads?: number,
performanceTracking?: boolean
performanceTracking?: boolean,
swaFullCache?: boolean
}): AddonContext
},
AddonGrammar: {
Expand All @@ -54,6 +55,7 @@ export type BindingModule = {
getMathCores(): number,
getBlockSizeForGgmlType(ggmlType: number): number | undefined,
getTypeSizeForGgmlType(ggmlType: number): number | undefined,
getGgmlGraphOverheadCustom(size: number, grads: boolean): number,
getConsts(): {
ggmlMaxDims: number,
ggmlTypeF16Size: number,
Expand Down Expand Up @@ -143,6 +145,7 @@ export type AddonContext = {
// startPos in inclusive, endPos is exclusive
shiftSequenceTokenCells(sequenceId: number, startPos: number, endPos: number, shiftDelta: number): void,

getSequenceKvCacheMinPosition(sequenceId: number): number,
getEmbedding(inputTokensLength: number, maxVectorSize?: number): Float64Array,
getStateSize(): number,
getThreads(): number,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ export function extractSegmentSettingsFromTokenizerAndChatTemplate(
return removeUndefinedFields({
thought: tryMatchPrefixSuffixPair([
["<think>", "</think>"], // DeepSeek, QwQ
["<thought>", "</thought>"] // EXAONE Deep
["<thought>", "</thought>"], // EXAONE Deep
["<|START_THINKING|>", "<|END_THINKING|>"] // Command R7B
])
});
}
39 changes: 32 additions & 7 deletions src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type ChatCommand = {
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
swaFullCache?: boolean,
noTrimWhitespace: boolean,
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[1],
jsonSchemaGrammarFile?: string,
Expand All @@ -61,6 +62,7 @@ type ChatCommand = {
repeatFrequencyPenalty?: number,
repeatPresencePenalty?: number,
maxTokens: number,
thoughtBudget?: number,
noHistory: boolean,
environmentFunctions: boolean,
tokenPredictionDraftModel?: string,
Expand Down Expand Up @@ -162,6 +164,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
default: false,
description: "Enable flash attention"
})
.option("swaFullCache", {
alias: "noSwa",
type: "boolean",
default: false,
description: "Disable SWA (Sliding Window Attention) on supported models"
})
.option("noTrimWhitespace", {
type: "boolean",
alias: ["noTrim"],
Expand Down Expand Up @@ -255,6 +263,13 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
default: 0,
description: "Maximum number of tokens to generate in responses. Set to `0` to disable. Set to `-1` to set to the context size"
})
.option("thoughtBudget", {
alias: ["tb", "thinkingBudget", "reasoningBudget"],
type: "number",
default: -1,
defaultDescription: "Unlimited",
description: "Maximum number of tokens the model can use for thoughts. Set to `0` to disable reasoning"
})
.option("noHistory", {
alias: "nh",
type: "boolean",
Expand Down Expand Up @@ -308,19 +323,20 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt,
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention,
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention, swaFullCache,
noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
topP, seed, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, thoughtBudget, noHistory,
environmentFunctions, tokenPredictionDraftModel, tokenPredictionModelContextSize, debug, meter, timing, noMmap, printTimings
}) {
try {
await RunChat({
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize,
batchSize, flashAttention, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, seed,
batchSize, flashAttention, swaFullCache, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads,
temperature, minP, topK, topP, seed,
gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, noHistory, environmentFunctions, tokenPredictionDraftModel, tokenPredictionModelContextSize, debug, meter,
timing, noMmap, printTimings
maxTokens, thoughtBudget, noHistory, environmentFunctions, tokenPredictionDraftModel, tokenPredictionModelContextSize,
debug, meter, timing, noMmap, printTimings
});
} catch (err) {
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing
Expand All @@ -333,13 +349,15 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {

async function RunChat({
modelPath: modelArg, header: headerArg, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja,
contextSize, batchSize, flashAttention, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
contextSize, batchSize, flashAttention, swaFullCache, noTrimWhitespace, grammar: grammarArg,
jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
threads, temperature, minP, topK, topP, seed, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, tokenPredictionDraftModel,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, thoughtBudget, noHistory, environmentFunctions, tokenPredictionDraftModel,
tokenPredictionModelContextSize, debug, meter, timing, noMmap, printTimings
}: ChatCommand) {
if (contextSize === -1) contextSize = undefined;
if (gpuLayers === -1) gpuLayers = undefined;
if (thoughtBudget === -1) thoughtBudget = undefined;

const headers = resolveHeaderFlag(headerArg);
const trimWhitespace = !noTrimWhitespace;
Expand All @@ -363,11 +381,13 @@ async function RunChat({

const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
flashAttention,
swaFullCache,
useMmap
});
const resolvedDraftModelPath = (tokenPredictionDraftModel != null && tokenPredictionDraftModel !== "")
? await resolveCommandGgufPath(tokenPredictionDraftModel, llama, headers, {
flashAttention,
swaFullCache,
useMmap,
consoleTitle: "Draft model file"
})
Expand Down Expand Up @@ -413,6 +433,7 @@ async function RunChat({
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
defaultContextSwaFullCache: swaFullCache,
useMmap,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
Expand Down Expand Up @@ -446,6 +467,7 @@ async function RunChat({
return await llama.loadModel({
modelPath: resolvedDraftModelPath,
defaultContextFlashAttention: flashAttention,
defaultContextSwaFullCache: swaFullCache,
useMmap,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down Expand Up @@ -673,6 +695,9 @@ async function RunChat({
seed: seed ?? undefined,
signal: abortController.signal,
stopOnAbortSignal: true,
budgets: {
thoughtTokens: thoughtBudget
},
repeatPenalty: {
penalty: repeatPenalty,
frequencyPenalty: repeatFrequencyPenalty != null ? repeatFrequencyPenalty : undefined,
Expand Down
17 changes: 14 additions & 3 deletions src/cli/commands/CompleteCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type CompleteCommand = {
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
swaFullCache?: boolean,
threads?: number,
temperature: number,
minP: number,
Expand Down Expand Up @@ -119,6 +120,12 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
default: false,
description: "Enable flash attention"
})
.option("swaFullCache", {
alias: "noSwa",
type: "boolean",
default: false,
description: "Disable SWA (Sliding Window Attention) on supported models"
})
.option("threads", {
type: "number",
defaultDescription: "Number of cores that are useful for math on the current machine",
Expand Down Expand Up @@ -235,14 +242,14 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
flashAttention, threads, temperature, minP, topK,
flashAttention, swaFullCache, threads, temperature, minP, topK,
topP, seed, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, tokenPredictionDraftModel, tokenPredictionModelContextSize,
debug, meter, timing, noMmap, printTimings
}) {
try {
await RunCompletion({
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention, swaFullCache,
threads, temperature, minP, topK, topP, seed, gpuLayers, lastTokensRepeatPenalty,
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
tokenPredictionDraftModel, tokenPredictionModelContextSize, debug, meter, timing, noMmap, printTimings
Expand All @@ -257,7 +264,7 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {


async function RunCompletion({
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention, swaFullCache,
threads, temperature, minP, topK, topP, seed, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
tokenPredictionDraftModel, tokenPredictionModelContextSize, maxTokens, debug, meter, timing, noMmap, printTimings
Expand Down Expand Up @@ -286,11 +293,13 @@ async function RunCompletion({

const resolvedModelPath = await resolveCommandGgufPath(modelArg, llama, headers, {
flashAttention,
swaFullCache,
useMmap
});
const resolvedDraftModelPath = (tokenPredictionDraftModel != null && tokenPredictionDraftModel !== "")
? await resolveCommandGgufPath(tokenPredictionDraftModel, llama, headers, {
flashAttention,
swaFullCache,
useMmap,
consoleTitle: "Draft model file"
})
Expand Down Expand Up @@ -329,6 +338,7 @@ async function RunCompletion({
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
defaultContextSwaFullCache: swaFullCache,
useMmap,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
Expand Down Expand Up @@ -362,6 +372,7 @@ async function RunCompletion({
return await llama.loadModel({
modelPath: resolvedDraftModelPath,
defaultContextFlashAttention: flashAttention,
defaultContextSwaFullCache: swaFullCache,
useMmap,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down
1 change: 1 addition & 0 deletions src/cli/commands/DebugCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ async function DebugCmakeOptionsFunction() {
console.info();

console.info(`${chalk.yellow("CMake options:")} ${prettyPrintObject(llama.cmakeOptions)}`);
console.info(`${chalk.yellow("Release:")} ${prettyPrintObject(llama.llamaCppRelease)}`);
}

Loading
Loading