Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit fdbe257

Browse files
committed
feat: rendering chat_template
1 parent 5414e02 commit fdbe257

17 files changed

+4402
-168
lines changed

engine/cli/commands/chat_completion_cmd.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) {
5050

5151
return data_length;
5252
}
53-
5453
} // namespace
5554

5655
void ChatCompletionCmd::Exec(const std::string& host, int port,
@@ -103,7 +102,7 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
103102
return;
104103
}
105104

106-
std::string url = "http://" + address + "/v1/chat/completions";
105+
auto url = "http://" + address + "/v1/chat/completions";
107106
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
108107
curl_easy_setopt(curl, CURLOPT_POST, 1L);
109108

@@ -151,18 +150,18 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
151150
json_data["model"] = model_handle;
152151
json_data["stream"] = true;
153152

154-
std::string json_payload = json_data.toStyledString();
155-
156-
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str());
153+
auto json_str = json_data.toStyledString();
154+
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str());
155+
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, json_str.length());
156+
curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L);
157157

158158
std::string ai_chat;
159159
StreamingCallback callback;
160160
callback.ai_chat = &ai_chat;
161161

162162
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
163163
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &callback);
164-
165-
CURLcode res = curl_easy_perform(curl);
164+
auto res = curl_easy_perform(curl);
166165

167166
if (res != CURLE_OK) {
168167
CLI_LOG("CURL request failed: " << curl_easy_strerror(res));

engine/common/model_metadata.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <sstream>
4+
#include "common/tokenizer.h"
5+
6+
struct ModelMetadata {
7+
uint32_t version;
8+
uint64_t tensor_count;
9+
uint64_t metadata_kv_count;
10+
std::shared_ptr<Tokenizer> tokenizer;
11+
12+
std::string ToString() const {
13+
std::ostringstream ss;
14+
ss << "ModelMetadata {\n"
15+
<< "version: " << version << "\n"
16+
<< "tensor_count: " << tensor_count << "\n"
17+
<< "metadata_kv_count: " << metadata_kv_count << "\n"
18+
<< "tokenizer: ";
19+
20+
if (tokenizer) {
21+
ss << "\n" << tokenizer->ToString();
22+
} else {
23+
ss << "null";
24+
}
25+
26+
ss << "\n}";
27+
return ss.str();
28+
}
29+
};

engine/common/tokenizer.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
#include <sstream>
4+
#include <string>
5+
6+
struct Tokenizer {
7+
std::string eos_token = "";
8+
bool add_eos_token = true;
9+
10+
std::string bos_token = "";
11+
bool add_bos_token = true;
12+
13+
std::string unknown_token = "";
14+
std::string padding_token = "";
15+
16+
std::string chat_template = "";
17+
18+
bool add_generation_prompt = true;
19+
20+
// Helper function for common fields
21+
std::string BaseToString() const {
22+
std::ostringstream ss;
23+
ss << "eos_token: \"" << eos_token << "\"\n"
24+
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
25+
<< "bos_token: \"" << bos_token << "\"\n"
26+
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
27+
<< "unknown_token: \"" << unknown_token << "\"\n"
28+
<< "padding_token: \"" << padding_token << "\"\n"
29+
<< "chat_template: \"" << chat_template << "\"\n"
30+
<< "add_generation_prompt: "
31+
<< (add_generation_prompt ? "true" : "false") << "\"";
32+
return ss.str();
33+
}
34+
35+
virtual ~Tokenizer() = default;
36+
37+
virtual std::string ToString() = 0;
38+
};
39+
40+
struct GgufTokenizer : public Tokenizer {
41+
std::string pre = "";
42+
43+
~GgufTokenizer() override = default;
44+
45+
std::string ToString() override {
46+
std::ostringstream ss;
47+
ss << "GgufTokenizer {\n";
48+
// Add base class members
49+
ss << BaseToString() << "\n";
50+
// Add derived class members
51+
ss << "pre: \"" << pre << "\"\n";
52+
ss << "}";
53+
return ss.str();
54+
}
55+
};
56+
57+
struct SafeTensorTokenizer : public Tokenizer {
58+
bool add_prefix_space = true;
59+
60+
~SafeTensorTokenizer() = default;
61+
62+
std::string ToString() override {
63+
std::ostringstream ss;
64+
ss << "SafeTensorTokenizer {\n";
65+
// Add base class members
66+
ss << BaseToString() << "\n";
67+
// Add derived class members
68+
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
69+
ss << "}";
70+
return ss.str();
71+
}
72+
};

engine/controllers/files.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,8 @@ void Files::RetrieveFileContent(
216216
return;
217217
}
218218

219-
auto [buffer, size] = std::move(res.value());
220-
auto resp = HttpResponse::newHttpResponse();
221-
resp->setBody(std::string(buffer.get(), size));
222-
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
219+
auto resp =
220+
cortex_utils::CreateCortexContentResponse(std::move(res.value()));
223221
callback(resp);
224222
} else {
225223
if (!msg_res->rel_path.has_value()) {
@@ -243,10 +241,8 @@ void Files::RetrieveFileContent(
243241
return;
244242
}
245243

246-
auto [buffer, size] = std::move(content_res.value());
247-
auto resp = HttpResponse::newHttpResponse();
248-
resp->setBody(std::string(buffer.get(), size));
249-
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
244+
auto resp = cortex_utils::CreateCortexContentResponse(
245+
std::move(content_res.value()));
250246
callback(resp);
251247
}
252248
}
@@ -261,9 +257,6 @@ void Files::RetrieveFileContent(
261257
return;
262258
}
263259

264-
auto [buffer, size] = std::move(res.value());
265-
auto resp = HttpResponse::newHttpResponse();
266-
resp->setBody(std::string(buffer.get(), size));
267-
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
260+
auto resp = cortex_utils::CreateCortexContentResponse(std::move(res.value()));
268261
callback(resp);
269262
}

engine/controllers/server.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "trantor/utils/Logger.h"
44
#include "utils/cortex_utils.h"
55
#include "utils/function_calling/common.h"
6-
#include "utils/http_util.h"
76

87
using namespace inferences;
98

@@ -27,6 +26,14 @@ void server::ChatCompletion(
2726
std::function<void(const HttpResponsePtr&)>&& callback) {
2827
LOG_DEBUG << "Start chat completion";
2928
auto json_body = req->getJsonObject();
29+
if (json_body == nullptr) {
30+
Json::Value ret;
31+
ret["message"] = "Body can't be empty";
32+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
33+
resp->setStatusCode(k400BadRequest);
34+
callback(resp);
35+
return;
36+
}
3037
bool is_stream = (*json_body).get("stream", false).asBool();
3138
auto model_id = (*json_body).get("model", "invalid_model").asString();
3239
auto engine_type = [this, &json_body]() -> std::string {

engine/main.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
159159
auto model_src_svc = std::make_shared<services::ModelSourceService>();
160160
auto model_service = std::make_shared<ModelService>(
161161
download_service, inference_svc, engine_service);
162+
inference_svc->SetModelService(model_service);
162163

163164
auto file_watcher_srv = std::make_shared<FileWatcherService>(
164165
model_dir_path.string(), model_service);

engine/services/engine_service.h

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <mutex>
55
#include <optional>
66
#include <string>
7-
#include <string_view>
87
#include <unordered_map>
98
#include <vector>
109

@@ -17,7 +16,6 @@
1716
#include "utils/cpuid/cpu_info.h"
1817
#include "utils/dylib.h"
1918
#include "utils/dylib_path_manager.h"
20-
#include "utils/engine_constants.h"
2119
#include "utils/github_release_utils.h"
2220
#include "utils/result.hpp"
2321
#include "utils/system_info_utils.h"
@@ -48,10 +46,6 @@ class EngineService : public EngineServiceI {
4846
struct EngineInfo {
4947
std::unique_ptr<cortex_cpp::dylib> dl;
5048
EngineV engine;
51-
#if defined(_WIN32)
52-
DLL_DIRECTORY_COOKIE cookie;
53-
DLL_DIRECTORY_COOKIE cuda_cookie;
54-
#endif
5549
};
5650

5751
std::mutex engines_mutex_;
@@ -105,21 +99,23 @@ class EngineService : public EngineServiceI {
10599

106100
cpp::result<DefaultEngineVariant, std::string> SetDefaultEngineVariant(
107101
const std::string& engine, const std::string& version,
108-
const std::string& variant);
102+
const std::string& variant) override;
109103

110104
cpp::result<DefaultEngineVariant, std::string> GetDefaultEngineVariant(
111-
const std::string& engine);
105+
const std::string& engine) override;
112106

113107
cpp::result<std::vector<EngineVariantResponse>, std::string>
114-
GetInstalledEngineVariants(const std::string& engine) const;
108+
GetInstalledEngineVariants(const std::string& engine) const override;
115109

116110
cpp::result<EngineV, std::string> GetLoadedEngine(
117111
const std::string& engine_name);
118112

119113
std::vector<EngineV> GetLoadedEngines();
120114

121-
cpp::result<void, std::string> LoadEngine(const std::string& engine_name);
122-
cpp::result<void, std::string> UnloadEngine(const std::string& engine_name);
115+
cpp::result<void, std::string> LoadEngine(
116+
const std::string& engine_name) override;
117+
cpp::result<void, std::string> UnloadEngine(
118+
const std::string& engine_name) override;
123119

124120
cpp::result<github_release_utils::GitHubRelease, std::string>
125121
GetLatestEngineVersion(const std::string& engine) const;
@@ -137,7 +133,7 @@ class EngineService : public EngineServiceI {
137133

138134
cpp::result<cortex::db::EngineEntry, std::string> GetEngineByNameAndVariant(
139135
const std::string& engine_name,
140-
const std::optional<std::string> variant = std::nullopt);
136+
const std::optional<std::string> variant = std::nullopt) override;
141137

142138
cpp::result<cortex::db::EngineEntry, std::string> UpsertEngine(
143139
const std::string& engine_name, const std::string& type,

engine/services/inference_service.cc

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <drogon/HttpTypes.h>
33
#include "utils/engine_constants.h"
44
#include "utils/function_calling/common.h"
5+
#include "utils/jinja_utils.h"
56

67
namespace services {
78
cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
@@ -24,6 +25,45 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2425
return cpp::fail(std::make_pair(stt, res));
2526
}
2627

28+
{
29+
auto model_id = json_body->get("model", "").asString();
30+
if (!model_id.empty()) {
31+
if (auto model_service = model_service_.lock()) {
32+
auto metadata_ptr = model_service->GetCachedModelMetadata(model_id);
33+
if (metadata_ptr != nullptr &&
34+
!metadata_ptr->tokenizer->chat_template.empty()) {
35+
auto tokenizer = metadata_ptr->tokenizer;
36+
auto messages = (*json_body)["messages"];
37+
Json::Value messages_jsoncpp(Json::arrayValue);
38+
for (auto message : messages) {
39+
messages_jsoncpp.append(message);
40+
}
41+
42+
Json::Value tools(Json::arrayValue);
43+
Json::Value template_data_json;
44+
template_data_json["messages"] = messages_jsoncpp;
45+
// template_data_json["tools"] = tools;
46+
47+
auto prompt_result = jinja::RenderTemplate(
48+
tokenizer->chat_template, template_data_json,
49+
tokenizer->bos_token, tokenizer->eos_token,
50+
tokenizer->add_bos_token, tokenizer->add_eos_token,
51+
tokenizer->add_generation_prompt);
52+
if (prompt_result.has_value()) {
53+
(*json_body)["prompt"] = prompt_result.value();
54+
Json::Value stops(Json::arrayValue);
55+
stops.append(tokenizer->eos_token);
56+
(*json_body)["stop"] = stops;
57+
} else {
58+
CTL_ERR("Failed to render prompt: " + prompt_result.error());
59+
}
60+
}
61+
}
62+
}
63+
}
64+
65+
CTL_INF("Json body inference: " + json_body->toStyledString());
66+
2767
auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
2868
if (!tool_choice.isNull()) {
2969
res["tool_choice"] = tool_choice;
@@ -297,4 +337,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
297337
}
298338
return true;
299339
}
300-
} // namespace services
340+
} // namespace services

engine/services/inference_service.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#include <mutex>
55
#include <queue>
66
#include "services/engine_service.h"
7+
#include "services/model_service.h"
78
#include "utils/result.hpp"
8-
#include "extensions/remote-engine/remote_engine.h"
9+
910
namespace services {
11+
1012
// Status and result
1113
using InferResult = std::pair<Json::Value, Json::Value>;
1214

@@ -58,7 +60,12 @@ class InferenceService {
5860
bool HasFieldInReq(std::shared_ptr<Json::Value> json_body,
5961
const std::string& field);
6062

63+
void SetModelService(std::shared_ptr<ModelService> model_service) {
64+
model_service_ = model_service;
65+
}
66+
6167
private:
6268
std::shared_ptr<EngineService> engine_service_;
69+
std::weak_ptr<ModelService> model_service_;
6370
};
6471
} // namespace services

0 commit comments

Comments
 (0)