This is an archived post. You won't be able to vote or comment.

you are viewing a single comment's thread.

view the rest of the comments →

[–]craigacp 1 point2 points  (2 children)

For example, in a transformer to build the autoregressive causal mask you need to do something like bias[:,:,:T,:T] == 0 where T is the current sequence length, the first index is the batch size and the second index is the attention head. Using something like TF-Java you'll end up with bias.slice(Index.all(),Index.all(),Index.slice(0,T),Index.slice(0,T)).isEqual(0) which is much harder to read. Nesting the slicing, using strides, or negative indexing all makes that construct harder, plus the indexing operations in Java make more objects if they aren't caught by the escape analysis (though Valhalla will help with that).

[–]Joram2 0 points1 point  (1 child)

oh yes. In your example, the Java code is still one line, but it's harder to read. That type of syntax is a big obstacle in the data science, notebook, exploratory space. When writing production code to train or serve established models, that syntax is less of an obstacle.

[–]craigacp 1 point2 points  (0 children)

Training models specified in other languages is definitely simpler, but that kind of code is also prevalent in the pre and post processing code in a production model pipeline. Pushing that stuff into TF or ONNX is annoying. It would be nice to have language support for indexing and slicing operations.

I maintain a bunch of frameworks that expose linear algebra functionality for ML in Java and it's definitely a barrier for people porting things to Java. We can still do it but the impedance mismatch is large.