Skip to content
94 changes: 94 additions & 0 deletions src/mongo/db/pipeline/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,34 @@ Value ExpressionArrayElemAt::evaluate(const Document& root) const {
return array[index];
}

intrusive_ptr<Expression> ExpressionArrayElemAt::optimize() {
// If ExpressionArrayElemAt is passed an ExpressionFilter as its first arugment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what you did in the ExpressionIndexOf optimization, I think you should first call ExpressionNary::optimize():

// This will optimize all arguments to this expression.
auto optimized = ExpressionNary::optimize();
if (optimized.get() != this) {
return optimized;
}

(Also, while you're there, it looks like the comment right below that got wrapped at a strange place, can you fix that?)

// set a limit on the filter so filter returns an array with the last element being the value we
// want.
if (dynamic_cast<ExpressionFilter*>(vpOperand[0].get())) {
if (auto expConstant = dynamic_cast<ExpressionConstant*>(vpOperand[1].get())) {
auto indexArg = expConstant->getValue();

uassert(50802,
str::stream() << getOpName() << "'s second argument must be a numeric value,"
<< " but is "
<< typeName(indexArg.getType()),
indexArg.numeric());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is redundant with the one below. If something is integral it must be numeric. https://github.com/mongodb/mongo/blob/master/src/mongo/db/pipeline/value.cpp#L1010

uassert(50803,
str::stream() << getOpName() << "'s second argument must be representable as"
<< " a 32-bit integer: "
<< indexArg.coerceToDouble(),
indexArg.integral());
auto ind = indexArg.coerceToInt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use full names, index

// Can't optimize of the index is less that 0.
if (ind >= 0) {
dynamic_cast<ExpressionFilter*>(vpOperand[0].get())->setLimit(ind + 1);
}
}
}
return this;
};

REGISTER_EXPRESSION(arrayElemAt, ExpressionArrayElemAt::parse);
const char* ExpressionArrayElemAt::getOpName() const {
return "$arrayElemAt";
Expand Down Expand Up @@ -2206,12 +2234,19 @@ Value ExpressionFilter::evaluate(const Document& root) const {

if (_filter->evaluate(root).coerceToBool()) {
output.push_back(std::move(elem));
if(_limit && static_cast<int>(output.size()) == _limit.get()){
return Value(std::move(output));
}
}
}

return Value(std::move(output));
}

void ExpressionFilter::setLimit(int limit) {
_limit = boost::optional<int>(limit);
}

void ExpressionFilter::_doAddDependencies(DepsTracker* deps) const {
_input->addDependencies(deps);
_filter->addDependencies(deps);
Expand Down Expand Up @@ -3900,6 +3935,65 @@ Value ExpressionSlice::evaluate(const Document& root) const {
return Value(vector<Value>(array.begin() + start, array.begin() + end));
}

intrusive_ptr<Expression> ExpressionSlice::optimize() {
// If ExpressionSlice is passed an ExpressionFilter we can stop filtering once the size of
// the array returned by the filter is equal to the last arguement passed to ExpressionSlice.
if (dynamic_cast<ExpressionFilter*>(vpOperand[0].get())) {
if (auto secondArg = dynamic_cast<ExpressionConstant*>(vpOperand[1].get())) {
auto secondVal = secondArg->getValue();

uassert(50797,
str::stream() << "Second argument to $slice must be a numeric value,"
<< " but is of type: "
<< typeName(secondVal.getType()),
secondVal.numeric());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I don't think you need the assertion about numeric().


uassert(50798,
str::stream() << "Second argument to $slice can't be represented as"
<< " a 32-bit integer: "
<< secondVal.coerceToDouble(),
secondVal.integral());

int arg2 = secondVal.coerceToInt();
if (vpOperand.size() == 2) {
// Can't set a limit if it is negative.
if (arg2 >= 0) {
// If slice is given two arguments set limit to the position we want to slice.
dynamic_cast<ExpressionFilter*>(vpOperand[0].get())->setLimit(arg2);
}
} else if (vpOperand.size() > 2) {
if (auto thirdArg = dynamic_cast<ExpressionConstant*>(vpOperand[2].get())) {
auto thirdVal = thirdArg->getValue();

uassert(50799,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can condense these three asserts into 'uassertIfNotIntegralAndNonNegative': https://github.com/mongodb/mongo/blob/master/src/mongo/db/pipeline/expression.cpp#L2704

str::stream() << "Third argument to $slice must be numeric, but "
<< "is of type: " << typeName(thirdVal.getType()),
thirdVal.numeric());

uassert(50800,
str::stream() << "Third argument to $slice can't be represented"
<< " as a 32-bit integer: "
<< thirdVal.coerceToDouble(),
thirdVal.integral());

uassert(50801,
str::stream() << "Third argument to $slice must be positive: "
<< thirdVal.coerceToInt(),
thirdVal.coerceToInt() > 0);

int arg3 = thirdVal.coerceToInt();
if (arg2 >= 0) {
// If ExpressionSlice is given three arguments set limit to 'firstArg' +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate a little bit more here? Maybe something like If we have three arguments, then we can stop the $filter once we have seen 'firstArg' + 'secondArg' + 1. We need to skip 'firstArg' elements, then include 'secondArg' things after that. For example, if we want to slice starting from 2 with length 3, we skip the first 2, then include 3 after that.

Wait, I think we don't need the +1 here? Is that right?

// 'secondArg' + 1.
dynamic_cast<ExpressionFilter*>(vpOperand[0].get())
->setLimit(arg2 + arg3 + 1);
}
}
}
}
}
return this;
}
REGISTER_EXPRESSION(slice, ExpressionSlice::parse);
const char* ExpressionSlice::getOpName() const {
return "$slice";
Expand Down
6 changes: 6 additions & 0 deletions src/mongo/db/pipeline/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ class ExpressionArrayElemAt final : public ExpressionFixedArity<ExpressionArrayE
public:
explicit ExpressionArrayElemAt(const boost::intrusive_ptr<ExpressionContext>& expCtx)
: ExpressionFixedArity<ExpressionArrayElemAt, 2>(expCtx) {}
boost::intrusive_ptr<Expression> optimize() final;

Value evaluate(const Document& root) const final;
const char* getOpName() const final;
Expand Down Expand Up @@ -1155,6 +1156,7 @@ class ExpressionFilter final : public Expression {
const boost::intrusive_ptr<ExpressionContext>& expCtx,
BSONElement expr,
const VariablesParseState& vps);
void setLimit(int limit);

protected:
void _doAddDependencies(DepsTracker* deps) const final;
Expand All @@ -1174,6 +1176,9 @@ class ExpressionFilter final : public Expression {
boost::intrusive_ptr<Expression> _input;
// The expression determining whether each element should be present in the result array.
boost::intrusive_ptr<Expression> _filter;
// For expression ArrayElemAt and Slice we can optimize filter by setting a limit to end the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd make it clearer that this optimization applies when $filter is being used as an argument to either $slice or $arrayElemAt. Also, feel free to use the $slice name instead of ExpressionSlice here.

// filter once all values needed are filtered.
boost::optional<int> _limit;
};


Expand Down Expand Up @@ -1666,6 +1671,7 @@ class ExpressionSlice final : public ExpressionRangedArity<ExpressionSlice, 2, 3
: ExpressionRangedArity<ExpressionSlice, 2, 3>(expCtx) {}

Value evaluate(const Document& root) const final;
boost::intrusive_ptr<Expression> optimize() final;
const char* getOpName() const final;
};

Expand Down
86 changes: 86 additions & 0 deletions src/mongo/db/pipeline/expression_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,92 @@ TEST(ExpressionObjectOptimizations,

} // namespace Object

TEST(ExpressionFilter, ExpressionFilterWithASetLimitShouldReturnAArrayNoGreaterThanTheLimit) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AnArray instead of AArray.

intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;
auto filterSpec = BSON(
"$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as"
<< "arr"
<< "cond" << BSON("$gt" << BSON_ARRAY("$$arr" << 3))));


auto expFilter = ExpressionFilter::parse(expCtx, filterSpec.firstElement(), vps);
dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(1);
auto oneElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(oneElemArray.getArray().size() == 1);
ASSERT_VALUE_EQ(oneElemArray, Value(BSON_ARRAY(4)));

dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(2);
auto twoElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(twoElemArray.getArray().size() == 2);
ASSERT_VALUE_EQ(twoElemArray, Value(BSON_ARRAY(4 << 5)));

dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(5);
auto fiveElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(fiveElemArray.getArray().size() == 5);
ASSERT_VALUE_EQ(fiveElemArray, Value(BSON_ARRAY(4 << 5 << 6 << 7 << 8)));
// Filter runs out of elements before limit is reached
dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(10);
auto sixElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(sixElemArray.getArray().size() == 6);
ASSERT_VALUE_EQ(sixElemArray, Value(BSON_ARRAY(4 << 5 << 6 << 7 << 8 << 9)));
}
TEST(ExpressionArrayElemAt, ArrayElemAtWithFilterShouldEvaluateCorrectly) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;
// Returns an array with all values greater than 3.
auto filterSpec = BSON(
"$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as"
<< "arr"
<< "cond" << BSON("$gt" << BSON_ARRAY("$$arr" << 3))));

auto arrayElemAtSpec = BSON("$arrayElemAt" << BSON_ARRAY(filterSpec << 2));

auto expArrayElemAt = ExpressionArrayElemAt::parse(expCtx, arrayElemAtSpec.firstElement(), vps);
auto xpArrayElemAt = dynamic_cast<ExpressionArrayElemAt*>(expArrayElemAt.get())->optimize();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a better variable name, maybe 'optimized'?

auto val = xpArrayElemAt->evaluate(Document());
ASSERT_VALUE_EQ(val, Value(6));

auto eA = ExpressionArrayElemAt::parse(expCtx, BSON("$arrayElemAt" << BSON_ARRAY(BSON_ARRAY(1 << 2 <<3 << 4 <<5) << 1)).firstElement(), vps);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can re-use the variable names from before, I think that'd be clearer than two-letter variable names.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with below. Alternatively, you could separate each one into it's own test.

auto eAO = dynamic_cast<ExpressionArrayElemAt*>(eA.get())->optimize();
auto va = eAO->evaluate(Document());
ASSERT_VALUE_EQ(va, Value(2));

auto ElemAtNegativeIndex = BSON("$arrayElemAt" << BSON_ARRAY(filterSpec << -2));
auto expElemAtNegativeIndex =
ExpressionArrayElemAt::parse(expCtx, ElemAtNegativeIndex.firstElement(), vps);
auto elemAtNegativeIndexOptimized =
dynamic_cast<ExpressionArrayElemAt*>(expElemAtNegativeIndex.get())->optimize();
ASSERT_VALUE_EQ(elemAtNegativeIndexOptimized->evaluate(Document()), Value(8));
}

TEST(ExpressionSlice, SliceWithFilterShouldEvaluateCorrectly) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;
// Returns an array with values greater than 1.
auto filterSpec = BSON(
"$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as"
<< "arr"
<< "cond" << BSON("$gt" << BSON_ARRAY("$$arr" << 1))));
auto sliceSpec = BSON("$slice" << BSON_ARRAY(filterSpec << 2 << 2));
auto expSlice = ExpressionSlice::parse(expCtx, sliceSpec.firstElement(), vps);
auto optimizedSlice = dynamic_cast<ExpressionSlice*>(expSlice.get())->optimize();
ASSERT_VALUE_EQ(optimizedSlice->evaluate(Document()), Value(BSON_ARRAY(4 << 5)));

auto sliceNegative = BSON("$slice" << BSON_ARRAY(filterSpec << -4 << 4));
auto expSliceNegative = ExpressionSlice::parse(expCtx, sliceNegative.firstElement(), vps);
auto optimizedSliceNegative =
dynamic_cast<ExpressionSlice*>(expSliceNegative.get())->optimize();
ASSERT_VALUE_EQ(optimizedSliceNegative->evaluate(Document()),
Value(BSON_ARRAY(6 << 7 << 8 << 9)));
auto sliceNegative2ndArg = BSON("$slice" << BSON_ARRAY(filterSpec << -2));
auto expSliceNegative2ndArg =
ExpressionSlice::parse(expCtx, sliceNegative2ndArg.firstElement(), vps);
auto optimizedSliceNegative2ndArg =
dynamic_cast<ExpressionSlice*>(expSliceNegative2ndArg.get())->optimize();
ASSERT_VALUE_EQ(optimizedSliceNegative2ndArg->evaluate(Document()), Value(BSON_ARRAY(8 << 9)));
}

namespace Or {

class ExpectedResultBase {
Expand Down