Skip to content

Commit cc64dd6

Browse files
committed
Change the base effect from Try to Either
1 parent 002a837 commit cc64dd6

File tree

2 files changed

+146
-60
lines changed

2 files changed

+146
-60
lines changed

src/main/scala/com/donderom/llm4s/Llm.scala

Lines changed: 96 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,55 +4,66 @@ import fr.hammons.slinc.runtime.given
44
import fr.hammons.slinc.types.SizeT
55
import fr.hammons.slinc.{FSet, Ptr, Scope, Slinc}
66

7-
import java.nio.file.Path
8-
9-
import scala.util.{Success, Try}
7+
import java.nio.file.{Files, Path}
108

119
final case class Logprob(token: String, value: Double)
1210
final case class Probability(logprob: Logprob, candidates: Array[Logprob])
1311
final case class Token(value: String, probs: Vector[Probability] = Vector.empty)
1412
final case class Usage(promptSize: Int, tokens: LazyList[Token])
1513

14+
enum LlmError(message: String) extends Exception(message):
15+
case ModelError(message: String) extends LlmError(message)
16+
case ConfigError(message: String) extends LlmError(message)
17+
18+
import LlmError.ModelError
19+
20+
type Result[A] = Either[LlmError, A]
21+
object Result:
22+
def unit: Result[Unit] = Right(())
23+
1624
trait Llm(val modelPath: Path) extends AutoCloseable:
17-
def generate(prompt: String, params: LlmParams): Try[Usage]
25+
def generate(prompt: String, params: LlmParams): Result[Usage]
1826

19-
def embeddings(prompt: String): Try[Array[Float]] =
27+
def embeddings(prompt: String): Result[Array[Float]] =
2028
embeddings(prompt, EmbeddingParams())
2129

22-
def embeddings(prompt: String, params: EmbeddingParams): Try[Array[Float]]
30+
def embeddings(prompt: String, params: EmbeddingParams): Result[Array[Float]]
2331

24-
def apply(prompt: String): Try[LazyList[String]] = apply(prompt, LlmParams())
32+
def apply(prompt: String): Result[LazyList[String]] =
33+
apply(prompt, LlmParams())
2534

26-
def apply(prompt: String, params: LlmParams): Try[LazyList[String]] =
35+
def apply(prompt: String, params: LlmParams): Result[LazyList[String]] =
2736
generate(prompt, params).map(_.tokens.map(_.value))
2837

2938
object Llm:
3039
def apply(model: Path): Llm = apply(model, ModelParams())
3140

3241
def apply(model: Path, params: ModelParams): Llm =
3342
new Llm(model):
34-
val binding = Try(FSet.instance[Llama])
43+
val api = catchNonFatal(FSet.instance[Llama])("Cannot load libllama")
3544
val llm = createModel(model, params)
3645

37-
def generate(prompt: String, params: LlmParams): Try[Usage] =
46+
def generate(prompt: String, params: LlmParams): Result[Usage] =
3847
for
3948
llm <- llm
40-
ctx <- createContext(llm, params.context, false)
41-
_ <- loadLora(llm, ctx, params.lora)
42-
yield SlincLlm(ctx).generate(prompt, params)
49+
config <- LlmParams.parse(params)
50+
ctx <- createContext(llm, config.context, false)
51+
_ <- loadLora(llm, ctx, config.lora)
52+
yield SlincLlm(ctx).generate(prompt, config)
4353

4454
def embeddings(
4555
prompt: String,
4656
params: EmbeddingParams
47-
): Try[Array[Float]] =
57+
): Result[Array[Float]] =
4858
for
4959
llm <- llm
50-
ctx <- createContext(llm, params.context, true)
51-
yield SlincLlm(ctx).embeddings(prompt, params)
60+
config <- EmbeddingParams.parse(params)
61+
ctx <- createContext(llm, config.context, true)
62+
yield SlincLlm(ctx).embeddings(prompt, config)
5263

5364
def close(): Unit =
5465
for
55-
llama <- binding
66+
llama <- api
5667
llm <- llm
5768
do
5869
llama.llama_model_free(llm)
@@ -61,70 +72,88 @@ object Llm:
6172
private def createModel(
6273
model: Path,
6374
params: ModelParams
64-
): Try[Llama.Model] =
65-
binding.map: llama =>
66-
llama.llama_backend_init()
67-
llama.llama_numa_init(params.numa)
68-
Scope.confined:
69-
llama.llama_model_load_from_file(
70-
path_model = Ptr.copy(model.toAbsolutePath.toString),
71-
params = llama.llama_model_default_params().copy(
72-
n_gpu_layers = params.gpuLayers,
73-
main_gpu = params.mainGpu,
74-
use_mmap = params.mmap,
75-
use_mlock = params.mlock
75+
): Result[Llama.Model] =
76+
val error = s"Cannot load the model $model"
77+
for
78+
llama <- api
79+
path <- Either.cond(
80+
Files.exists(model),
81+
model,
82+
ModelError(s"Model file $model does not exist")
83+
)
84+
_ <- catchNonFatal(llama.llama_backend_init())(
85+
"Cannot load libllama backend"
86+
)
87+
_ <- catchNonFatal(llama.llama_numa_init(params.numa))(
88+
s"Cannot init Numa (${params.numa})"
89+
)
90+
m <- catchNonFatal(
91+
Scope.confined:
92+
llama.llama_model_load_from_file(
93+
path_model = Ptr.copy(path.toAbsolutePath.toString),
94+
params = llama.llama_model_default_params().copy(
95+
n_gpu_layers = params.gpuLayers,
96+
main_gpu = params.mainGpu,
97+
use_mmap = params.mmap,
98+
use_mlock = params.mlock
99+
)
76100
)
77-
)
101+
)(error).filterOrElse(notNull, ModelError(error))
102+
yield m
78103

79104
private def createContext(
80105
llm: Llama.Model,
81-
contextParams: ContextParams,
106+
params: ContextParams,
82107
embedding: Boolean
83-
): Try[Llama.Ctx] =
108+
): Result[Llama.Ctx] =
109+
val error = s"Cannot initialize model context ($params)"
84110
for
85-
llama <- binding
86-
ctx = llama.llama_init_from_model(
87-
model = llm,
88-
params = llamaParams(
89-
llama.llama_context_default_params(),
90-
contextParams,
91-
embedding
111+
llama <- api
112+
ctx <- catchNonFatal(
113+
llama.llama_init_from_model(
114+
model = llm,
115+
params = llamaParams(
116+
llama.llama_context_default_params(),
117+
params,
118+
embedding
119+
)
92120
)
93-
) if ctx != Slinc.getRuntime().Null
121+
)(error).filterOrElse(notNull, ModelError(error))
94122
yield ctx
95123

96124
private def loadLora(
97125
llm: Llama.Model,
98126
ctx: Llama.Ctx,
99127
lora: List[AdapterParams]
100-
): Try[Unit] =
101-
lora.map(loadAdapter(llm, ctx, _)).foldLeft(Try(())):
102-
case (acc, Success(_)) => acc
103-
case (_, failure) => failure
128+
): Result[Unit] =
129+
lora.map(loadAdapter(llm, ctx, _)).foldLeft(Result.unit):
130+
case (acc, Right(_)) => acc
131+
case (_, failure) => failure
104132

105133
private def loadAdapter(
106134
llm: Llama.Model,
107135
ctx: Llama.Ctx,
108136
params: AdapterParams
109-
): Try[Unit] =
110-
Scope.confined:
111-
for
112-
llama <- binding
113-
adapter <- Try(
137+
): Result[Unit] =
138+
val error = s"Cannot initialize LoRA adapter ($params)"
139+
for
140+
llama <- api
141+
config <- AdapterParams.parse(params)
142+
adapter <- catchNonFatal(
143+
Scope.confined:
114144
llama.llama_adapter_lora_init(
115145
model = llm,
116-
path_lora = Ptr.copy(params.path.toAbsolutePath.toString)
117-
)
118-
)
119-
if adapter != Slinc.getRuntime().Null
120-
_ <- Try(
121-
llama.llama_set_adapter_lora(
122-
ctx = ctx,
123-
adapter = adapter,
124-
scale = params.scale
146+
path_lora = Ptr.copy(config.path.toAbsolutePath.toString)
125147
)
148+
)(error).filterOrElse(notNull, ModelError(error))
149+
_ <- catchNonFatal(
150+
llama.llama_set_adapter_lora(
151+
ctx = ctx,
152+
adapter = adapter,
153+
scale = config.scale
126154
)
127-
yield ()
155+
)(error)
156+
yield ()
128157

129158
private def llamaParams(
130159
defaultParams: Llama.ContextParams,
@@ -148,3 +177,11 @@ object Llm:
148177
flash_attn = params.flashAttention,
149178
embeddings = embedding
150179
)
180+
181+
private def catchNonFatal[A](f: => A)(reason: => String): Result[A] =
182+
try Right(f)
183+
catch
184+
case t if scala.util.control.NonFatal(t) =>
185+
Left(ModelError(s"$reason: ${t.getMessage}"))
186+
187+
private def notNull(ptr: Ptr[Any]): Boolean = ptr != Slinc.getRuntime().Null

src/main/scala/com/donderom/llm4s/Params.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package com.donderom.llm4s
22

3-
import java.nio.file.Path
3+
import java.nio.file.{Files, Path}
44

5+
import LlmError.ConfigError
56
import Llama.{NumaStrategy, RopeScalingType}
67

78
object Default:
@@ -20,6 +21,11 @@ final case class AdapterParams(
2021
scale: Float = 1.0f
2122
)
2223

24+
object AdapterParams:
25+
def parse(params: AdapterParams): Result[AdapterParams] =
26+
if Files.exists(params.path) then Right(params)
27+
else Left(ConfigError(s"LoRA adapter file ${params.path} does not exist"))
28+
2329
final case class ModelParams(
2430
gpuLayers: Int = -1,
2531
mainGpu: Int = 0,
@@ -48,8 +54,26 @@ final case class BatchParams(
4854
threads: Int = Default.threads
4955
)
5056

57+
object BatchParams:
58+
def parse(params: BatchParams): Result[BatchParams] =
59+
if params.logical < 1 then
60+
Left(ConfigError("Logical batch size should be positive"))
61+
else if params.physical < 1 then
62+
Left(ConfigError("Batch size should be positive"))
63+
else if params.threads < 1 then
64+
Left(ConfigError("Batch threads should be positive"))
65+
else Right(params)
66+
5167
final case class GroupAttention(factor: Int = 1, width: Int = 512)
5268

69+
object GroupAttention:
70+
def parse(params: GroupAttention): Result[GroupAttention] =
71+
if params.factor <= 0 then
72+
Left(ConfigError("Group attention factor should be positive"))
73+
else if params.width % params.factor != 0 then
74+
Left(ConfigError("Group attention width should be a multiple of factor"))
75+
else Right(params)
76+
5377
final case class ContextParams(
5478
size: Int = 4096,
5579
threads: Int = Default.threads,
@@ -59,6 +83,19 @@ final case class ContextParams(
5983
flashAttention: Boolean = false
6084
)
6185

86+
object ContextParams:
87+
def parse(params: ContextParams): Result[ContextParams] =
88+
val config =
89+
if params.size < 0 then
90+
Left(ConfigError("Context size should be positive"))
91+
else if params.threads < 1 then
92+
Left(ConfigError("Context threads should be positive"))
93+
else Right(params)
94+
for
95+
_ <- BatchParams.parse(params.batch)
96+
config <- config
97+
yield config
98+
6299
final case class Penalty(
63100
lastN: Int = 64,
64101
repeat: Float = 1.0f,
@@ -131,6 +168,11 @@ final case class EmbeddingParams(
131168
norm: Option[Norm] = None
132169
)
133170

171+
object EmbeddingParams:
172+
def parse(params: EmbeddingParams): Result[EmbeddingParams] =
173+
for _ <- ContextParams.parse(params.context)
174+
yield params
175+
134176
final case class LlmParams(
135177
context: ContextParams = ContextParams(),
136178
sampling: Sampling = Sampling.Dist(),
@@ -142,3 +184,10 @@ final case class LlmParams(
142184
groupAttention: GroupAttention = GroupAttention(),
143185
lora: List[AdapterParams] = Nil
144186
)
187+
188+
object LlmParams:
189+
def parse(params: LlmParams): Result[LlmParams] =
190+
for
191+
_ <- ContextParams.parse(params.context)
192+
_ <- GroupAttention.parse(params.groupAttention)
193+
yield params

0 commit comments

Comments
 (0)