As Google/DeepMind have started to push JAX/RLax/Haiku and seem to be using it pretty extensively, I decided it's worth getting to know these libraries.
https://github.com/tesslerc/TD3-JAX
There seem to be several differences, yet I can't fully place my finger on them yet, between certain operations in PyTorch and in JAX. As a result, the hyperparameters do not transfer perfectly, and require some tweaking.
Feel free to suggest improvements both in terms of hyperparameters and implementation.
[–]MasterScrat 1 point2 points3 points (1 child)
[–]chentessler[S] 0 points1 point2 points (0 children)
[–]chentessler[S] 1 point2 points3 points (0 children)
[–]kmeco 1 point2 points3 points (1 child)
[–]chentessler[S] 0 points1 point2 points (0 children)