Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions crates/polars-core/src/frame/group_by/position.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ impl GroupsType {
}
}

pub fn is_rolling(&self) -> bool {
matches!(self, GroupsType::Slice { rolling: true, .. })
}

pub fn take_group_firsts(self) -> Vec<IdxSize> {
match self {
GroupsType::Idx(mut groups) => std::mem::take(&mut groups.first),
Expand All @@ -371,6 +375,16 @@ impl GroupsType {
}
}

pub fn check_lengths(self: &GroupsType, other: &GroupsType) -> PolarsResult<()> {
if std::ptr::eq(self, other) {
return Ok(());
}
polars_ensure!(self.iter().zip(other.iter()).all(|(a, b)| {
a.len() == b.len()
}), ComputeError: "expressions must have matching group lengths");
Ok(())
}

/// # Safety
/// This will not do any bounds checks. The caller must ensure
/// all groups have members.
Expand Down Expand Up @@ -620,12 +634,6 @@ impl Clone for GroupPositions {
}
}

impl PartialEq for GroupPositions {
fn eq(&self, other: &Self) -> bool {
self.offset == other.offset && self.len == other.len && self.sliced == other.sliced
}
}

impl AsRef<GroupsType> for GroupPositions {
fn as_ref(&self) -> &GroupsType {
self.sliced.deref()
Expand Down
119 changes: 99 additions & 20 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,68 @@ impl BinaryExpr {
fn apply_elementwise<'a>(
&self,
mut ac_l: AggregationContext<'a>,
ac_r: AggregationContext,
mut ac_r: AggregationContext<'a>,
aggregated: bool,
) -> PolarsResult<AggregationContext<'a>> {
// We want to be able to mutate in place, so we take the lhs to make sure that we drop.
let lhs = ac_l.get_values().clone();
let rhs = ac_r.get_values().clone();
// At this stage, we do not have both AggregatedList and NotAggregated ACs

// Drop lhs so that we might operate in place.
drop(ac_l.take());
// Check group lengths in case of all AggList
if [&ac_l, &ac_r]
.iter()
.all(|ac| matches!(ac.state, AggState::AggregatedList(_)))
{
ac_l.groups.check_lengths(&ac_r.groups)?;
}

let out = apply_operator_owned(lhs, rhs, self.op)?;
ac_l.with_values(out, aggregated, Some(&self.expr))?;
Ok(ac_l)
// The first non-LiteralScalar AC will be used as the base AC to retain the context
let left_is_literal = ac_l.is_literal();

match if !left_is_literal {
ac_l.agg_state()
} else {
ac_r.agg_state()
} {
AggState::AggregatedList(s) => {
let ca = s.list().unwrap();
let cols = [&ac_l, &ac_r]
.iter()
.map(|ac| ac.flat_naive().into_owned())
.collect::<Vec<_>>();

let out = ca.apply_to_inner(&|_| {
apply_operator(&cols[0], &cols[1], self.op)
.map(|c| c.take_materialized_series())
})?;
let out = out.into_column();

if ac_l.is_literal() {
std::mem::swap(&mut ac_l, &mut ac_r);
}

ac_l.with_values(out.into_column(), true, Some(&self.expr))?;
Ok(ac_l)
},

_ => {
// We want to be able to mutate in place, so we take the lhs to make sure that we drop.
let lhs = ac_l.get_values().clone();
let rhs = ac_r.get_values().clone();

let out = apply_operator_owned(lhs, rhs, self.op)?;

// Make sure ac_l is a non_literal AC so we retain correct group info, e.g.
// in the case of (LiteralScalar, NotAggregated with mutated groups)
if ac_l.is_literal() {
std::mem::swap(&mut ac_l, &mut ac_r);
}

// Drop lhs so that we might operate in place.
drop(ac_l.take());

ac_l.with_values(out, aggregated, Some(&self.expr))?;
Ok(ac_l)
},
}
}

fn apply_all_literal<'a>(
Expand Down Expand Up @@ -239,8 +288,37 @@ impl PhysicalExpr for BinaryExpr {
)
});
let mut ac_l = result_a?;
let ac_r = result_b?;
let mut ac_r = result_b?;

// Aggregate NotAggregated into AggregatedList, but only if strictly required AND
// when there is no risk of memory explosion. See ApplyExpr for additional context
// TODO - extend rolling to group_by_dynamic
let has_agg_list = [&ac_l, &ac_r]
.iter()
.any(|ac| matches!(ac.state, AggState::AggregatedList(_)));
let not_agg_has_rolling = [&ac_l, &ac_r]
.iter()
.any(|ac| matches!(ac.state, AggState::NotAggregated(_)) && ac.groups.is_rolling());

let not_agg_groups_may_diverge = [&ac_l, &ac_r]
.iter()
.filter(|ac| matches!(ac.state, AggState::NotAggregated(_)))
.map(|ac| ac.groups.as_ref())
.collect::<Vec<_>>()
.windows(2)
.any(|w| !std::ptr::eq(w[0], w[1]));

for ac in [&mut ac_l, &mut ac_r] {
if matches!(ac.state, AggState::NotAggregated(_)) {
if !not_agg_has_rolling && (has_agg_list || not_agg_groups_may_diverge) {
ac.aggregated();
}
}
}

// Dispatch
// TODO - consolidate into 4 branches:
// all_lit, aggscalar incompatible, compatible but expensive, other)
match (ac_l.agg_state(), ac_r.agg_state()) {
(AggState::LiteralScalar(_), AggState::LiteralScalar(_)) => {
self.apply_all_literal(ac_l, ac_r)
Expand All @@ -251,7 +329,11 @@ impl PhysicalExpr for BinaryExpr {
_ => self.apply_group_aware(ac_l, ac_r),
},
(AggState::NotAggregated(_), AggState::NotAggregated(_)) => {
self.apply_elementwise(ac_l, ac_r, false)
if not_agg_groups_may_diverge && not_agg_has_rolling {
self.apply_group_aware(ac_l, ac_r)
} else {
self.apply_elementwise(ac_l, ac_r, false)
}
},
(
AggState::AggregatedScalar(_) | AggState::LiteralScalar(_),
Expand All @@ -261,16 +343,13 @@ impl PhysicalExpr for BinaryExpr {
| (AggState::NotAggregated(_), AggState::AggregatedScalar(_)) => {
self.apply_group_aware(ac_l, ac_r)
},
(AggState::AggregatedList(lhs), AggState::AggregatedList(rhs)) => {
let lhs = lhs.list().unwrap();
let rhs = rhs.list().unwrap();
let out = lhs.apply_to_inner(&|lhs| {
apply_operator(&lhs.into_column(), &rhs.get_inner().into_column(), self.op)
.map(|c| c.take_materialized_series())
})?;
ac_l.with_values(out.into_column(), true, Some(&self.expr))?;
Ok(ac_l)
(AggState::AggregatedList(_), AggState::AggregatedList(_)) => {
self.apply_elementwise(ac_l, ac_r, true)
},
(
AggState::AggregatedList(_) | AggState::LiteralScalar(_),
AggState::AggregatedList(_) | AggState::LiteralScalar(_),
) => self.apply_elementwise(ac_l, ac_r, true),
_ => self.apply_group_aware(ac_l, ac_r),
}
}
Expand Down
132 changes: 131 additions & 1 deletion py-polars/tests/unit/operations/namespaces/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import struct
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import pytest
Expand Down Expand Up @@ -463,3 +463,133 @@ def test_binary_compounded_literal_aggstate_24460() -> None:
)
expected = pl.DataFrame({"g": [10], "z": [27]})
assert_frame_equal(out, expected)


# parametric tuples: (expr, is_scalar, values with broadcast)
agg_expressions = [
(pl.lit(7, pl.Int64), True, [7, 7, 7]), # LiteralScalar
(pl.col("n"), False, [2, 1, 3]), # NotAggregated
(pl.int_range(pl.len()), False, [0, 1, 0]), # AggregatedList
(pl.col("n").first(), True, [2, 2, 3]), # AggregatedScalar
]


@pytest.mark.parametrize("lhs", agg_expressions)
@pytest.mark.parametrize("rhs", agg_expressions)
@pytest.mark.parametrize("n_rows", [0, 1, 2, 3])
@pytest.mark.parametrize("maintain_order", [True, False])
def test_add_aggstates_in_binary_expr_24504(
lhs: tuple[pl.Expr, bool, list[int]],
rhs: tuple[pl.Expr, bool, list[int]],
n_rows: int,
maintain_order: bool,
) -> None:
df = pl.DataFrame({"g": [10, 10, 20], "n": [2, 1, 3]})
lf = df.head(n_rows).lazy()
expr = pl.Expr.add(lhs[0].alias("lhs"), rhs[0].alias("rhs")).alias("expr")
q = lf.group_by("g", maintain_order=maintain_order).agg(expr)
out = q.collect()

# check schema
assert q.collect_schema() == out.schema

# check output against ground truth
if n_rows in [1, 2, 3]:
data = df.to_dict(as_series=False)
result: dict[int, Any] = {}
for gg, ll, rr in zip(data["g"][:n_rows], lhs[2][:n_rows], rhs[2][:n_rows]):
result.setdefault(gg, []).append(ll + rr)
if lhs[1] and rhs[1]:
# expect scalar result
result = {k: v[0] for k, v in result.items()}
expected = pl.DataFrame(
{"g": list(result.keys()), "expr": list(result.values())}
)
assert_frame_equal(out, expected, check_row_order=maintain_order)

# check output against non_aggregated expression evaluation
if n_rows in [1, 2, 3]:
print(f"df\n{df}")
grouped = df.head(n_rows).group_by("g", maintain_order=maintain_order)
out_non_agg = pl.DataFrame({})
for df_group in grouped:
df = df_group[1]
print(f"df pre expr:\n{df}", flush=True)
if lhs[1] and rhs[1]:
df = df.head(1)
df = df.select(["g", expr])
else:
df = df.select(["g", expr.implode()]).head(1)
print(f"df post expr:{df}\n")
out_non_agg = out_non_agg.vstack(df)
print(f"out_non_agg:\n{out_non_agg}")

assert_frame_equal(out, out_non_agg, check_row_order=maintain_order)


# parametric tuples: (expr, is_scalar)
agg_expressions_sort = [
(pl.lit(7, pl.Int64), True), # LiteralScalar
(pl.col("n"), False), # NotAggregated
(pl.col("n").sort(), False), # NotAggregated w groups modified
(pl.int_range(pl.len()), False), # AggregatedList
(pl.int_range(pl.len()).reverse(), False), # AggregatedList w groups modified
(pl.col("n").first(), True), # AggregatedScalar
]


@pytest.mark.parametrize("lhs", agg_expressions_sort)
@pytest.mark.parametrize("rhs", agg_expressions_sort)
@pytest.mark.parametrize("maintain_order", [True, False])
def test_add_aggstates_with_sort_in_binary_expr_24504(
lhs: tuple[pl.Expr, bool, list[int]],
rhs: tuple[pl.Expr, bool, list[int]],
maintain_order: bool,
) -> None:
df = pl.DataFrame({"g": [10, 10, 20], "n": [2, 1, 3]})
lf = df.lazy()
expr = pl.Expr.add(lhs[0].alias("lhs"), rhs[0].alias("rhs")).alias("expr")
q = lf.group_by("g", maintain_order=maintain_order).agg(expr)
out = q.collect()

# check schema
assert q.collect_schema() == out.schema

# check output against non_aggregated expression evaluation
print(f"df\n{df}")
grouped = df.group_by("g", maintain_order=maintain_order)
out_non_agg = pl.DataFrame({})
for df_group in grouped:
df = df_group[1]
print(f"df pre expr:\n{df}", flush=True)
if lhs[1] and rhs[1]:
df = df.head(1)
df = df.select(["g", expr])
else:
df = df.select(["g", expr.implode()]).head(1)
print(f"df post expr:{df}\n")
out_non_agg = out_non_agg.vstack(df)
print(f"out_non_agg:\n{out_non_agg}")

assert_frame_equal(out, out_non_agg, check_row_order=maintain_order)


@pytest.mark.parametrize("maintain_order", [True, False])
def test_binary_context_nested(maintain_order: bool) -> None:
df = pl.DataFrame({"groups": [1, 1, 2, 2, 3, 3], "vals": [1, 13, 3, 87, 1, 6]})
out = (
df.lazy()
.group_by(pl.col("groups"), maintain_order=maintain_order)
.agg(
[
pl.when(pl.col("vals").eq(pl.lit(1)))
.then(pl.col("vals").sum())
.otherwise(pl.lit(90))
.alias("vals")
]
)
).collect()
expected = pl.DataFrame(
{"groups": [1, 2, 3], "vals": [[14, 90], [90, 90], [7, 90]]}
)
assert_frame_equal(out, expected, check_row_order=maintain_order)
21 changes: 21 additions & 0 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,24 @@ def test_shift_with_null_deprecated_24105(fill_value: Any) -> None:
pl.DataFrame({"x": [None, None, None]}),
check_dtypes=False,
)


def test_raies_on_mismatch_column_length_binary_expr() -> None:
df = pl.DataFrame(
{
"a": [10, 10, 10, 20, 20, 20],
"b": [2, 0, 99, 0, 0, 0],
"c": [3, 0, 0, 2, 0, 99],
}
)

with pytest.raises(
ComputeError,
match="expressions must have matching group lengths",
):
df.group_by("a").agg(
pl.Expr.add(
pl.col("b").head(pl.col("b").first()),
pl.col("c").head(pl.col("c").first()),
)
)
Loading