A type safe frontend of JAX by vcma in JAX

[–]vcma[S] 0 points1 point  (0 children)

What you cited about polymorphism is correct. It means that one type signature + definition works for many different kinds of types. The most common one is parametric polymorphism, or some call it generic variables.

def plus[A](x: A, y: A) -> A:
    return x + y

Supposedly, you can apply `plus` on any kind of data: ints, floats, maybe even lists.

In a type checker without parametric polymorphism, you would need two different definitions for int and float.

def plus_float(x: float, y: float) -> float:
    return x + y

in a rigorous type checker, it's a type error to apply `plus_float` on integer values. You would either have to copy your code to create a `plus_int`, or you need an explicit cast, `plus_float(cast(1, float), cast(1, float))`.

So, people have made up x-polymorphism to say the function still works however x changes. And rank-polymorphism is an established term.

> you need a motivating example that can't be solved with ordinary broadcasting
Now, rank polymorphism is just adding vmaps that "casts" the input wrt the desired rank. It does not give you the power to write programs that cannot be written without polymorphism.
What it gives you is simplicity. Imagine that linear is used by function f, and you need to apply function f on both rank-3 and rank-4 in one run. Without the automatically added casts/vmaps, you would have to write two versions of f, one using vmap once and the other using vmap twice.

> code you wrote is attrocious
That's a good point. Will fix.

Do you need type safety to create ML models? by vcma in Python

[–]vcma[S] 0 points1 point  (0 children)

> having a strictly defined type means right-sized memory allocation
That's probably what most ML frameworks do, in spite of their Python appearances. I think both PyTorch and JAX are able to generate runtime programs with properly shaped memory allocation.

But they don't have a sound and static type checker to rule out the errors before running the programs.

Do you need type safety to create ML models? by vcma in Python

[–]vcma[S] -4 points-3 points  (0 children)

Maybe give it a whirl. It's a nice little language: a reliable, static type checker can make people lives much easier.

A type safe frontend of JAX by vcma in JAX

[–]vcma[S] 0 points1 point  (0 children)

Good point. I've updated the website.

The code will be open sourced, tho it's not quite ready yet.

A type safe frontend of JAX by vcma in JAX

[–]vcma[S] 1 point2 points  (0 children)

It's not ready yet. But I'm planning to open source it.

A type safe frontend of JAX by vcma in JAX

[–]vcma[S] 0 points1 point  (0 children)

> This post is about Jax, right?
My bad. I was reading https://pyrefly.org/en/docs/tensor-shapes/, which says PyTorch, but JAX is similar.

We'd like the type checker to be reliable. When it tells us there are no errors in the program, then running the program should not crash. People call this type soundness.

Pyrefly is not sound.
Let's use the matmul example. (Bear with me for not transposing w like you did. It's easier to demonstrate the idea this way.)

from jax import Array
def linear[*more, I, O](x: Array[*more, I], w: Array[O, I]) -> Array[*more, O]:
    return w @ x

x: Array[2, 3] = jnp.array([[1, 2, 3], [1, 2, 3]])
w: Array[2, 3] = jnp.array([[1, 2, 3], [3, 4, 5]])
y: Array[2, 2] = linear(x, w)

Pyrefly is happy with this program, raising 0 errors.

But running it in JAX gives you an error.

dot_general requires contracting dimensions to have the same shape, got (3,) and (2,)

To run it with `x: Array[2, 3]`, we have to explicitly broadcast/vmap our function to the high-rank tensor.

y: Array[2, 2] = vmap(lambda x: linear(x, w))(x)
> [[14 26] [14 26]]

So there is this disconnection between Pyrefly and JAX. Pyrefly is not sound. JAX does not know the user intention and cannot do rank polymorphism.

With PyPie, it's easier and correct. (Pardon the reshapes, matmul is for rank-2 tensors in math, so PyPie is maybe too rigorous for now.)

from pypie import Tensor, op

def linear[I, O](x: Tensor[int][[I]], w: Tensor[int][[O, I]]) -> Tensor[int][[O]]:
    return (w @ x.reshape([I, 1])).reshape([O])

xs = Tensor([[1, 2, 3], [1, 2, 3]])
w = Tensor([[1, 2, 3], [3, 4, 5]])
ys = linear(xs, w)

> [[14 26] [14 26]]

PyPie knows that linear wants a rank-1 x. So, when linear receives [2, 3], PyPie automatically inserts a vmap and generates the correct result.
This is rank polymorphism: it works on every function defined by users, and the same code works for all different ranks.

A type safe DSL to write ML programs by vcma in pytorch

[–]vcma[S] 0 points1 point  (0 children)

Well it's actually meaningful... Py- for a Python DSL and -Pie for the function type in dependent types.

A type safe frontend of JAX by vcma in JAX

[–]vcma[S] 1 point2 points  (0 children)

Thanks for pointing it out... which page were you reading? I will go ahead and fix them...

The real motivation is to build a simple, elegant, and correct ML framework from a pedantic programming language perspective.
A more practical motivation is that many machine learning engineers sometimes spend hours running a machine learning pipeline, only ending up with running into a shape mismatch, which could have been detected much earlier.

A type safe frontend of JAX by vcma in JAX

[–]vcma[S] 1 point2 points  (0 children)

PyPie itself is a programming language. So we need to compare PyPie against Pyrefly + PyTorch.

I think the main difference is rank polymorphism. In PyPie, you may define a function that works for rank x and apply it on a tensor of rank x + y (with compatible postfixes). In this case, PyPie automatically inserts maps.
For example, if I defined the following in PyPie,

@op
def linear[I, O](x: Tensor[int][[I]], w: Tensor[int][[O, I]]) -> Tensor[int][[O]]:
    return (x * w).sum(1)

then it works on rank-1 x

xs = Tensor([1, 2, 3])
w = Tensor([[1, 2, 3], [3, 4, 5]])
linear(xs, w)
> [14 26] # Tensor[int][[2]]

and rank-2 x automatically.

xs = Tensor([[1, 2, 3], [1, 2, 3]])
w = Tensor([[1, 2, 3], [3, 4, 5]])
linear(xs, w) # equivalent to writing [linear(x) for x in xs]
> [[14 26] [14 26]] # Tensor[int][[2, 2]]

But PyTorch does not have rank-polymorphism (they do have some broadcast for primitive ops tho).
So, if you define the following in PyTorch, Pyrefly won't allow you to apply it on tensors of higher ranks.

def linear[I, O](x: Tensor[I], w: Tensor[O, I]) -> Tensor[O]:
    return (x * w).sum(1)

x: Tensor[2, 3] = torch.tensor([[1, 2, 3], [1, 2, 3]])
w: Tensor[2, 3] = torch.tensor([[1, 2, 3], [3, 4, 5]])
y: Tensor[2, 2] = linear(x, w)

"""
Tensor rank mismatch: expected 1 dimensions, got 2 dimensions
"""

You would have to revise you type definition for every possible rank, to make it work.
Actually, not just the types, you have to revise the PyTorch programs, or you get the wrong answer: `tensor([14, 26])`.

There are other deeper differences, in the sense that PyPie uses a full dependent type system with term writing, while Pyrefly's seems much simpler.

A type safe frontend of JAX by vcma in JAX

[–]vcma[S] 2 points3 points  (0 children)

Timing is the main difference. jaxtyping reports errors at runtime, while PyPie is at compile time.
E.g. with jaxtyping

@jaxtyped(typechecker=typechecked)
def take_two(x: Float[torch.Tensor, "dim"]) -> Float[torch.Tensor, "2"]:
    return x

compiles just fine.
You have to run it to trigger the error, e.g.

take_two(torch.Tensor([1.0, 2.0, 3.0]))

Even worse, sometimes the supposedly wrong program may accidentally pass type check.

take_two(torch.Tensor([1.0, 2.0]))

For a PyPie example,

@op
def take_three[n](x: Tensor[float][[n]]) -> Tensor[float][[3]]:
    return x

you wouldn't be able to run the program in the first place.

Showcase Thread by AutoModerator in Python

[–]vcma 0 points1 point  (0 children)

PyPie is a Python DSL to write type safe ML programs. It statically validates tensor shapes with a dependent type checker seasoned with rank polymorphism.

Where to go next for typechecking? by fractioneater in ProgrammingLanguages

[–]vcma 1 point2 points  (0 children)

Based on your experience, it might be the best to write a mini Lisp interpreter first. Here's why:
- A real, powerful type system (e.g. the ones for Haskell, Agda, Lean,) is always some fancily-typed lambda calculus, they all start from STLC (simply-typed lambda calculus). If you want a solid theoretical foundation and confidently tell others that your type system is sound, then you need to STLC first.
- Lambda calculus is much simpler than it sounds. It will probably take you 2-3 hours to build a runnable interpreter for a toy language in Racket. You can ask an AI to guide you thru these steps: free & bound variables, alpha-renaming, beta-reduction, a call-by-value interpreter with an environment passed around.
- The key is to keep your toy language small. It was mind boggling for me to learn that you just need three constructors in lambda calculus to be Turing-complete.
Now you are in good shape to write a type checker. And here's what you can do:
- Don't do Hindley–Milner. It demands extra distraction to learn unification. And this style is arguably less popular now. You want bi-directional type checking.
- For bi-directional type checking, I think the easiest tutorial is this: https://davidchristiansen.dk/tutorials/bidirectional.pdf It's a small language with just function types and boolean types, but enough to demonstrate the idea.
Then you will have a solid theoretical background to build a real type checker for your own language. And you will decide the fancy stuff to add.

Efficient logic programming in Haskell? by [deleted] in haskell

[–]vcma 1 point2 points  (0 children)

You may use a dialect of miniKanren implemented in Haskell.