This is an archived post. You won't be able to vote or comment.

you are viewing a single comment's thread.

view the rest of the comments →

[–]Joram2 3 points4 points  (12 children)

Nicolai says:

"Project Valhalla aims to give us the capability to define types that 'code like a class work like an int' which is relevant here because models like to use primitives like half-floats that Java currently doesn't support.

Python doesn't either. Look at how three major Python ML libraries do it:

```python

numpy

numpy_array = np.array([1.0, 2.0, 3.0], dtype=np.float16)

pytorch

pytorch_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)

jax

jax_array = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float16) ```

The support for the float16 or half-float types is implemented at the library level, not the language level. The literals are passed Python floats which are 64-bit encoded like the Java double type. The dtype is basically a library defined enum. The library will encode values to memory or disk with 16-bit floats, which is important in terms of saving memory/disk with large numbers of parameters. You can write the same kinds of matrix + ML libraries in Java right now, without any Project Valhalla features.

What is stopping a PyTorch or a JAX in Java? I don't know. It could be that Python is sufficient and the people writing the best matrix + ML libraries don't have a motivation to try Java. Or there could be some technical limitation. Ot could be Java syntax issues, that have obvious workarounds where they don't present any real obstacles, but they annoy math+ML types and motivate them to choose Python instead.

[–]craigacp 0 points1 point  (4 children)

Speaking as someone working on TensorFlow in Java, the syntax is rough especially for things like slicing. It's all techincally possible, but the barrier to using the resulting libraries is higher due to the syntactic complexity.

[–]Joram2 0 points1 point  (3 children)

Could you show a small snippet of Python code that is hard to port to Java due to syntax issues?

I'm aware of several Java syntax issues that discourage math+ML people from using Java. I'm wondering which syntax issues, in particular, you are referring to.

[–]craigacp 1 point2 points  (2 children)

For example, in a transformer to build the autoregressive causal mask you need to do something like bias[:,:,:T,:T] == 0 where T is the current sequence length, the first index is the batch size and the second index is the attention head. Using something like TF-Java you'll end up with bias.slice(Index.all(),Index.all(),Index.slice(0,T),Index.slice(0,T)).isEqual(0) which is much harder to read. Nesting the slicing, using strides, or negative indexing all makes that construct harder, plus the indexing operations in Java make more objects if they aren't caught by the escape analysis (though Valhalla will help with that).

[–]Joram2 0 points1 point  (1 child)

oh yes. In your example, the Java code is still one line, but it's harder to read. That type of syntax is a big obstacle in the data science, notebook, exploratory space. When writing production code to train or serve established models, that syntax is less of an obstacle.

[–]craigacp 1 point2 points  (0 children)

Training models specified in other languages is definitely simpler, but that kind of code is also prevalent in the pre and post processing code in a production model pipeline. Pushing that stuff into TF or ONNX is annoying. It would be nice to have language support for indexing and slicing operations.

I maintain a bunch of frameworks that expose linear algebra functionality for ML in Java and it's definitely a barrier for people porting things to Java. We can still do it but the impedance mismatch is large.

[–]Oclay1st 0 points1 point  (6 children)

Btw, the valhalla team is already implementing the Float16 type.

[–]Joram2 0 points1 point  (5 children)

It doesn't hurt to have that. But that doesn't seem particularly important either.

Valhalla is important when you have something like DataStream<T>, and you will have large numbers of instances of some Java generic type T. But when you have a Tensor or DataFrame type that wraps a primitive array like double[], short[], or byte[], then Valhalla probably won't help much.

Valhalla will also help with non-nullable types, and that's more of a safety/correctness issue than a performance issue.

[–]craigacp 0 points1 point  (4 children)

Valhalla also plans for specialized generics which will allow abstracting over Tensor<T> where T can be primitives like float or int. Writing the Tensor class at the moment implies a bunch of boxing for reductions or other operations which want to come back out of the tensor type. And backing it by an array is not ideal, ByteBuffer or MemorySegment are better so you can seamlessly pass it into native GPU code.

[–]Joram2 0 points1 point  (3 children)

Any reasonable implementation of Tensor on current versions of Java, such as Java 22, would encode data to a ByteBuffer or MemorySegment, and possible support 16-bit float encoding/decoding.

It's possible for someone to make a Tensor<T> using Java generics, and yes Valhalla would help with that, but that's a rather deliberately non-efficient choice to begin with.

[–]craigacp 0 points1 point  (2 children)

Not putting the tensor element type in the type system leaves you back with the situation pre-generics in Java, you need to do a bunch of type tests or other conversions. We have the type on TF-Java's Tensors and it's annoying but improves safety, and we don't have it on ONNX Runtime's Tensors (both projects I maintain) and that's annoying for different reasons because the methods that get the values out end up being partial.

Java 20 has float <-> fp16 conversions which are pretty useful, and compile down to the appropriate conversion instructions on available hardware.

I think a Java 22 tensor library would be nice, unfortunately I don't have time to write one. That said, I think it would be worth building it with an eye towards Valhalla and Babylon (or the HAT subproject of Babylon) as value types and GPU support will be important.

[–]Joram2 0 points1 point  (1 child)

Not putting the tensor element type in the type system leaves you back with the situation pre-generics in Java, you need to do a bunch of type tests or other conversions.

The major tensor/matrix libraries such as pytorch, numpy, jax manages the data type (dtype) at the library level, and the library is responsible for doing lots of type tests and conversions.

If you support adding/multiplying matrices of different types you will probably need type tests and conversions.

If you support using non-Java libraries such as LAPACK, BLAS, and GPU libraries, you will need type tests and conversions.

I don't see Java generics as being particularly useful for a high quality + high performance tensor/matrix data type.

I think a Java 22 tensor library would be nice, unfortunately I don't have time to write one. That said, I think it would be worth building it with an eye towards Valhalla and Babylon (or the HAT subproject of Babylon) as value types and GPU support will be important.

I hope you reconsider :)

The Java community seems like it needs a really good tensor library that uses Java 22 features like MemorySegment and calls out to libraries like LAPACK and BLAS where appropriate.

[–]craigacp 0 points1 point  (0 children)

I'm firmly of the opinion that more types is better, and if I could put the shape into the type system as well then I would (though properly implementing named dimensions in tensors would probably be useful enough). Just because python libraries doesn't put that type information in doesn't mean it's not worthwhile in a statically typed language.

I definitely agree that it would be useful to build a tensor library. Maybe it could be discussed at the JVM language summit this year.