Skip to content

Commit 002a837

Browse files
committed
Support encoder-decoder models
1 parent 9a501f9 commit 002a837

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

src/main/scala/com/donderom/llm4s/Llama.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import fr.hammons.slinc.types.*
44
import fr.hammons.slinc.{CUnion, FSet, Ptr, Struct, Transform}
55

66
object Llama:
7+
val nullToken = -1
8+
79
type Pos = CInt
810
type Token = CInt
911
type SeqId = CInt

src/main/scala/com/donderom/llm4s/SlincLlm.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,35 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx):
125125

126126
ids.foreach(lastTokens.append)
127127

128+
// Support encoder-decoder models
129+
val encoder = llama.llama_model_has_encoder(model)
130+
if encoder then
131+
Scope.confined:
132+
llama.llama_encode(
133+
ctx = ctx,
134+
batch = llama.llama_batch_get_one(Ptr.copy(ids), ids.size)
135+
)
136+
val decStartToken = llama.llama_model_decoder_start_token(model)
137+
if !nullToken(decStartToken) then lastTokens.append(decStartToken)
138+
else lastTokens.append(llama.llama_vocab_bos(vocab))
139+
128140
val gen = (e: Evaluated) => tokens(State[Token](params.predictTokens, e))
129141
Usage(
130142
ids.size,
131-
if params.echo then promptTokens(ids, Array()) #::: gen(Evaluated.none)
143+
if encoder then gen(Evaluated(ids.size))
144+
else if params.echo then promptTokens(ids) #::: gen(Evaluated.none)
132145
else gen(evaluate(ids, Evaluated.none, params.context.batch))
133146
)
134147
end generate
135148

149+
def promptTokens(ids: Array[Int]): LazyList[Token] =
150+
promptTokens(ids, Array())
151+
136152
def promptTokens(ids: Array[Int], pending: Array[Byte]): LazyList[Token] =
137153
if ids.isEmpty then LazyList.empty
138154
else
139155
decode(ids.head, pending) match
140-
case token: String => Token(token) #:: promptTokens(ids.tail, Array())
156+
case token: String => Token(token) #:: promptTokens(ids.tail)
141157
case partial: Array[Byte] => promptTokens(ids.tail, partial)
142158

143159
def embeddings(prompt: String, params: EmbeddingParams): Array[Float] =
@@ -184,6 +200,8 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx):
184200
lazy val vocabSize: Int = llama.llama_vocab_n_tokens(vocab)
185201
lazy val addBos: Boolean = llama.llama_vocab_get_add_bos(vocab)
186202

203+
def nullToken(token: Int): Boolean = token == Llama.nullToken
204+
187205
def keepGenerating(token: Int): Boolean =
188206
!llama.llama_vocab_is_eog(vocab, token)
189207

0 commit comments

Comments
 (0)