diff --git a/LLama.Examples/Examples/QuantizeModel.cs b/LLama.Examples/Examples/QuantizeModel.cs index a1f7ca1bd..863bb0c3a 100644 --- a/LLama.Examples/Examples/QuantizeModel.cs +++ b/LLama.Examples/Examples/QuantizeModel.cs @@ -2,7 +2,7 @@ namespace LLama.Examples.Examples { public class QuantizeModel { - public static async Task Run() + public static Task Run() { string inputPath = UserSettings.GetModelPath(); @@ -20,6 +20,8 @@ public static async Task Run() { Console.WriteLine("Quantization failed!"); } + + return Task.CompletedTask; } } } diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 90119d4fe..bda7472d5 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -76,9 +76,10 @@ public class ChatSession /// The executor for this session /// History for this session /// History Transform for this session + /// A token that cancels the operation /// A new chat session. public static async Task InitializeSessionFromHistoryAsync( - ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null) + ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default) { if (executor is not StatefulExecutorBase statefulExecutor) { @@ -90,7 +91,7 @@ public static async Task InitializeSessionFromHistoryAsync( session = session.WithHistoryTransform(transform); } - await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history)); + await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken); return session; } @@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage() /// Compute KV cache for the message and add it to the chat history. /// /// + /// /// - public async Task AddAndProcessMessage(ChatHistory.Message message) + public async Task AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default) { if (Executor is not StatefulExecutorBase statefulExecutor) { throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages."); } + AddMessage(message); var content = message.Content; if (message.AuthorRole != AuthorRole.Assistant) @@ -328,27 +331,27 @@ public async Task AddAndProcessMessage(ChatHistory.Message message) } } - await statefulExecutor.PrefillPromptAsync(content); + await statefulExecutor.PrefillPromptAsync(content, cancellationToken); return this; } /// /// Compute KV cache for the system message and add it to the chat history. /// - public Task AddAndProcessSystemMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content)); + public Task AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken); /// /// Compute KV cache for the user message and add it to the chat history. /// - public Task AddAndProcessUserMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content)); + public Task AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken); /// /// Compute KV cache for the assistant message and add it to the chat history. /// - public Task AddAndProcessAssistantMessage(string content) - => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content)); + public Task AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default) + => AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken); /// /// Replace a user message with a new message and remove all messages after the new message. diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 42d76c514..4188f9e5f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -1,14 +1,14 @@ -using LLama.Native; using System; using System.Collections.Generic; using System.Diagnostics; -using System.Text; using System.IO; using System.IO.MemoryMappedFiles; +using System.Text; +using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; +using LLama.Native; using Microsoft.Extensions.Logging; -using System.Threading; namespace LLama { @@ -73,7 +73,7 @@ public int BatchThreads /// Get the special tokens for the model associated with this context /// public SafeLlamaModelHandle.Vocabulary Vocab { get; } - + /// /// Create a new LLamaContext for the given LLamaWeights /// @@ -396,7 +396,7 @@ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancel { return Task.Run(() => Decode(batch), cancellationToken); } - + /// /// /// @@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch) return 0; if (batch.EmbeddingsCount > BatchSize) throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); - + return (DecodeResult)NativeHandle.Decode(batch); } - + /// /// /// @@ -425,15 +425,16 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo /// /// /// + /// /// A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded. - public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past) + public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default) { return Task.Run(() => { var past = n_past; var res = NativeHandle.Decode(tokens, id, batch, ref past); return (res.Item1, res.Item2, past); - }); + }, cancellationToken); } #endregion diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 36989006e..0e8d5f115 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -239,36 +239,41 @@ protected virtual void TryReuseMatchingPrefix() /// Decide whether to continue the loop. /// /// + /// /// - protected abstract Task GetLoopCondition(InferStateArgs args); + protected abstract Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default); /// /// Preprocess the inputs before the inference. /// /// /// - protected abstract Task PreprocessInputs(string? text, InferStateArgs args); + /// + protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default); /// /// Do some post processing after the inference. /// /// /// + /// /// - protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args); + protected abstract Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// /// The core inference logic. /// /// /// - protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args); + /// + protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default); /// /// Save the current state to a file. /// /// - public abstract Task SaveState(string filename); + /// + public abstract Task SaveState(string filename, CancellationToken cancellationToken = default); /// /// Get the current state data. @@ -280,13 +285,15 @@ protected virtual void TryReuseMatchingPrefix() /// Load the state from data. /// /// - public abstract Task LoadState(ExecutorBaseState data); + /// + public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default); /// /// Load the state from a file. /// /// - public abstract Task LoadState(string filename); + /// + public abstract Task LoadState(string filename, CancellationToken cancellationToken = default); /// @@ -310,15 +317,15 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count }; - await PreprocessInputs(text, args); + await PreprocessInputs(text, args, cancellationToken); - while (await GetLoopCondition(args)) + while (await GetLoopCondition(args, cancellationToken)) { if (cancellationToken.IsCancellationRequested) { break; } - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); if (args.ReturnValue) { @@ -326,7 +333,7 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc yield return _decoder.Read(); } - var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); + var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken); if (extraOutputs is { Count: > 0 }) { foreach (var item in extraOutputs) @@ -346,8 +353,9 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc /// It could reduce the latency of the first time response if the first input from the user is not immediate. /// /// Prompt to process + /// /// - public virtual async Task PrefillPromptAsync(string prompt) + public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default) { var inferenceParams = new InferenceParams { @@ -362,11 +370,11 @@ public virtual async Task PrefillPromptAsync(string prompt) NeedToSaveSession = false }; - await PreprocessInputs(prompt, args); + await PreprocessInputs(prompt, args, cancellationToken); // First run adds the prompt to the _embeds - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); // Second run puts it through decode - await InferInternal(inferenceParams, args); + await InferInternal(inferenceParams, args, cancellationToken); } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 331591fba..d7a8c4a94 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -1,14 +1,15 @@ -using LLama.Abstractions; -using LLama.Common; -using LLama.Native; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using LLama.Exceptions; +using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -65,9 +66,9 @@ public override ExecutorBaseState GetStateData() return state; } /// - public override Task LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { - if(data is InstructExecutorState state) + if (data is InstructExecutorState state) { _n_session_consumed = state.ConsumedSessionCount; _embed_inps = state.EmbedInps!.ToList(); @@ -91,34 +92,34 @@ public override Task LoadState(ExecutorBaseState data) } /// - public override async Task SaveState(string filename) + public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { var state = (InstructExecutorState)GetStateData(); using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - await JsonSerializer.SerializeAsync(fs, state); + await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken); } } /// - public override async Task LoadState(string filename) + public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = await JsonSerializer.DeserializeAsync(fs); - await LoadState(state!); + await LoadState(state!, cancellationToken); } } /// - protected override Task GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run); } /// - protected override Task PreprocessInputs(string? text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { - args.Antiprompts ??= [ ]; + args.Antiprompts ??= []; if (!args.Antiprompts.Contains(_instructionPrefix)) args.Antiprompts.Add(_instructionPrefix); @@ -154,19 +155,19 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } /// - protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) { if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) { args.WaitForInput = true; - return (true, Array.Empty()); + return Task.FromResult<(bool, IReadOnlyList)>((true, [])); } if (_pastTokensCount > 0 && args.WaitForInput) { - return (true, new[] { "\n> " }); + return Task.FromResult<(bool, IReadOnlyList)>((true, ["\n> "])); } } @@ -180,11 +181,12 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) args.RemainedTokens = inferenceParams.MaxTokens; args.WaitForInput = true; } - return (false, Array.Empty()); + + return Task.FromResult<(bool, IReadOnlyList)>((false, [])); } /// - protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { var batch = new LLamaBatch(); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 7c9558ee3..f05b1c974 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -1,14 +1,15 @@ -using LLama.Common; -using LLama.Native; -using LLama.Abstractions; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text.Json; using System.Text.Json.Serialization; +using System.Threading; using System.Threading.Tasks; +using LLama.Abstractions; +using LLama.Common; using LLama.Exceptions; +using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -21,7 +22,7 @@ namespace LLama public class InteractiveExecutor : StatefulExecutorBase { private bool _is_prompt_run = true; - + // LLava private int _EmbedImagePosition = -1; private List _imageEmbedHandles = new List(); @@ -36,7 +37,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { } - + /// /// /// @@ -46,7 +47,7 @@ public InteractiveExecutor(LLamaContext context, ILogger? logger = null) public InteractiveExecutor(LLamaContext context, LLavaWeights clipModel, ILogger? logger = null) : base(context, clipModel, logger) { - } + } /// public override ExecutorBaseState GetStateData() @@ -68,7 +69,7 @@ public override ExecutorBaseState GetStateData() return state; } /// - public override Task LoadState(ExecutorBaseState data) + public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default) { if (data is InteractiveExecutorState state) { @@ -88,22 +89,24 @@ public override Task LoadState(ExecutorBaseState data) return Task.CompletedTask; } + /// - public override async Task SaveState(string filename) + public override async Task SaveState(string filename, CancellationToken cancellationToken = default) { var state = (InteractiveExecutorState)GetStateData(); - using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) + using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write)) { - await JsonSerializer.SerializeAsync(fs, state); + await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken); } } + /// - public override async Task LoadState(string filename) + public override async Task LoadState(string filename, CancellationToken cancellationToken = default) { using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read)) { var state = await JsonSerializer.DeserializeAsync(fs); - await LoadState(state!); + await LoadState(state!, cancellationToken); } } @@ -111,13 +114,13 @@ public override async Task LoadState(string filename) /// Define whether to continue the loop to generate responses. /// /// - protected override Task GetLoopCondition(InferStateArgs args) + protected override Task GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken) { return Task.FromResult(args.RemainedTokens != 0 && !args.WaitForInput || _is_prompt_run); } /// - protected override Task PreprocessInputs(string? text, InferStateArgs args) + protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken) { if (_is_prompt_run) { @@ -159,8 +162,8 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) } /// - private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = true ) - { + private void PreprocessLlava(string text, InferStateArgs args, bool addBos = true) + { // If the prompt contains the tag extract this. _imageInPrompt = text.Contains(""); if (_imageInPrompt && IsMultiModal) @@ -191,10 +194,9 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru { var line_inp = Context.Tokenize(text, false, true); _embed_inps.AddRange(line_inp); - args.RemainedTokens -= line_inp.Length; + args.RemainedTokens -= line_inp.Length; } } - return Task.CompletedTask; } /// @@ -202,21 +204,26 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru /// /// /// + /// /// - protected override async Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args) + protected override Task<(bool, IReadOnlyList)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { if (_embed_inps.Count <= _consumedTokensCount) { if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + { args.WaitForInput = true; + } if (_pastTokensCount > 0 && args.WaitForInput) - return (true, Array.Empty()); + { + return Task.FromResult<(bool, IReadOnlyList)>((true, [])); + } } if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab)) { - return (true, Array.Empty()); + return Task.FromResult<(bool, IReadOnlyList)>((true, [])); } if (args.RemainedTokens <= 0 && inferenceParams.MaxTokens != -1) @@ -225,11 +232,11 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru args.WaitForInput = true; } - return (false, Array.Empty()); + return Task.FromResult<(bool, IReadOnlyList)>((false, [])); } /// - protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken) { var batch = new LLamaBatch(); @@ -258,18 +265,18 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In // Changes to support Multi-Modal LLMs. // (DecodeResult, int, int) header, end, result; - if (IsMultiModal && _EmbedImagePosition > 0) + if (IsMultiModal && _EmbedImagePosition > 0) { // Tokens previous to the images header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = header.Item3; if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); - + // Images - foreach( var image in _imageEmbedHandles ) + foreach (var image in _imageEmbedHandles) ClipModel!.EvalImageEmbed(Context, image, ref _pastTokensCount); - + // Post-image Tokens end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); _pastTokensCount = end.Item3; @@ -285,7 +292,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); } - + if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession)) {