diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 211f2ab7f1d97..c4ac880cd8920 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -524,6 +524,34 @@ Value ExpressionArrayElemAt::evaluate(const Document& root) const { return array[index]; } +intrusive_ptr ExpressionArrayElemAt::optimize() { + // This will optimize all arguments to this expression. + auto optimized = ExpressionNary::optimize(); + if (optimized.get() != this) + return optimized; + + + // If ExpressionArrayElemAt is passed an ExpressionFilter as its first arugment set a limit on + // the filter so filter returns an array with the last element being the value we want. + if (dynamic_cast(vpOperand[0].get())) { + if (auto expConstant = dynamic_cast(vpOperand[1].get())) { + auto indexArg = expConstant->getValue(); + + uassert(50803, + str::stream() << getOpName() << "'s second argument must be representable as" + << " a 32-bit integer: " + << indexArg.coerceToDouble(), + indexArg.integral()); + auto index = indexArg.coerceToInt(); + // Can't optimize of the index is less that 0. + if (index >= 0) { + dynamic_cast(vpOperand[0].get())->setLimit(index + 1); + } + } + } + return this; +}; + REGISTER_EXPRESSION(arrayElemAt, ExpressionArrayElemAt::parse); const char* ExpressionArrayElemAt::getOpName() const { return "$arrayElemAt"; @@ -2206,12 +2234,19 @@ Value ExpressionFilter::evaluate(const Document& root) const { if (_filter->evaluate(root).coerceToBool()) { output.push_back(std::move(elem)); + if (_limit && static_cast(output.size()) == _limit.get()) { + return Value(std::move(output)); + } } } return Value(std::move(output)); } +void ExpressionFilter::setLimit(int limit) { + _limit = boost::optional(limit); +} + void ExpressionFilter::_doAddDependencies(DepsTracker* deps) const { _input->addDependencies(deps); _filter->addDependencies(deps); @@ -3878,20 +3913,7 @@ Value ExpressionSlice::evaluate(const Document& root) const { return Value(BSONNULL); } - uassert(28727, - str::stream() << "Third argument to $slice must be numeric, but " - << "is of type: " - << typeName(countVal.getType()), - countVal.numeric()); - uassert(28728, - str::stream() << "Third argument to $slice can't be represented" - << " as a 32-bit integer: " - << countVal.coerceToDouble(), - countVal.integral()); - uassert(28729, - str::stream() << "Third argument to $slice must be positive: " - << countVal.coerceToInt(), - countVal.coerceToInt() > 0); + uassertIfNotIntegralAndNonNegative(countVal, "$slice", "third argument"); size_t count = size_t(countVal.coerceToInt()); end = std::min(start + count, array.size()); @@ -3900,6 +3922,49 @@ Value ExpressionSlice::evaluate(const Document& root) const { return Value(vector(array.begin() + start, array.begin() + end)); } +intrusive_ptr ExpressionSlice::optimize() { + // This will optimize all arguments to this expression. + auto optimized = ExpressionNary::optimize(); + if(optimized.get() != this) + return optimized; + + // If ExpressionSlice is passed an ExpressionFilter we can stop filtering once the size of + // the array returned by the filter is equal to the last arguement passed to ExpressionSlice. + if (dynamic_cast(vpOperand[0].get())) { + if (auto secondArg = dynamic_cast(vpOperand[1].get())) { + auto secondVal = secondArg->getValue(); + + uassert(50798, + str::stream() << "Second argument to $slice can't be represented as" + << " a 32-bit integer: " + << secondVal.coerceToDouble(), + secondVal.integral()); + + int arg2 = secondVal.coerceToInt(); + if (vpOperand.size() == 2) { + // Can't set a limit if it is negative. + if (arg2 >= 0) { + // If slice is given two arguments set limit to the position we want to slice. + dynamic_cast(vpOperand[0].get())->setLimit(arg2); + } + } else if (vpOperand.size() > 2) { + if (auto thirdArg = dynamic_cast(vpOperand[2].get())) { + auto thirdVal = thirdArg->getValue(); + + uassertIfNotIntegralAndNonNegative(thirdVal, "$slice", "third argument"); + + int arg3 = thirdVal.coerceToInt(); + if (arg2 >= 0) { + // The limit needs to set as the last element we want in this case its + // the position argument + the first n elements argument. + dynamic_cast(vpOperand[0].get())->setLimit(arg2 + arg3); + } + } + } + } + } + return this; +} REGISTER_EXPRESSION(slice, ExpressionSlice::parse); const char* ExpressionSlice::getOpName() const { return "$slice"; diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 5b4a0286b7c03..53e684713c5d7 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -737,6 +737,7 @@ class ExpressionArrayElemAt final : public ExpressionFixedArity& expCtx) : ExpressionFixedArity(expCtx) {} + boost::intrusive_ptr optimize() final; Value evaluate(const Document& root) const final; const char* getOpName() const final; }; @@ -1155,6 +1156,7 @@ class ExpressionFilter final : public Expression { const boost::intrusive_ptr& expCtx, BSONElement expr, const VariablesParseState& vps); + void setLimit(int limit); protected: void _doAddDependencies(DepsTracker* deps) const final; @@ -1174,6 +1176,9 @@ class ExpressionFilter final : public Expression { boost::intrusive_ptr _input; // The expression determining whether each element should be present in the result array. boost::intrusive_ptr _filter; + // When $filter is passed as an argument to $arrayElemAt or $slice we can set a limit on $filter + // to stop filtering once all the values needed are in the result array. + boost::optional _limit; }; @@ -1666,6 +1671,7 @@ class ExpressionSlice final : public ExpressionRangedArity(expCtx) {} Value evaluate(const Document& root) const final; + boost::intrusive_ptr optimize() final; const char* getOpName() const final; }; diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 79a30eb56e2f4..8d8e774f85103 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -2894,6 +2894,120 @@ TEST(ExpressionObjectOptimizations, } // namespace Object +TEST(ExpressionFilter, ExpressionFilterWithASetLimitShouldReturnAnArrayNoGreaterThanTheLimit) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + auto filterSpec = BSON( + "$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as" + << "arr" + << "cond" + << BSON("$gt" << BSON_ARRAY("$$arr" << 3)))); + + + auto expFilter = ExpressionFilter::parse(expCtx, filterSpec.firstElement(), vps); + dynamic_cast(expFilter.get())->setLimit(1); + auto oneElemArray = expFilter->evaluate(Document()); + ASSERT_TRUE(oneElemArray.getArray().size() == 1); + ASSERT_VALUE_EQ(oneElemArray, Value(BSON_ARRAY(4))); + + dynamic_cast(expFilter.get())->setLimit(2); + auto twoElemArray = expFilter->evaluate(Document()); + ASSERT_TRUE(twoElemArray.getArray().size() == 2); + ASSERT_VALUE_EQ(twoElemArray, Value(BSON_ARRAY(4 << 5))); + + dynamic_cast(expFilter.get())->setLimit(5); + auto fiveElemArray = expFilter->evaluate(Document()); + ASSERT_TRUE(fiveElemArray.getArray().size() == 5); + ASSERT_VALUE_EQ(fiveElemArray, Value(BSON_ARRAY(4 << 5 << 6 << 7 << 8))); + // Filter runs out of elements before limit is reached + dynamic_cast(expFilter.get())->setLimit(10); + auto sixElemArray = expFilter->evaluate(Document()); + ASSERT_TRUE(sixElemArray.getArray().size() == 6); + ASSERT_VALUE_EQ(sixElemArray, Value(BSON_ARRAY(4 << 5 << 6 << 7 << 8 << 9))); +} + +TEST(ExpressionArrayElemAt, ArrayElemAtWithAllConstantValuesShouldOptimizeToAnExpressionConstant) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + + auto expArrayElemAt = ExpressionArrayElemAt::parse( + expCtx, + BSON("$arrayElemAt" << BSON_ARRAY(BSON_ARRAY(1 << 2 << 3 << 4 << 5) << 1)).firstElement(), + vps); + expArrayElemAt = expArrayElemAt.get()->optimize(); + ASSERT_TRUE(dynamic_cast(expArrayElemAt.get())); +} + +TEST(ExpressionArrayElemAt, ArrayElemAtWithFilterShouldEvaluateCorrectly) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + // Returns an array with all values greater than 3. + auto filterSpec = BSON( + "$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as" + << "arr" + << "cond" + << BSON("$gt" << BSON_ARRAY("$$arr" << 3)))); + + + auto arrayElemAtSpec = BSON("$arrayElemAt" << BSON_ARRAY(filterSpec << 2)); + + auto expArrayElemAt = ExpressionArrayElemAt::parse(expCtx, arrayElemAtSpec.firstElement(), vps); + auto optimized = dynamic_cast(expArrayElemAt.get())->optimize(); + auto val = optimized->evaluate(Document()); + ASSERT_VALUE_EQ(val, Value(6)); + + expArrayElemAt = ExpressionArrayElemAt::parse( + expCtx, + BSON("$arrayElemAt" << BSON_ARRAY(BSON_ARRAY(1 << 2 << 3 << 4 << 5) << 1)).firstElement(), + vps); + + optimized = dynamic_cast(expArrayElemAt.get())->optimize(); + val = optimized->evaluate(Document()); + ASSERT_VALUE_EQ(val, Value(2)); + + expArrayElemAt = ExpressionArrayElemAt::parse( + expCtx, BSON("$arrayElemAt" << BSON_ARRAY(filterSpec << -2)).firstElement(), vps); + optimized = dynamic_cast(expArrayElemAt.get())->optimize(); + ASSERT_VALUE_EQ(optimized->evaluate(Document()), Value(8)); +} + +TEST(ExpressionSlice, ExpressionSliceWithAllConstantValuesShouldOptimizeToAnExpressionConstant) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + + auto expSlice = ExpressionSlice::parse( + expCtx, + BSON("$arrayElemAt" << BSON_ARRAY(BSON_ARRAY(1 << 2 << 3 << 4 << 5) << 1 << 1)) + .firstElement(), + vps); + expSlice = expSlice.get()->optimize(); + ASSERT_TRUE(dynamic_cast(expSlice.get())); +} + +TEST(ExpressionSlice, SliceWithFilterShouldEvaluateCorrectly) { + intrusive_ptr expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + // Returns an array with values greater than 1. + auto filterSpec = BSON( + "$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as" + << "arr" + << "cond" + << BSON("$gt" << BSON_ARRAY("$$arr" << 1)))); + auto sliceSpec = BSON("$slice" << BSON_ARRAY(filterSpec << 2 << 2)); + auto expSlice = ExpressionSlice::parse(expCtx, sliceSpec.firstElement(), vps); + auto optimizedSlice = dynamic_cast(expSlice.get())->optimize(); + ASSERT_VALUE_EQ(optimizedSlice->evaluate(Document()), Value(BSON_ARRAY(4 << 5))); + + sliceSpec = BSON("$slice" << BSON_ARRAY(filterSpec << -4 << 4)); + expSlice = ExpressionSlice::parse(expCtx, sliceSpec.firstElement(), vps); + optimizedSlice = dynamic_cast(expSlice.get())->optimize(); + ASSERT_VALUE_EQ(optimizedSlice->evaluate(Document()), Value(BSON_ARRAY(6 << 7 << 8 << 9))); + sliceSpec = BSON("$slice" << BSON_ARRAY(filterSpec << -2)); + expSlice = ExpressionSlice::parse(expCtx, sliceSpec.firstElement(), vps); + optimizedSlice = dynamic_cast(expSlice.get())->optimize(); + ASSERT_VALUE_EQ(optimizedSlice->evaluate(Document()), Value(BSON_ARRAY(8 << 9))); +} + namespace Or { class ExpectedResultBase {