Skip to content

Commit 02f24ca

Browse files
committed
Add support for embeddings pooling type
1 parent 80c3364 commit 02f24ca

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

src/main/scala/com/donderom/llm4s/Llm.scala

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ object Llm:
4747
for
4848
llm <- llm
4949
config <- LlmParams.parse(params)
50-
ctx <- createContext(llm, config.context, false)
50+
ctx <- createContext(llm, config.context, llamaParams(false))
5151
_ <- loadLora(llm, ctx, config.lora)
5252
yield SlincLlm(ctx).generate(prompt, config)
5353

@@ -56,9 +56,18 @@ object Llm:
5656
params: EmbeddingParams
5757
): Result[Array[Float]] =
5858
for
59+
_ <- Either.cond(
60+
params.poolingType != Llama.PoolingType.RANK,
61+
params,
62+
LlmError.ConfigError("Rank pooling type is not supported")
63+
)
5964
llm <- llm
6065
config <- EmbeddingParams.parse(params)
61-
ctx <- createContext(llm, config.context, true)
66+
ctx <- createContext(
67+
llm,
68+
config.context,
69+
embeddingParams(params.poolingType)
70+
)
6271
yield SlincLlm(ctx).embeddings(prompt, config)
6372

6473
def close(): Unit =
@@ -104,19 +113,19 @@ object Llm:
104113
private def createContext(
105114
llm: Llama.Model,
106115
params: ContextParams,
107-
embedding: Boolean
116+
nativeParams: (
117+
Llama.ContextParams,
118+
ContextParams
119+
) => Llama.ContextParams
108120
): Result[Llama.Ctx] =
109121
val error = s"Cannot initialize model context ($params)"
110122
for
111123
llama <- api
112124
ctx <- catchNonFatal(
113125
llama.llama_init_from_model(
114126
model = llm,
115-
params = llamaParams(
116-
llama.llama_context_default_params(),
117-
params,
118-
embedding
119-
)
127+
params =
128+
nativeParams(llama.llama_context_default_params(), params)
120129
)
121130
)(error).filterOrElse(notNull, ModelError(error))
122131
yield ctx
@@ -156,9 +165,10 @@ object Llm:
156165
yield ()
157166

158167
private def llamaParams(
159-
defaultParams: Llama.ContextParams,
160-
params: ContextParams,
161168
embedding: Boolean
169+
)(
170+
defaultParams: Llama.ContextParams,
171+
params: ContextParams
162172
): Llama.ContextParams =
163173
defaultParams.copy(
164174
n_ctx = params.size,
@@ -178,6 +188,15 @@ object Llm:
178188
embeddings = embedding
179189
)
180190

191+
private def embeddingParams(
192+
poolingType: Llama.PoolingType
193+
)(
194+
defaultParams: Llama.ContextParams,
195+
params: ContextParams
196+
): Llama.ContextParams =
197+
llamaParams(true)(defaultParams, params)
198+
.copy(pooling_type = poolingType)
199+
181200
private def catchNonFatal[A](f: => A)(reason: => String): Result[A] =
182201
try Right(f)
183202
catch

src/main/scala/com/donderom/llm4s/Params.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ enum Norm:
218218

219219
final case class EmbeddingParams(
220220
context: ContextParams = ContextParams(),
221+
poolingType: Llama.PoolingType = Llama.PoolingType.NONE,
221222
// Normalisation for embeddings
222223
norm: Option[Norm] = None
223224
)

src/main/scala/com/donderom/llm4s/SlincLlm.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx):
161161
val ids = encode(prompt)
162162
val _ = evaluate(ids, Evaluated.none, params.context.batch)
163163
val size = llama.llama_model_n_embd(model)
164-
val embeddings = llama.llama_get_embeddings(ctx).asArray(size).unsafeArray
164+
val embeddings =
165+
if params.poolingType == Llama.PoolingType.NONE then
166+
llama.llama_get_embeddings(ctx).asArray(size).unsafeArray
167+
else llama.llama_get_embeddings_seq(ctx, 0).asArray(size).unsafeArray
165168
llama.llama_free(ctx)
166169

167170
def normalized(

0 commit comments

Comments
 (0)