Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion LLama.Examples/Examples/QuantizeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ namespace LLama.Examples.Examples
{
public class QuantizeModel
{
public static async Task Run()
public static Task Run()
{
string inputPath = UserSettings.GetModelPath();

Expand All @@ -20,6 +20,8 @@ public static async Task Run()
{
Console.WriteLine("Quantization failed!");
}

return Task.CompletedTask;
}
}
}
23 changes: 13 additions & 10 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ public class ChatSession
/// <param name="executor">The executor for this session</param>
/// <param name="history">History for this session</param>
/// <param name="transform">History Transform for this session</param>
/// <param name="cancellationToken">A token that cancels the operation</param>
/// <returns>A new chat session.</returns>
public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null)
ILLamaExecutor executor, ChatHistory history, IHistoryTransform? transform = null, CancellationToken cancellationToken = default)
{
if (executor is not StatefulExecutorBase statefulExecutor)
{
Expand All @@ -90,7 +91,7 @@ public static async Task<ChatSession> InitializeSessionFromHistoryAsync(
session = session.WithHistoryTransform(transform);
}

await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history));
await statefulExecutor.PrefillPromptAsync(session.HistoryTransform.HistoryToText(history), cancellationToken);
return session;
}

Expand Down Expand Up @@ -311,13 +312,15 @@ public ChatSession RemoveLastMessage()
/// Compute KV cache for the message and add it to the chat history.
/// </summary>
/// <param name="message"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message, CancellationToken cancellationToken = default)
{
if (Executor is not StatefulExecutorBase statefulExecutor)
{
throw new InvalidOperationException("Executor must be a StatefulExecutorBase to support pre-processing of system messages.");
}

AddMessage(message);
var content = message.Content;
if (message.AuthorRole != AuthorRole.Assistant)
Expand All @@ -328,27 +331,27 @@ public async Task<ChatSession> AddAndProcessMessage(ChatHistory.Message message)
}
}

await statefulExecutor.PrefillPromptAsync(content);
await statefulExecutor.PrefillPromptAsync(content, cancellationToken);
return this;
}

/// <summary>
/// Compute KV cache for the system message and add it to the chat history.
/// </summary>
public Task<ChatSession> AddAndProcessSystemMessage(string content)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content));
public Task<ChatSession> AddAndProcessSystemMessage(string content, CancellationToken cancellationToken = default)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.System, content), cancellationToken);

/// <summary>
/// Compute KV cache for the user message and add it to the chat history.
/// </summary>
public Task<ChatSession> AddAndProcessUserMessage(string content)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content));
public Task<ChatSession> AddAndProcessUserMessage(string content, CancellationToken cancellationToken = default)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.User, content), cancellationToken);

/// <summary>
/// Compute KV cache for the assistant message and add it to the chat history.
/// </summary>
public Task<ChatSession> AddAndProcessAssistantMessage(string content)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content));
public Task<ChatSession> AddAndProcessAssistantMessage(string content, CancellationToken cancellationToken = default)
=> AddAndProcessMessage(new ChatHistory.Message(AuthorRole.Assistant, content), cancellationToken);

/// <summary>
/// Replace a user message with a new message and remove all messages after the new message.
Expand Down
19 changes: 10 additions & 9 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System.IO;
using System.IO.MemoryMappedFiles;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Native;
using Microsoft.Extensions.Logging;
using System.Threading;

namespace LLama
{
Expand Down Expand Up @@ -73,7 +73,7 @@ public int BatchThreads
/// Get the special tokens for the model associated with this context
/// </summary>
public SafeLlamaModelHandle.Vocabulary Vocab { get; }

/// <summary>
/// Create a new LLamaContext for the given LLamaWeights
/// </summary>
Expand Down Expand Up @@ -396,7 +396,7 @@ public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancel
{
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
Expand All @@ -406,10 +406,10 @@ public DecodeResult Decode(LLamaBatchEmbeddings batch)
return 0;
if (batch.EmbeddingsCount > BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (DecodeResult)NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
Expand All @@ -425,15 +425,16 @@ public Task<DecodeResult> DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
/// <param name="id"></param>
/// <param name="batch"></param>
/// <param name="n_past"></param>
/// <param name="cancellationToken"></param>
/// <returns>A tuple, containing the decode result, the number of tokens that have <b>not</b> been decoded yet and the total number of tokens that have been decoded.</returns>
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past)
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past, CancellationToken cancellationToken = default)
{
return Task.Run(() =>
{
var past = n_past;
var res = NativeHandle.Decode(tokens, id, batch, ref past);
return (res.Item1, res.Item2, past);
});
}, cancellationToken);
}
#endregion

Expand Down
38 changes: 23 additions & 15 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,36 +239,41 @@ protected virtual void TryReuseMatchingPrefix()
/// Decide whether to continue the loop.
/// </summary>
/// <param name="args"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected abstract Task<bool> GetLoopCondition(InferStateArgs args);
protected abstract Task<bool> GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken = default);

/// <summary>
/// Preprocess the inputs before the inference.
/// </summary>
/// <param name="text"></param>
/// <param name="args"></param>
protected abstract Task PreprocessInputs(string? text, InferStateArgs args);
/// <param name="cancellationToken"></param>
protected abstract Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken = default);

/// <summary>
/// Do some post processing after the inference.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);

/// <summary>
/// The core inference logic.
/// </summary>
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
/// <param name="cancellationToken"></param>
protected abstract Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken = default);

/// <summary>
/// Save the current state to a file.
/// </summary>
/// <param name="filename"></param>
public abstract Task SaveState(string filename);
/// <param name="cancellationToken"></param>
public abstract Task SaveState(string filename, CancellationToken cancellationToken = default);

/// <summary>
/// Get the current state data.
Expand All @@ -280,13 +285,15 @@ protected virtual void TryReuseMatchingPrefix()
/// Load the state from data.
/// </summary>
/// <param name="data"></param>
public abstract Task LoadState(ExecutorBaseState data);
/// <param name="cancellationToken"></param>
public abstract Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default);

/// <summary>
/// Load the state from a file.
/// </summary>
/// <param name="filename"></param>
public abstract Task LoadState(string filename);
/// <param name="cancellationToken"></param>
public abstract Task LoadState(string filename, CancellationToken cancellationToken = default);


/// <summary>
Expand All @@ -310,23 +317,23 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count
};

await PreprocessInputs(text, args);
await PreprocessInputs(text, args, cancellationToken);

while (await GetLoopCondition(args))
while (await GetLoopCondition(args, cancellationToken))
{
if (cancellationToken.IsCancellationRequested)
{
break;
}
await InferInternal(inferenceParams, args);
await InferInternal(inferenceParams, args, cancellationToken);

if (args.ReturnValue)
{
_decoder.AddRange(_embeds);
yield return _decoder.Read();
}

var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args, cancellationToken);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
Expand All @@ -346,8 +353,9 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
/// It could reduce the latency of the first time response if the first input from the user is not immediate.
/// </summary>
/// <param name="prompt">Prompt to process</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public virtual async Task PrefillPromptAsync(string prompt)
public virtual async Task PrefillPromptAsync(string prompt, CancellationToken cancellationToken = default)
{
var inferenceParams = new InferenceParams
{
Expand All @@ -362,11 +370,11 @@ public virtual async Task PrefillPromptAsync(string prompt)
NeedToSaveSession = false
};

await PreprocessInputs(prompt, args);
await PreprocessInputs(prompt, args, cancellationToken);
// First run adds the prompt to the _embeds
await InferInternal(inferenceParams, args);
await InferInternal(inferenceParams, args, cancellationToken);
// Second run puts it through decode
await InferInternal(inferenceParams, args);
await InferInternal(inferenceParams, args, cancellationToken);
}

/// <summary>
Expand Down
36 changes: 19 additions & 17 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
using LLama.Abstractions;
using LLama.Common;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
using LLama.Common;
using LLama.Exceptions;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -65,9 +66,9 @@
return state;
}
/// <inheritdoc />
public override Task LoadState(ExecutorBaseState data)
public override Task LoadState(ExecutorBaseState data, CancellationToken cancellationToken = default)
{
if(data is InstructExecutorState state)
if (data is InstructExecutorState state)
{
_n_session_consumed = state.ConsumedSessionCount;
_embed_inps = state.EmbedInps!.ToList();
Expand All @@ -91,34 +92,34 @@
}

/// <inheritdoc />
public override async Task SaveState(string filename)
public override async Task SaveState(string filename, CancellationToken cancellationToken = default)
{
var state = (InstructExecutorState)GetStateData();
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
await JsonSerializer.SerializeAsync(fs, state);
await JsonSerializer.SerializeAsync(fs, state, cancellationToken: cancellationToken);
}
}
/// <inheritdoc />
public override async Task LoadState(string filename)
public override async Task LoadState(string filename, CancellationToken cancellationToken = default)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state!);
await LoadState(state!, cancellationToken);
}
}

/// <inheritdoc />
protected override Task<bool> GetLoopCondition(InferStateArgs args)
protected override Task<bool> GetLoopCondition(InferStateArgs args, CancellationToken cancellationToken)
{
return Task.FromResult(args.RemainedTokens != 0 || _is_prompt_run);
}

/// <inheritdoc />
protected override Task PreprocessInputs(string? text, InferStateArgs args)
protected override Task PreprocessInputs(string? text, InferStateArgs args, CancellationToken cancellationToken)
{
args.Antiprompts ??= [ ];
args.Antiprompts ??= [];
if (!args.Antiprompts.Contains(_instructionPrefix))
args.Antiprompts.Add(_instructionPrefix);

Expand Down Expand Up @@ -154,19 +155,19 @@
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
protected override Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 162 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 162 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 162 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 162 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 162 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, []));
}

if (_pastTokensCount > 0 && args.WaitForInput)
{
return (true, new[] { "\n> " });
return Task.FromResult<(bool, IReadOnlyList<string>)>((true, ["\n> "]));
}
}

Expand All @@ -180,11 +181,12 @@
args.RemainedTokens = inferenceParams.MaxTokens;
args.WaitForInput = true;
}
return (false, Array.Empty<string>());

return Task.FromResult<(bool, IReadOnlyList<string>)>((false, []));
}

/// <inheritdoc />
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args, CancellationToken cancellationToken)
{
var batch = new LLamaBatch();

Expand Down
Loading
Loading