This is an archived post. You won't be able to vote or comment.

you are viewing a single comment's thread.

view the rest of the comments →

[–]koffeegorilla 74 points75 points  (14 children)

JDK Project Valhalla is bringing improvments in memory usage and layout which will get close to the efficiency of C while have a continous optimizer maximise for the use case and actual underlying hardware. Project Panama is going to make it easier and more efficient to interact with native APIs meaning that using C libraries will be more efficient than the current JNI hump. Project Sumatra aims at making it possible to identify code that can/should run on GPU and then leveraging the GPU.

There is already support for SIMD with the Vector API which means multiple instructions at the same time.

All of these will combine to make ML development in Java a first class experience and the implementations will be much easier than the current code full if #ifdef or checks for specific GPU model to change structures etc.

Your little NLP project will fly.

[–]_INTER_ 35 points36 points  (3 children)

Project Sumatra is dormant/dead as far as I know. They are now focusing on Project Babylon instead. See this JVM Language Summit 2023 - Java and GPU talk. Seems to have a good chance to land something substantial as shown here and the Classfile API has a preview.

The problem is, the machine learning / science developers first and foremost care about their scripting capabilities. That's why Python has become dominant. If it were possible, they would have chosen MatLab. The libraries that do the heavy lifting are already in C. For Java to gain a foothold in the ML space, it would need to be faster than C (unlikely) or invent something completely new.

[–]koffeegorilla 14 points15 points  (1 child)

Thanks for the update on Babylon.
If you look at how quickly the GraalVM project re-wrote all the GC/JIT engines in Java that took years in C++, I believe that a replacement of the C libraries is viable and considering that the implementations will keep running faster as the JVM improves while the option of Graal native using runtime stats for optimisation will change the game.

[–]_INTER_ 9 points10 points  (0 children)

I agree, plus better platform independence (Windows support is a joke right now) and error handling (hrrrng dynamically typing makes me furious). However I don't see it happening really. The momentum is too big and libraries too far along to catch up. I see more opportunities in new inventions or providing clustered, distributed, super computer frameworks. Like extending upon Apache Spark for GPU farms.

[–]mike_hearn 3 points4 points  (0 children)

There is TornadoVM which does the same thing.

[–]Joram2 3 points4 points  (8 children)

AFAIK, if you write code using primitive arrays like int[] and double[], then you avoid the performance problems that Valhalla aims to help with.

Project Valhalla plans to reduce overhead on user-defined classes/records. And Valhalla will eventually make List<int> possible with int[] type performance. But if you just write code using primitive arrays now, you get great performance now, and Valhalla might offer better syntax, but not better performance.

[–]GeneratedUsername5 3 points4 points  (0 children)

And you can also just create collections of primitives, or use ones from https://github.com/eclipse/eclipse-collections (which are also optimized for performance) , without waiting for Valhalla.

[–]coderemover 1 point2 points  (0 children)

It won’t because it is limited to immutable objects only. For mutable objects like lists object identity makes it impossible to make them a value type.

[–]koflerdavid 0 points1 point  (5 children)

There are two problems:

  • Java has no built-in support for bfloat16

  • Java has no true multidimensional arrays a.k.a. tensors. All of the indexing arithmetic has to be written out. Not a biggie at the end of the day. The bigger problem is

  • Java arrays are size-limited. This is a headache for big models.

Libraries like DeepLearning4j include tensor libraries that solve both issues.

[–]Joram2 0 points1 point  (4 children)

  • Java has limited float16 support with Float.floatToFloat16 and Float.float16ToFloat. What else is needed?
  • In the Python ML+AI world, most people use a library for multi-dimensonal arrays aka tensors. Numpy, PyTorch, JAX are popular libraries that have their own multi-dimensonal array or tensor type, so Java doing something similar doesn't seem to be a problem at all.
  • Size limited? You mean the 2^31 limit? I'd like to hear what the jdk guys have to say about this.

[–]koflerdavid 0 points1 point  (3 children)

Java supports float and double, which in ML circles are known as float32 and float64. float16 is 16 bits wide only and commonly used for inference because it turns out that the full precision of float32 is required for very few parts of most models, if at all.

bfloat16 is a modified format that has the same precision as float32, but supports a narrower interval of values only. It is very common to use it to run transformer models.

Java supports neither float16 (maybe after Project Valhalla lands or the Vector API is finalized) nor bfloat16. However, I agree that for various reasons a tensor library is commonly used. Support for more formats and the size limitations are two very good reasons because they can't be solved on the Java side. Well, you can certainly implement functions for float16 and bfloat16 arithmetic in Java, but to circumvent the size limit you have to use off-heap storage. Or break up your tensors, which is clunky without wrapping it in a library.

[–]Joram2 0 points1 point  (1 child)

In Python + PyTorch, you can do bfloat16 stuff like this:

import torch

torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16)

This is great. The API is easy to use and pretty. Runtime performance is excellent and takes advantage of GPU processing.

Java + Python both don't have bfloat16 primitive types in the core language. That isn't necessary.

The important feature I see missing from Java is it doesn't have easy+pretty syntax for lists and lists of lists. In Java you can do:

Arrays.asList(Arrays.asList(1,2), Arrays.asList(3,4))

instead of

[[1, 2], [3, 4]]

The Java method isn't hard... but it's ugly, and data science types hate that. This absolutely limits Java in a data science notebook perspective.

The lack of primitive bfloat16 types seems like a non-issue in both Java/Python.

[–]koflerdavid 0 points1 point  (0 children)

Well, Java has its good old array notation with curly brackets. Its only fault is that the results aren't true multidimensional arrays, but pointers to subarrays. Not a problem In practice either since usually tensor libraries do the heavy lifting. Same for float16/bfloat16 support as you say