diff --git a/LLama.Examples/Examples/QuantizeModel.cs b/LLama.Examples/Examples/QuantizeModel.cs
index a1f7ca1bd..863bb0c3a 100644
--- a/LLama.Examples/Examples/QuantizeModel.cs
+++ b/LLama.Examples/Examples/QuantizeModel.cs
@@ -2,7 +2,7 @@ namespace LLama.Examples.Examples
{
public class QuantizeModel
{
- public static async Task Run()
+ public static Task Run()
{
string inputPath = UserSettings.GetModelPath();
@@ -20,6 +20,8 @@ public static async Task Run()
{
Console.WriteLine("Quantization failed!");
}
+
+ return Task.CompletedTask;
}
}
}
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 90119d4fe..bda7472d5 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -76,9 +76,10 @@ public class ChatSession
/// The executor for this session
/// History for this session
/// History Transform for this session
+ /// A token that cancels the operation
/// A new chat session.
public static async Task InitializeSessionFromHistoryAsync(
- ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null)
+ ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
@@ -90,7 +91,7 @@ public static async Task InitializeSessionFromHistoryAsync(
session = session.WithHistoryTransform(transform);
}
- await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
+ await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
return session;
}
@@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage()
/// Compute KV cache for the message and add it to the chat history.
///
///
+ ///
///
- public async Task AddAndProcessMessage(ChatHistory.Message message)
+ public async Task AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default)
{
if (Executor is not StatefulExecutorBase statefulExecutor)
{
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
}
+
AddMessage(message);
var content = message.Content;
if (message.AuthorRole != AuthorRole.Assistant)
@@ -328,27 +331,27 @@ public async Task AddAndProcessMessage(ChatHistory.Message message)
}
}
- await statefulExecutor.PrefillPromptAsync(content);
+ await statefulExecutor.PrefillPromptAsync(content, cancellationToken);
return this;
}
///
/// Compute KV cache for the system message and add it to the chat history.
///
- public Task AddAndProcessSystemMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
+ public Task AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default)
+ => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken);
///
/// Compute KV cache for the user message and add it to the chat history.
///
- public Task AddAndProcessUserMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
+ public Task AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default)
+ => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken);
///
/// Compute KV cache for the assistant message and add it to the chat history.
///
- public Task AddAndProcessAssistantMessage(string content)
- => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
+ public Task AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default)
+ => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken);
///
/// Replace a user message with a new message and remove all messages after the new message.
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 42d76c514..4188f9e5f 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -1,14 +1,14 @@
-using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics;
-using System.Text;
using System.IO;
using System.IO.MemoryMappedFiles;
+using System.Text;
+using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
+using LLama.Native;
using Microsoft.Extensions.Logging;
-using System.Threading;
namespace LLama
{
@@ -73,7 +73,7 @@ public int BatchThreads
/// Get the special tokens for the model associated with this context
///
public SafeLlamaModelHandle.Vocabulary Vocab { get; }
-
+
///
/// Create a new LLamaContext for the given LLamaWeights
///
@@ -396,7 +396,7 @@ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancel
{
return Task.Run(() => Decode(batch), cancellationToken);
}
-
+
///
///
///
@@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch)
return 0;
if (batch.EmbeddingsCount > BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
-
+
return (DecodeResult)NativeHandle.Decode(batch);
}
-
+
///
///
///
@@ -425,15 +425,16 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
///
///
///
+ ///
/// A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded.
- public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past)
+ public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default)
{
return Task.Run(() =>
{
var past = n_past;
var res = NativeHandle.Decode(tokens, id, batch, ref past);
return (res.Item1, res.Item2, past);
- });
+ }, cancellationToken);
}
#endregion
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index 36989006e..0e8d5f115 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -239,36 +239,41 @@ protected virtual void TryReuseMatchingPrefix()
/// Decide whether to continue the loop.
///
///
+ ///
///
- protected abstract Task GetLoopCondition(InferStateArgs args);
+ protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default);
///
/// Preprocess the inputs before the inference.
///
///
///
- protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
+ ///
+ protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default);
///
/// Do some post processing after the inference.
///
///
///
+ ///
///
- protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
+ protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
///
/// The core inference logic.
///
///
///
- protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
+ ///
+ protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);
///
/// Save the current state to a file.
///
///
- public abstract Task SaveState(string filename);
+ ///
+ public abstract Task SaveState(string filename, CancellationToken cancellationToken = default);
///
/// Get the current state data.
@@ -280,13 +285,15 @@ protected virtual void TryReuseMatchingPrefix()
/// Load the state from data.
///
///
- public abstract Task LoadState(ExecutorBaseState data);
+ ///
+ public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default);
///
/// Load the state from a file.
///
///
- public abstract Task LoadState(string filename);
+ ///
+ public abstract Task LoadState(string filename, CancellationToken cancellationToken = default);
///
@@ -310,15 +317,15 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};
- await PreprocessInputs(text, args);
+ await PreprocessInputs(text, args, cancellationToken);
- while (await GetLoopCondition(args))
+ while (await GetLoopCondition(args, cancellationToken))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}
- await InferInternal(inferenceParams, args);
+ await InferInternal(inferenceParams, args, cancellationToken);
if (args.ReturnValue)
{
@@ -326,7 +333,7 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc
yield return _decoder.Read();
}
- var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
+ var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
@@ -346,8 +353,9 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc
/// It could reduce the latency of the first time response if the first input from the user is not immediate.
///
/// Prompt to process
+ ///
///
- public virtual async Task PrefillPromptAsync(string prompt)
+ public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default)
{
var inferenceParams = new InferenceParams
{
@@ -362,11 +370,11 @@ public virtual async Task PrefillPromptAsync(string prompt)
NeedToSaveSession = false
};
- await PreprocessInputs(prompt, args);
+ await PreprocessInputs(prompt, args, cancellationToken);
// First run adds the prompt to the _embeds
- await InferInternal(inferenceParams, args);
+ await InferInternal(inferenceParams, args, cancellationToken);
// Second run puts it through decode
- await InferInternal(inferenceParams, args);
+ await InferInternal(inferenceParams, args, cancellationToken);
}
///
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 331591fba..d7a8c4a94 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -1,14 +1,15 @@
-using LLama.Abstractions;
-using LLama.Common;
-using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
+using System.Threading;
using System.Threading.Tasks;
+using LLama.Abstractions;
+using LLama.Common;
using LLama.Exceptions;
+using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
@@ -65,9 +66,9 @@ public override ExecutorBaseState GetStateData()
return state;
}
///
- public override Task LoadState(ExecutorBaseState data)
+ public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default)
{
- if(data is InstructExecutorState state)
+ if (data is InstructExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
_embed_inps = state.EmbedInps!.ToList();
@@ -91,34 +92,34 @@ public override Task LoadState(ExecutorBaseState data)
}
///
- public override async Task SaveState(string filename)
+ public override async Task SaveState(string filename, CancellationToken cancellationToken = default)
{
var state = (InstructExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
- await JsonSerializer.SerializeAsync(fs, state);
+ await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken);
}
}
///
- public override async Task LoadState(string filename)
+ public override async Task LoadState(string filename, CancellationToken cancellationToken = default)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync(fs);
- await LoadState(state!);
+ await LoadState(state!, cancellationToken);
}
}
///
- protected override Task GetLoopCondition(InferStateArgs args)
+ protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)
{
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
}
///
- protected override Task PreprocessInputs(string? text, InferStateArgs args)
+ protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)
{
- args.Antiprompts ??= [ ];
+ args.Antiprompts ??= [];
if (!args.Antiprompts.Contains(_instructionPrefix))
args.Antiprompts.Add(_instructionPrefix);
@@ -154,19 +155,19 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}
///
- protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{
args.WaitForInput = true;
- return (true, Array.Empty());
+ return Task.FromResult<(bool, IReadOnlyList)>((true, []));
}
if (_pastTokensCount > 0 && args.WaitForInput)
{
- return (true, new[] { "\n> " });
+ return Task.FromResult<(bool, IReadOnlyList)>((true, ["\n> "]));
}
}
@@ -180,11 +181,12 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true;
}
- return (false, Array.Empty());
+
+ return Task.FromResult<(bool, IReadOnlyList)>((false, []));
}
///
- protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
var batch = new LLamaBatch();
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 7c9558ee3..f05b1c974 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -1,14 +1,15 @@
-using LLama.Common;
-using LLama.Native;
-using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
+using System.Threading;
using System.Threading.Tasks;
+using LLama.Abstractions;
+using LLama.Common;
using LLama.Exceptions;
+using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
@@ -21,7 +22,7 @@ namespace LLama
public class InteractiveExecutor : StatefulExecutorBase
{
private bool _is_prompt_run = true;
-
+
// LLava
private int _EmbedImagePosition = -1;
private List _imageEmbedHandles = new List();
@@ -36,7 +37,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
}
-
+
///
///
///
@@ -46,7 +47,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null)
: base(context, clipModel, logger)
{
- }
+ }
///
public override ExecutorBaseState GetStateData()
@@ -68,7 +69,7 @@ public override ExecutorBaseState GetStateData()
return state;
}
///
- public override Task LoadState(ExecutorBaseState data)
+ public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default)
{
if (data is InteractiveExecutorState state)
{
@@ -88,22 +89,24 @@ public override Task LoadState(ExecutorBaseState data)
return Task.CompletedTask;
}
+
///
- public override async Task SaveState(string filename)
+ public override async Task SaveState(string filename, CancellationToken cancellationToken = default)
{
var state = (InteractiveExecutorState)GetStateData();
- using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
+ using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
- await JsonSerializer.SerializeAsync(fs, state);
+ await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken);
}
}
+
///
- public override async Task LoadState(string filename)
+ public override async Task LoadState(string filename, CancellationToken cancellationToken = default)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync(fs);
- await LoadState(state!);
+ await LoadState(state!, cancellationToken);
}
}
@@ -111,13 +114,13 @@ public override async Task LoadState(string filename)
/// Define whether to continue the loop to generate responses.
///
///
- protected override Task GetLoopCondition(InferStateArgs args)
+ protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)
{
return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run);
}
///
- protected override Task PreprocessInputs(string? text, InferStateArgs args)
+ protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)
{
if (_is_prompt_run)
{
@@ -159,8 +162,8 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}
///
- private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true )
- {
+ private void PreprocessLlava(string text, InferStateArgs args, bool addBos = true)
+ {
// If the prompt contains the tag extract this.
_imageInPrompt = text.Contains("");
if (_imageInPrompt && IsMultiModal)
@@ -191,10 +194,9 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
{
var line_inp = Context.Tokenize(text, false, true);
_embed_inps.AddRange(line_inp);
- args.RemainedTokens -= line_inp.Length;
+ args.RemainedTokens -= line_inp.Length;
}
}
- return Task.CompletedTask;
}
///
@@ -202,21 +204,26 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
///
///
///
+ ///
///
- protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
+ {
args.WaitForInput = true;
+ }
if (_pastTokensCount > 0 && args.WaitForInput)
- return (true, Array.Empty());
+ {
+ return Task.FromResult<(bool, IReadOnlyList)>((true, []));
+ }
}
if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab))
{
- return (true, Array.Empty());
+ return Task.FromResult<(bool, IReadOnlyList)>((true, []));
}
if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1)
@@ -225,11 +232,11 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
args.WaitForInput = true;
}
- return (false, Array.Empty());
+ return Task.FromResult<(bool, IReadOnlyList)>((false, []));
}
///
- protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
+ protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
var batch = new LLamaBatch();
@@ -258,18 +265,18 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
// Changes to support Multi-Modal LLMs.
//
(DecodeResult, int, int) header, end, result;
- if (IsMultiModal && _EmbedImagePosition > 0)
+ if (IsMultiModal && _EmbedImagePosition > 0)
{
// Tokens previous to the images
header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = header.Item3;
if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);
-
+
// Images
- foreach( var image in _imageEmbedHandles )
+ foreach (var image in _imageEmbedHandles)
ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount);
-
+
// Post-image Tokens
end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = end.Item3;
@@ -285,7 +292,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1);
}
-
+
if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{