Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions include/minja/minja.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline };

class TemplateToken {
public:
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter };
enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue };

static std::string typeToString(Type t) {
switch (t) {
Expand All @@ -714,6 +714,8 @@ class TemplateToken {
case Type::EndFilter: return "endfilter";
case Type::Generation: return "generation";
case Type::EndGeneration: return "endgeneration";
case Type::Break: return "break";
case Type::Continue: return "continue";
}
return "Unknown";
}
Expand Down Expand Up @@ -815,6 +817,22 @@ struct CommentTemplateToken : public TemplateToken {
CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {}
};

enum class LoopControlType { Break, Continue };

class LoopControlException : public std::runtime_error {
public:
LoopControlType control_type;
LoopControlException(const std::string & message, LoopControlType control_type) : std::runtime_error(message), control_type(control_type) {}
LoopControlException(LoopControlType control_type)
: std::runtime_error((std::ostringstream() << (control_type == LoopControlType::Continue ? "continue" : "break") << " outside of a loop").str()),
control_type(control_type) {}
};

struct LoopControlTemplateToken : public TemplateToken {
LoopControlType control_type;
LoopControlTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, location, pre, post), control_type(control_type) {}
};

class TemplateNode {
Location location_;
protected:
Expand All @@ -825,6 +843,12 @@ class TemplateNode {
void render(std::ostringstream & out, const std::shared_ptr<Context> & context) const {
try {
do_render(out, context);
} catch (const LoopControlException & e) {
// TODO: make stack creation lazy. Only needed if it was thrown outside of a loop.
std::ostringstream err;
err << e.what();
if (location_.source) err << error_location_suffix(*location_.source, location_.pos);
throw LoopControlException(err.str(), e.control_type);
} catch (const std::exception & e) {
std::ostringstream err;
err << e.what();
Expand Down Expand Up @@ -897,6 +921,15 @@ class IfNode : public TemplateNode {
}
};

class LoopControlNode : public TemplateNode {
LoopControlType control_type_;
public:
LoopControlNode(const Location & location, LoopControlType control_type) : TemplateNode(location), control_type_(control_type) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> &) const override {
throw LoopControlException(control_type_);
}
};

class ForNode : public TemplateNode {
std::vector<std::string> var_names;
std::shared_ptr<Expression> iterable;
Expand Down Expand Up @@ -961,7 +994,12 @@ class ForNode : public TemplateNode {
loop.set("last", i == (n - 1));
loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value());
loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value());
body->render(out, loop_context);
try {
body->render(out, loop_context);
} catch (const LoopControlException & e) {
if (e.control_type == LoopControlType::Break) break;
if (e.control_type == LoopControlType::Continue) continue;
}
}
}
};
Expand Down Expand Up @@ -2159,7 +2197,7 @@ class Parser {
static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
Expand Down Expand Up @@ -2291,6 +2329,9 @@ class Parser {
} else if (keyword == "endfilter") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndFilterTemplateToken>(location, pre_space, post_space));
} else if (keyword == "break" || keyword == "continue") {
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<LoopControlTemplateToken>(location, pre_space, post_space, keyword == "break" ? LoopControlType::Break : LoopControlType::Continue));
} else {
throw std::runtime_error("Unexpected block: " + keyword);
}
Expand Down Expand Up @@ -2414,6 +2455,8 @@ class Parser {
children.emplace_back(std::make_shared<FilterNode>(token->location, std::move(filter_token->filter), std::move(body)));
} else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
// Ignore comments
} else if (auto ctrl_token = dynamic_cast<LoopControlTemplateToken*>(token.get())) {
children.emplace_back(std::make_shared<LoopControlNode>(token->location, ctrl_token->control_type));
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
|| dynamic_cast<EndSetTemplateToken*>(token.get())
|| dynamic_cast<EndMacroTemplateToken*>(token.get())
Expand Down
3 changes: 2 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ set(MODEL_IDS

if(NOT WIN32)
list(APPEND MODEL_IDS
# Needs investigation
# Needs investigation (https://github.com/google/minja/issues/40)
CohereForAI/c4ai-command-r7b-12-2024 # Gated
deepseek-ai/deepseek-coder-33b-instruct
deepseek-ai/DeepSeek-Coder-V2-Instruct
deepseek-ai/DeepSeek-V2.5
Expand Down
10 changes: 10 additions & 0 deletions tests/test-syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ TEST(SyntaxTest, SimpleCases) {
auto ThrowsWithSubstr = [](const std::string & expected_substr) {
return testing::Throws<std::runtime_error>(Property(&std::runtime_error::what, testing::HasSubstr(expected_substr)));
};

EXPECT_EQ(
" b",
render(R"( {% set _ = 1 %} {% set _ = 2 %}b)", {}, lstrip_trim_blocks));
Expand Down Expand Up @@ -486,10 +487,19 @@ TEST(SyntaxTest, SimpleCases) {
"",
render("{% if 1 %}{% elif 1 %}{% else %}{% endif %}", {}, {}));

EXPECT_EQ(
"0,1,2,",
render("{% for i in range(10) %}{{ i }},{% if i == 2 %}{% break %}{% endif %}{% endfor %}", {}, {}));
EXPECT_EQ(
"0,2,4,6,8,",
render("{% for i in range(10) %}{% if i % 2 %}{% continue %}{% endif %}{{ i }},{% endfor %}", {}, {}));

if (!getenv("USE_JINJA2")) {
// TODO: capture stderr from jinja2 and test these.

EXPECT_THAT([]() { render("{% break %}", {}, {}); }, ThrowsWithSubstr("break outside of a loop"));
EXPECT_THAT([]() { render("{% continue %}", {}, {}); }, ThrowsWithSubstr("continue outside of a loop"));

EXPECT_THAT([]() { render("{%- set _ = [].pop() -%}", {}, {}); }, ThrowsWithSubstr("pop from empty list"));
EXPECT_THAT([]() { render("{%- set _ = {}.pop() -%}", {}, {}); }, ThrowsWithSubstr("pop"));
EXPECT_THAT([]() { render("{%- set _ = {}.pop('foooo') -%}", {}, {}); }, ThrowsWithSubstr("foooo"));
Expand Down