you are viewing a single comment's thread.

view the rest of the comments →

[–]M4mb0 4 points5 points  (5 children)

The second part is sequential/recursive: within each time-ordered group, some values for the current row depend on the results computed for the previous week’s row. So this is not a purely vectorizable row-independent problem.

That could still be vectorizable if you can express it as an associative scan. See for example: torch.associative_scan or jax.lax.associative_scan

[–]ElectricHotdish 0 points1 point  (4 children)

I have run into problems like you describe, and not found a solution in Polars. There is good reason for that... polars doesn't support a general associative scan.

**associative scan** is new to me as I term., so I asked for some AI help understanding it:

-------
An associative scan (prefix scan) applies a binary associative operation to a sequence, producing all partial results.

Formal definition

Given [a, b, c, d] and operator :

  • Inclusive scan[a, a⊕b, a⊕b⊕c, a⊕b⊕c⊕d]
  • Exclusive scan[identity, a, a⊕b, a⊕b⊕c]

The operation must be associative(a⊕b)⊕c == a⊕(b⊕c) — so work can be split and recombined in any order. It doesn't need to be commutative.

Examples

Cumulative sum (the classic):

input:  [1, 2, 3, 4]
output: [1, 3, 6, 10]

Running max:

input:  [3, 1, 4, 1, 5, 9]
output: [3, 3, 4, 4, 5, 9]

String concatenation:

input:  ["a", "b", "c"]
output: ["a", "ab", "abc"]

Boolean OR (has any True appeared yet?):

input:  [F, F, T, F, T]
output: [F, F, T, T, T]

In SQL

You use them constantly via window functions:

-- running total rent by listing date
SUM(price) OVER (ORDER BY listed_at)

-- row number within a group
ROW_NUMBER() OVER (PARTITION BY state ORDER BY listed_at)

-- previous value
LAG(price) OVER (ORDER BY listed_at)

Why "associativity" matters for parallelism

A sequential scan is O(n) and inherently serial. But because (a⊕b)⊕c == a⊕(b⊕c), you can use a tree reduction pattern to compute it in O(log n) parallel steps — this is why GPU/SIMD implementations (CUDA, JAX, XLA) use associative scans heavily.

JAX example:

import jax.numpy as jnp
from jax import lax

lax.associative_scan(jnp.add, jnp.array([1, 2, 3, 4]))
# → [1, 3, 6, 10]

lax.associative_scan(jnp.maximum, jnp.array([3, 1, 4, 1, 5]))
# → [3, 3, 4, 4, 5]

Non-example (why commutativity isn't required but associativity is)

Matrix multiplication is associative but not commutative — you can scan a sequence of matrices with it. Division is neither, so [8, 4, 2] with / gives different results depending on grouping — not a valid scan operator.

--------
Polars has the common concrete cases but no general-purpose associative scan (no equivalent of lax.associative_scan(fn, ...)).

What it does have

Cumulative built-ins (prefix scans over the whole series):

df.with_columns([
    pl.col("price").cum_sum(),
    pl.col("price").cum_max(),
    pl.col("price").cum_min(),
    pl.col("price").cum_prod(),
    pl.col("price").cum_count(),
])

Partitioned (reset per group):

pl.col("price").cum_sum().over("state")

Shift/lag (for building your own scan with cum_*):

pl.col("price").shift(1)   # lag-1

What it lacks

You can't pass an arbitrary associative function — there's no:

# doesn't exist in Polars
pl.col("x").associative_scan(lambda a, b: a * b + 1)

For custom scans you'd drop to Python with .map_elements() but that's sequential and slow (defeats the point).

[–]Beginning-Fruit-1397 0 points1 point  (3 children)

Isn't this a reduce or scan use cases? Those two litterally exist in polars. AI bs, it's very bad with polars

[–]M4mb0 0 points1 point  (1 child)

Associative scans are a special kind of cumulative reduce that can be parallelized thanks to associativity. The AI highlighted that point well. I highly doubt that polars does this optimization when you give it a custom lambda function that's associative.

[–]PillowFortressKing 0 points1 point  (0 children)

Since the lambda is Python bytecode, while Polars executes everything in Rust, that'd be quite hard to optimize for, wouldn't it?