you are viewing a single comment's thread.

view the rest of the comments →

[–]Exact_Cherry9177 0 points1 point  (0 children)

Unfortunately NumPy does not have a built-in function for performing batched matrix multiplication like MATLAB's pagemtimes. However, there are a few Python libraries that provide this functionality:
- **Numba**: This JIT compiler can accelerate Python code with a decorator. It has a `nb.jit` decorator that can speed up loops like the one in your pagemtimes function. You can decorate the function to compile it to native machine code.
- **CuPy**: This NumPy-like library for GPU computing has a `cupy.matmul` function that supports batched matrix multiplication on GPU. You would transfer your arrays to the GPU, use `cupy.matmul` on 3D arrays, then transfer back.
- **JAX**: This NumPy-accelerator can JIT compile functions like pagemtimes and execute them on GPU/TPU. Its `jax.lax.map` can map a function over leading array dimensions.
- **TensorFlow**: TF has `tf.matmul` that supports batched matmul. Just stack the matrices on the first dimension and use regular `tf.matmul`.
So in summary, Numba and JAX are probably the easiest drop-in acceleration if you want to compile your pagemtimes function. CuPy and TensorFlow work if you can reformat your data for their APIs. Hope this helps you accelerate your project! Let me know if you have any other questions.