Optax Feature Image

Unlocking Neural Networks with JAX: optimizers with optax

In our previous exploration within the “Unlocking in JAX” series, we covered the foundational concepts of neural networks and their implementation using JAX. Building on that knowledge, this post, shifts focus to a crucial component that significantly enhances learning efficiency: optimizers. Optax, a gradient processing and optimization library tailored for JAX, provides the tools necessary […]

Unlocking Neural Networks with JAX: optimizers with optax Read More »

MLP In JAX

Unlocking Neural Networks with JAX

Welcome back to our Unlocking with JAX series! Today, we’re getting hands-on with neural networks by building a multilayer perceptron (MLP) using JAX. JAX sits perfectly between the lower-level, details of CUDA and the higher-level abstractions offered by frameworks like Keras, offering both clarity and control. This balance helps immensely in understanding the inner workings

Unlocking Neural Networks with JAX Read More »

Scroll to Top