@@ -125,19 +125,35 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx):
125
125
126
126
ids.foreach(lastTokens.append)
127
127
128
+ // Support encoder-decoder models
129
+ val encoder = llama.llama_model_has_encoder(model)
130
+ if encoder then
131
+ Scope .confined:
132
+ llama.llama_encode(
133
+ ctx = ctx,
134
+ batch = llama.llama_batch_get_one(Ptr .copy(ids), ids.size)
135
+ )
136
+ val decStartToken = llama.llama_model_decoder_start_token(model)
137
+ if ! nullToken(decStartToken) then lastTokens.append(decStartToken)
138
+ else lastTokens.append(llama.llama_vocab_bos(vocab))
139
+
128
140
val gen = (e : Evaluated ) => tokens(State [Token ](params.predictTokens, e))
129
141
Usage (
130
142
ids.size,
131
- if params.echo then promptTokens(ids, Array ()) #::: gen(Evaluated .none)
143
+ if encoder then gen(Evaluated (ids.size))
144
+ else if params.echo then promptTokens(ids) #::: gen(Evaluated .none)
132
145
else gen(evaluate(ids, Evaluated .none, params.context.batch))
133
146
)
134
147
end generate
135
148
149
+ def promptTokens (ids : Array [Int ]): LazyList [Token ] =
150
+ promptTokens(ids, Array ())
151
+
136
152
def promptTokens (ids : Array [Int ], pending : Array [Byte ]): LazyList [Token ] =
137
153
if ids.isEmpty then LazyList .empty
138
154
else
139
155
decode(ids.head, pending) match
140
- case token : String => Token (token) #:: promptTokens(ids.tail, Array () )
156
+ case token : String => Token (token) #:: promptTokens(ids.tail)
141
157
case partial : Array [Byte ] => promptTokens(ids.tail, partial)
142
158
143
159
def embeddings (prompt : String , params : EmbeddingParams ): Array [Float ] =
@@ -184,6 +200,8 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx):
184
200
lazy val vocabSize : Int = llama.llama_vocab_n_tokens(vocab)
185
201
lazy val addBos : Boolean = llama.llama_vocab_get_add_bos(vocab)
186
202
203
+ def nullToken (token : Int ): Boolean = token == Llama .nullToken
204
+
187
205
def keepGenerating (token : Int ): Boolean =
188
206
! llama.llama_vocab_is_eog(vocab, token)
189
207
0 commit comments