Skip to content

Commit 3a271c7

Browse files
Add support for Amazon Nova (huggingface#1629)
* Add support for Amazon Nova * fix linting issue * fix: replace model.id check with `isNova` flag and remove debug logs --------- Co-authored-by: Nathan Sarrazin <[email protected]>
1 parent c3d680b commit 3a271c7

File tree

1 file changed

+58
-24
lines changed

1 file changed

+58
-24
lines changed

src/lib/server/endpoints/aws/endpointBedrock.ts

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export const endpointBedrockParametersSchema = z.object({
1111
region: z.string().default("us-east-1"),
1212
model: z.any(),
1313
anthropicVersion: z.string().default("bedrock-2023-05-31"),
14+
isNova: z.boolean().default(false),
1415
multimodal: z
1516
.object({
1617
image: createImageProcessorOptionsValidator({
@@ -34,7 +35,7 @@ export const endpointBedrockParametersSchema = z.object({
3435
export async function endpointBedrock(
3536
input: z.input<typeof endpointBedrockParametersSchema>
3637
): Promise<Endpoint> {
37-
const { region, model, anthropicVersion, multimodal } =
38+
const { region, model, anthropicVersion, multimodal, isNova } =
3839
endpointBedrockParametersSchema.parse(input);
3940

4041
let BedrockRuntimeClient, InvokeModelWithResponseStreamCommand;
@@ -59,24 +60,42 @@ export async function endpointBedrock(
5960
messages = messages.slice(1); // Remove the first system message from the array
6061
}
6162

62-
const formattedMessages = await prepareMessages(messages, imageProcessor);
63+
const formattedMessages = await prepareMessages(messages, model.id, imageProcessor);
6364

6465
let tokenId = 0;
6566
const parameters = { ...model.parameters, ...generateSettings };
6667
return (async function* () {
67-
const command = new InvokeModelWithResponseStreamCommand({
68-
body: Buffer.from(
69-
JSON.stringify({
70-
anthropic_version: anthropicVersion,
71-
max_tokens: parameters.max_new_tokens ? parameters.max_new_tokens : 4096,
72-
messages: formattedMessages,
73-
system,
74-
}),
75-
"utf-8"
76-
),
68+
const baseCommandParams = {
7769
contentType: "application/json",
7870
accept: "application/json",
7971
modelId: model.id,
72+
};
73+
74+
const maxTokens = parameters.max_new_tokens || 4096;
75+
76+
let bodyContent;
77+
if (isNova) {
78+
bodyContent = {
79+
messages: formattedMessages,
80+
inferenceConfig: {
81+
maxTokens,
82+
topP: 0.1,
83+
temperature: 1.0,
84+
},
85+
system: [{ text: system }],
86+
};
87+
} else {
88+
bodyContent = {
89+
anthropic_version: anthropicVersion,
90+
max_tokens: maxTokens,
91+
messages: formattedMessages,
92+
system,
93+
};
94+
}
95+
96+
const command = new InvokeModelWithResponseStreamCommand({
97+
...baseCommandParams,
98+
body: Buffer.from(JSON.stringify(bodyContent), "utf-8"),
8099
trace: "DISABLED",
81100
});
82101

@@ -86,21 +105,20 @@ export async function endpointBedrock(
86105

87106
for await (const item of response.body ?? []) {
88107
const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes));
89-
const chunk_type = chunk.type;
90-
91-
if (chunk_type === "content_block_delta") {
92-
text += chunk.delta.text;
108+
if ("contentBlockDelta" in chunk || chunk.type === "content_block_delta") {
109+
const chunkText = chunk.contentBlockDelta?.delta?.text || chunk.delta?.text || "";
110+
text += chunkText;
93111
yield {
94112
token: {
95113
id: tokenId++,
96-
text: chunk.delta.text,
114+
text: chunkText,
97115
logprob: 0,
98116
special: false,
99117
},
100118
generated_text: null,
101119
details: null,
102120
} satisfies TextGenerationStreamOutput;
103-
} else if (chunk_type === "message_stop") {
121+
} else if ("messageStop" in chunk || chunk.type === "message_stop") {
104122
yield {
105123
token: {
106124
id: tokenId++,
@@ -120,6 +138,7 @@ export async function endpointBedrock(
120138
// Prepare the messages excluding system prompts
121139
async function prepareMessages(
122140
messages: EndpointMessage[],
141+
isNova: boolean,
123142
imageProcessor: ReturnType<typeof makeImageProcessor>
124143
) {
125144
const formattedMessages = [];
@@ -128,9 +147,13 @@ async function prepareMessages(
128147
const content = [];
129148

130149
if (message.files?.length) {
131-
content.push(...(await prepareFiles(imageProcessor, message.files)));
150+
content.push(...(await prepareFiles(imageProcessor, isNova, message.files)));
151+
}
152+
if (isNova) {
153+
content.push({ text: message.content });
154+
} else {
155+
content.push({ type: "text", text: message.content });
132156
}
133-
content.push({ type: "text", text: message.content });
134157

135158
const lastMessage = formattedMessages[formattedMessages.length - 1];
136159
if (lastMessage && lastMessage.role === message.from) {
@@ -146,11 +169,22 @@ async function prepareMessages(
146169
// Process files and convert them to base64 encoded strings
147170
async function prepareFiles(
148171
imageProcessor: ReturnType<typeof makeImageProcessor>,
172+
isNova: boolean,
149173
files: MessageFile[]
150174
) {
151175
const processedFiles = await Promise.all(files.map(imageProcessor));
152-
return processedFiles.map((file) => ({
153-
type: "image",
154-
source: { type: "base64", media_type: "image/jpeg", data: file.image.toString("base64") },
155-
}));
176+
177+
if (isNova) {
178+
return processedFiles.map((file) => ({
179+
image: {
180+
format: file.mime.substring("image/".length),
181+
source: { bytes: file.image.toString("base64") },
182+
},
183+
}));
184+
} else {
185+
return processedFiles.map((file) => ({
186+
type: "image",
187+
source: { type: "base64", media_type: file.mime, data: file.image.toString("base64") },
188+
}));
189+
}
156190
}

0 commit comments

Comments
 (0)