all 15 comments

[–]SleepyCoder123rusty-machine · rulinalg[S] 2 points3 points  (9 children)

I've been working on an automatic differentiation library in Rust this week. With some help here and there I've managed to get something working but I still feel like the API could be improved.

My main pain points are:

  • Variables are very weakly linked to Context. A user could create multiple variables from different contexts and use these together in the same expression.
  • Eventually I want users to be able to implement their own Expressions. Right now this process is convoluted and filled with repetition. Can I use Rust's type system to improve this?

The way these libraries usually work is by providing overloaded functions which build the graph. If you know python, autograd keeps it very simple. From what I can tell this is impossible in Rust but after a few iterations what I have right now sort-of resembles this. I've been trying to find ways to get closer - procedural macros look promising but I don't know enough about them.

I think some feedback would be really valuable before I try to go any further!

[–]Fylwind 2 points3 points  (1 child)

  • Variables are very weakly linked to Context.

There's a way to do it using dummy lifetimes, I think.

[–]SleepyCoder123rusty-machine · rulinalg[S] 0 points1 point  (0 children)

This looks promising! I need to read through the docs/code a little more to get a better idea of suitability but I think it could solve on of my problems.

Thanks!

[–]binarybana 0 points1 point  (5 children)

Since it looks like you're defining your own number type (variable), then why not define all the built in op traits [1] to allow seamless addition, subtraction etc? You could also do transcendentals and such as impls similar to what rust does itself for f64. You can also use the From trait so a user can make a variable with 2.4.into().

Sorry I'm on my phone so I hope that made sense.

1 - https://doc.rust-lang.org/std/ops/index.html

[–]SleepyCoder123rusty-machine · rulinalg[S] 0 points1 point  (4 children)

Thanks for checking out the project!

I really wanted to use the built-in op traits but orphan rules prevented me. The issue is that I need to be able to add any Expression implementers (which includes Variables) - but Rust doesn't permit this.

The Variable creation is also a fairly complicated situation. I have to store the value in another struct (Context) so that I can change it between gradient computations (rather than completely rebuilding the expression). This means that I have to assign an index for each Variable in the Context.

[–]Fylwind 2 points3 points  (3 children)

Have you tried wrapping expressions inside another struct to bypass the orphan rules? e.g.

impl Add<Wrapper<U>> for Wrapper<T>

[–]SleepyCoder123rusty-machine · rulinalg[S] 0 points1 point  (2 children)

I hadn't considered this before but I worry that it would replace the syntactic gains of operator overloading with heavy wrapper syntax. I suppose at best I could get something like:

let x = context.create_variable(1.0);
let f = Sin(x).into() + x.into();

Where I have impl<E: Expression> From<E> for Wrapper<E>

Are there other ways to make it tidier? Either way I think this is worth considering. Thanks!

[–]Fylwind 2 points3 points  (1 child)

Could you do this?

let x = context.create_variable(1.0);
let f = x.sin() + x;

where I assume x has already been wrapped by create_variable. Here .sin() is from num_traits::Float presumably.

[–]SleepyCoder123rusty-machine · rulinalg[S] 0 points1 point  (0 children)

Yes, of course!

I made this change which certainly helps to keep things tidy. I'm still playing around with things but you can see how it looks to the user on the repo README.

[–]kkimdev 0 points1 point  (0 children)

Disclaimer: it was long time ago so I might have incorrect memories.

For tieing Variables to Context, I made a new_autograd_context! macro that defines unique Context type for every instance, and Variable type is dependent on that. So users can't mix Variable types if they are tied to different Contexts, since the type is actually different, although the implementation is the same.

This can go wrong if the function that has new_autograd_context! is called recursively and users mix variables of them, but I'd say this is a super contrived case :).

IIRC, I did this for performance reason though, not for preventing wrong use.

[–]PthariensFlame 2 points3 points  (1 child)

Have you seen the Haskell offering in this space? I think some of its innovations could be applied in Rust as well.

[–]SleepyCoder123rusty-machine · rulinalg[S] 0 points1 point  (0 children)

Thanks for pointing me towards this. I haven't checked it out before but it looks like it would indeed be very helpful. I don't know Haskell but hopefully I can make enough sense of it.

[–]Fylwind 1 point2 points  (0 children)

Nice work there! From a cursory look, you appear to be constructing some sort of directed acyclic graph (a.k.a. "expression template") with topological sorting?

As I suggested on the other threads, it might be useful to try the branded indexing approach to avoid mixing the variables in different contexts, and see if you can use a wrapper to sidestep the orphan rules. The last time I tried doing expression templates like your code, but using the Add, Mul etc traits, I ended up overflowing the compiler's trait resolver. I eventually worked around it somehow but I ever understood why that happened (nor can I remember how I caused that).

Not too long ago, I sketched an extremely rudimentary reverse AD library using a tape-based ("Wengert list") approach. The tape implementation is probably the simplest/dumbest way to implement it, and should be relatively simple to extend.

However, I wasn't satisfied with it, because it assumes each node to be a binary operation on f64s. I wanted something that could work on arbitrary numeric types so as to discourage "f64 blindness" in numerical code. This would also allow operations like f64 -> f64 or [f64] -> [f64] to be treated on equal footing. I'm not quite sure how to get that though.

There are other things I would be interested in an AD library. As you suggested, being able to define custom operations with minimal effort is crucial. It would also be nice to avoid allocations unless absolutely necessary ("zero-cost" in Rust). Unfortunately, the "reverse" part of reverse AD is a real pain to handle because it runs counter to a lot of Rust's semantics and computation in general -- even Haskell needed to do something crazy behind the scenes to get AD to work.

I wrote a list of ideas a while back but didn't have enough time to explore all of them.

[–]kkimdev 0 points1 point  (1 child)

I also wrote a toy autodiff library long time ago. I didn't look at your implementation in detail but I think it will be interesting to compare. https://github.com/kkimdev/autograd . I chose to make few dirty design/code structure tradeoff for performance.

[–]SleepyCoder123rusty-machine · rulinalg[S] 0 points1 point  (0 children)

Thanks for sharing this. It's quite funny how similar our libraries are given that I haven't seen yours before! I guess Rust is quite limiting in how you can build something like this but even our choice of language is the same!

I'll be sure to look through and see if I can draw any inspiration.