@@ -47,7 +47,7 @@ object Llm:
47
47
for
48
48
llm <- llm
49
49
config <- LlmParams .parse(params)
50
- ctx <- createContext(llm, config.context, false )
50
+ ctx <- createContext(llm, config.context, llamaParams( false ) )
51
51
_ <- loadLora(llm, ctx, config.lora)
52
52
yield SlincLlm (ctx).generate(prompt, config)
53
53
@@ -56,9 +56,18 @@ object Llm:
56
56
params : EmbeddingParams
57
57
): Result [Array [Float ]] =
58
58
for
59
+ _ <- Either .cond(
60
+ params.poolingType != Llama .PoolingType .RANK ,
61
+ params,
62
+ LlmError .ConfigError (" Rank pooling type is not supported" )
63
+ )
59
64
llm <- llm
60
65
config <- EmbeddingParams .parse(params)
61
- ctx <- createContext(llm, config.context, true )
66
+ ctx <- createContext(
67
+ llm,
68
+ config.context,
69
+ embeddingParams(params.poolingType)
70
+ )
62
71
yield SlincLlm (ctx).embeddings(prompt, config)
63
72
64
73
def close (): Unit =
@@ -104,19 +113,19 @@ object Llm:
104
113
private def createContext (
105
114
llm : Llama .Model ,
106
115
params : ContextParams ,
107
- embedding : Boolean
116
+ nativeParams : (
117
+ Llama .ContextParams ,
118
+ ContextParams
119
+ ) => Llama .ContextParams
108
120
): Result [Llama .Ctx ] =
109
121
val error = s " Cannot initialize model context ( $params) "
110
122
for
111
123
llama <- api
112
124
ctx <- catchNonFatal(
113
125
llama.llama_init_from_model(
114
126
model = llm,
115
- params = llamaParams(
116
- llama.llama_context_default_params(),
117
- params,
118
- embedding
119
- )
127
+ params =
128
+ nativeParams(llama.llama_context_default_params(), params)
120
129
)
121
130
)(error).filterOrElse(notNull, ModelError (error))
122
131
yield ctx
@@ -156,9 +165,10 @@ object Llm:
156
165
yield ()
157
166
158
167
private def llamaParams (
159
- defaultParams : Llama .ContextParams ,
160
- params : ContextParams ,
161
168
embedding : Boolean
169
+ )(
170
+ defaultParams : Llama .ContextParams ,
171
+ params : ContextParams
162
172
): Llama .ContextParams =
163
173
defaultParams.copy(
164
174
n_ctx = params.size,
@@ -178,6 +188,15 @@ object Llm:
178
188
embeddings = embedding
179
189
)
180
190
191
+ private def embeddingParams (
192
+ poolingType : Llama .PoolingType
193
+ )(
194
+ defaultParams : Llama .ContextParams ,
195
+ params : ContextParams
196
+ ): Llama .ContextParams =
197
+ llamaParams(true )(defaultParams, params)
198
+ .copy(pooling_type = poolingType)
199
+
181
200
private def catchNonFatal [A ](f : => A )(reason : => String ): Result [A ] =
182
201
try Right (f)
183
202
catch
0 commit comments