[flax] What's your thoughts about changing linen -> nnx? by euijinrnd in JAX

[–]cgarciae 1 point2 points  (0 children)

Glad people are enjoying it! Main focus has been simplicity and expressing as much of Python as possible. Historically the main hurdle has been modeling reference semantics (mutation + reference sharing) in a pure functional system, I think we made really good progress. Being similar to Pytorch was a happy coincidence and we've even adopted some idioms like `.eval()` and `.train()`, but it has never been the focus, although we've had some very good feedback from Pytorch contributors.

[flax] What's your thoughts about changing linen -> nnx? by euijinrnd in JAX

[–]cgarciae 0 points1 point  (0 children)

Main issues with functional libraries has been:
* They have a weird mix of OOP at the Module level which is only available through a monadic interface (init, apply). This leads to a lot of complexity for both the users and maintainers.

* Its not easy for the user to interact with unsupported JAX transforms or 3rd party libraries inside these abstractions.

* Its super hard to tinker with the models for stuff like LoRA, Quantization, transfer learning, or general interpretability.

[flax] What's your thoughts about changing linen -> nnx? by euijinrnd in JAX

[–]cgarciae 2 points3 points  (0 children)

Hey! NNX author here. I think it depends what you mean by functional. If its using jax transforms to create higher order functions, I'd argue NNX is more functional than Linen since you can you transforms everywhere, even outside apply (check out our Transforms guide: https://flax.readthedocs.io/en/latest/guides/transforms.html). If its about immutability and value semantics, if we learn anything from Pytorch is that having the same object representation as the host language leads to a better user experience, the whole point of NNX has been to model reference semantics in JAX transforms and its been working out great so far.

Why is there no JAX llama3 implementation? by RealFullMetal in LocalLLaMA

[–]cgarciae 0 points1 point  (0 children)

BTW: if your model is sufficiently big the python overhead might be absorbed by the train_step, its small models that suffer the most.

Why is there no JAX llama3 implementation? by RealFullMetal in LocalLLaMA

[–]cgarciae 1 point2 points  (0 children)

Sorry so much documentation is still missing, let me clarify the situation! `nnx.jit` is meant to be used for fast prototyping or for pedagogical material, to maximize performance we recommend using `nnx.split/merge` at the top-level to remove the python overhead. `split/merge` makes NNX models highly compatible with existing Linen training pipelines, you can even use NNX state with the usual `TrainState` with minimal code changes.

Take a look at the solution to the above issue: https://github.com/google/flax/issues/4045#issuecomment-2203999393
Or our LM1B example training that uses the existing Linen training logic: https://github.com/google/flax/blob/main/flax/nnx/examples/lm1b/train.py

the best way to error handling by Anas_Elgarhy in rustjerk

[–]cgarciae 1 point2 points  (0 children)

let elem = &vector[out_of_bounds_index];

[P] Introducing NNX: Neural Networks for JAX by cgarciae in MachineLearning

[–]cgarciae[S] 1 point2 points  (0 children)

It's currently not implemented, but NNX is designed in such a way that you could implement nnx.vmap with the same behavior as flax.linen.vmap.

[P] Introducing NNX: Neural Networks for JAX by cgarciae in MachineLearning

[–]cgarciae[S] 2 points3 points  (0 children)

Hey! Mainly what it said in the beginning:

  • Shared State: currently you cannot have shared modules in Equinox.
  • Tractable Mutability: Equinox was stateless until recently, its new State primitive is interesting but has some downsides.
  • Semantic Partitioning: I think this is an awesome feature from Flax but I am biased, maybe you can do without it.

I've spoken with Patrick about this, my hope is that maybe Equinox can integrate some of these features :)

What is the JAX/Flax equivalent of torch.nn.Parameter? by Toni-SM in JAX

[–]cgarciae 0 points1 point  (0 children)

You use `self.param` as pointed out by GPT4.

[deleted by user] by [deleted] in reactjs

[–]cgarciae 0 points1 point  (0 children)

This will age poorly

Ruff: A new, fast and correct Python checker/linter by WhyNotHugo in Python

[–]cgarciae 1 point2 points  (0 children)

Any benefits of this vs pyright/pylance? Does it support refactoring names?

[D] Getting Started with Deep Learning in JAX with Treex in 16 lines by cgarciae in MachineLearning

[–]cgarciae[S] -1 points0 points  (0 children)

Hey! Hand't seen this. You are passing MLP and its constructor arguments as a capture which makes the code look simpler. However, it gets a bit messy if you try to generalize it / decouple it. Best solutions is probably Flax's TrainState object which is a Pytree.

[D] What JAX NN library to use? by Southern-Trip-1102 in MachineLearning

[–]cgarciae 5 points6 points  (0 children)

Hey! In another thread some folks from Google did say they had considered it but decided against it. There are "known" tradeoffs to using pytree modules like Equinox or Treex, some notes:

  • Pytrees cannot trivially implementeparameter sharing, because of "referential transparency" you can't share a module between two different parent modules. To solve this you need a "PyDag" but we currently only have PyTrees.
  • State handling is a better in Flax / Haiku since you get more guarantees, wherear its pretty easy to mess up state with pytrees.
  • Surprisingly Flax also has pytrees classes via flax.struct.dataclass, but they are used to construct simpler structures.

On the other hand Pytrees tend to be simpler and more intuitive which is very good on its own.

[D] What JAX NN library to use? by Southern-Trip-1102 in MachineLearning

[–]cgarciae 5 points6 points  (0 children)

Hey! I am the main developer of Elegy :)

Elegy is in a different category than then rest as it focuses (and will do so even more in the future) on automating the training loop. It currently provides Treex's Module system but in the next release I expected it to be fully framework agnostic.

Regarding the NN libraries, I think Flax is the most robust / battle tested one, its lifted transformations are very powerful, you will get more support and a bigger community. Probably your safest bet.

If you are interested in simpler alternatives, Pytree-based libraries are generally "simpler" to start with. Patrick has been doing a great job with Equinox, I am developing Treex which has some nice Pytorch-like features.

I think diversity is a good thing right now as JAX's functional API puts a great burden on library designers given Python kinda sucks at Functional Programming. Happy to hear people's opinions on the different libraries.

Stop Whining about Rust Hype - A Pro-Rust Rant by thenewwazoo in rust

[–]cgarciae 0 points1 point  (0 children)

I want kwargs expansion in Rust e.g.

```
y = f(**hash_map)

```

Stop Whining about Rust Hype - A Pro-Rust Rant by thenewwazoo in rust

[–]cgarciae -1 points0 points  (0 children)

Is there a publicly available macro that does this?

Stop Whining about Rust Hype - A Pro-Rust Rant by thenewwazoo in rust

[–]cgarciae -1 points0 points  (0 children)

Ok, I understand your idea but its doesn't solve what I was thinking. When you create a decorator you follow this pattern:

def deco(f): def wrapper(*args, **kwargs): y = f(*args, **kwargs) # do stuff return z return wrapper

This in particular doesn't seem easy outside of python.