Unlocking Neural Networks with JAX: optimizers with optax

Optax Feature Image

In our previous exploration within the “Unlocking in JAX” series, we covered the foundational concepts of neural networks and their implementation using JAX. Building on that knowledge, this post, shifts focus to a crucial component that significantly enhances learning efficiency: optimizers. Optax, a gradient processing and optimization library tailored for JAX, provides the tools necessary for fine-tuning our neural network models. This entry aims to clarify the integration of Optax with JAX, illustrating how to leverage these optimizers to streamline the training process and achieve more robust results.

Twitter
LinkedIn



Introduction

In our previous blog post, we implemented a neural network from scratch using JAX, demonstrating the fundamentals of model building and parameter updates. The update function we used was straightforward

Python
@jax.jit
def update(params, inputs, targets, lr=1e-3):
    loss_val, grads = jax.value_and_grad(loss_fn)(params, inputs, targets)
    
    return [(W - lr * dW, b - lr * db) for (W, b), (dW, db) in zip(params, grads)], loss_val

While effective for basic scenarios, this method of manually updating parameters with a fixed learning rate has limitations, especially as models and datasets grow in complexity. A fixed learning rate may not be optimal throughout the training process; it can lead to slow convergence or even diverge if the rate is not well-tuned. Additionally, the simple update rule doesn’t account for the history of gradients, which can be crucial for navigating complex optimization landscapes more efficiently.

In this post, we’ll explore how advanced optimization algorithms can address these challenges, and we will introduce these algorithms via Optax, a library that dynamically adjusts learning rates and utilizes past gradient information to accelerate convergence and improve model performance.

Optax

Optax is a gradient processing and optimization library that’s all about making things easier and more efficient. Optax stands out because of its modular design—you can pick and choose the parts you need and leave out what you don’t. It’s like having a toolkit where every tool is exactly what you need for the job. What’s great about Optax is how it simplifies the often complex process of optimizing neural networks. Whether you’re tweaking a model to squeeze out that last bit of accuracy or trying to speed up training times, Optax provides a straightforward path to improving your models.

Optax is equipped with a variety of modules, but for now, we’ll concentrate on its robust suite of optimizers. These are designed to support a wide range of algorithms that are ready for immediate use. Optax, much like JAX, adheres to a design that emphasizes pure functions, avoiding the use of internal states and external modifications. Initially, this approach may seem challenging, but it quickly becomes intuitive. The learning curve is steep but short, allowing users to quickly adapt and fully leverage the power of these tools for more efficient and effective optimization.

SGD

Stochastic Gradient Descent (SGD) is a cornerstone optimization method in machine learning, particularly essential for training deep neural networks. Although traditionally it involves updating model parameters one data point at a time, modern implementations predominantly utilize mini-batches. This updated approach combines the computational efficiency of batch processing with the stochastic nature of the updates, enhancing training speed and stabilizing convergence.

Building on this, let’s explore our first optimizer, which we will seamlessly integrate into the neural network we developed in our previous post. We’ll examine how this optimizer operates and how it can be effectively used to optimize our model’s learning process, ensuring better performance and more robust results.

The first step is to initialize the optimizer and the necessary states required for it to function correctly.

Python
optimizer = optax.sgd(learning_rate=learning_rate)
opt_state = optimizer.init(_params)
  • optimizer = optax.sgd(learning_rate=learning_rate)
    • this line initialize SGD optimizer with parameter learning_rate
  • opt_state = optimizer.init(_params)
    • initializes the optimizer state using the parameters from the model (_params)
    • optimizer.init prepares the necessary state that the optimizer will use to store information specific to the model’s parameters across updates
      • state include things like momentum terms, averages of past gradients, or other optimizer-specific statistics necessary for updating the parameters

The next step is to modify the loop that iterates through each batch during the epochs. In this loop, the optimizer calculates the necessary gradients and additional relevant information. Using these computations, the parameters are updated according to the principles of gradient descent.

Python

# old implementation
# @jax.jit
# def update(params, inputs, targets, lr=1e-3):
#     loss_val, grads = jax.value_and_grad(loss_fn)(params, inputs, targets)
#     return [(W - lr * dW, b - lr * db) for (W, b), (dW, db) in zip(params, grads)], loss_val



# new implementation
@jax.jit
def update(opt_state, params, inputs, targets):
    loss_value, grads = jax.value_and_grad(loss_fn)(params, inputs, targets)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
        
    return new_params, opt_state, loss_value    
  • first modification in the update function signature to include opt_state
    • this adjustment is crucial due to the functional programming style enforced by JAX and Optax, where functions must not rely on any kind of internal state
  • updates, opt_state = optimizer.update(grads, opt_state)
    • it calls the update method of the optimizer, passing the current gradients (grads) and the optimizer state (opt_state) as arguments
    • the method returns two values
      • updates,
        • which are the computed changes to be applied to the model parameters
      • opt_state
        • which contains any necessary state information (like momentums or adaptive learning rate factors) needed for the next update cycle
  • new_params = optax.apply_updates(params, updates)
    • applies the computed updates (e.g., transformed gradients) to the current parameters (params)
    • returns new_params
      • which are the parameters after the updates have been applied, effectively moving the parameters towards an optimal state as dictated by the optimization strategy
  • update function then returns new parameters (new_params), updated optimizer state (opt_state) and loss value (loss_value)

Now, the updated function can be utilized within a loop to iterate over epochs and batches, effectively minimizing the loss value. This structured approach ensures systematic updates to the model’s parameters, facilitating gradual improvements in model performance over time.

Python
for _iter in range(iterations):
  train_loss = 0
  for batch_X, batch_Y in train_batch_generator:
      _params, opt_state, loss_batch = update(opt_state, _params, batch_X, batch_Y)

Now that we have all the necessary ingredients prepared, we can seamlessly cook up the entire workflow without any issues.

Python
def init_mlp_params(layers_configuration: list[int], seed: int):
    key = jax.random.PRNGKey(seed)
    key, *subkeys = jax.random.split(key, len(layers_configuration))

    params = []
    for (i, units) in enumerate(layers_configuration[1:]):
        params.append((
            jax.random.normal(subkeys[i], (layers_configuration[i], layers_configuration[i + 1])), # W
            jax.random.normal(subkeys[i], (layers_configuration[i+1],))                            # b
        ))

    return params


@jax.jit
def forward(params, inputs):
    activations = inputs
    *hidden_layers, output_layer = params

    for (W, b) in hidden_layers:
        activations = jnp.dot(activations, W) + b
        activations = jax.nn.relu(activations)

    return jax.nn.softmax(jnp.dot(activations, output_layer[0]) + output_layer[1])


@jax.jit
def predict(params, inputs):
    return jnp.argmax(forward(params, inputs), axis=1)


@jax.jit
def predict_proba(params, inputs):
    return forward(params, inputs)


@jax.jit
def make_safe_log(x):
    epsilon = 1e-12
    safe_x = jnp.where(x < epsilon, epsilon, x)

    return jnp.log(safe_x)


@jax.jit
def loss_fn(params, inputs, targets):
    targets_one_hot = jax.nn.one_hot(targets, 10)
    preds = forward(params, inputs)

    result = -jnp.mean(jnp.sum(targets_one_hot * make_safe_log(preds), axis=1))
    return result


@jax.jit
def update(opt_state, params, inputs, targets):
    loss_value, grads = jax.value_and_grad(loss_fn)(params, inputs, targets)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
        
    return new_params, opt_state, loss_value    


_params = init_mlp_params([784, 32, 10], seed=42)

batch_size = 128
iterations = 500
learning_rate = 1e-2

# get train and validation indices
np.random.seed(42)
indices = np.random.permutation(len(TRAIN_LABELS_NP))

train_indices = indices[:int(0.8 * len(TRAIN_LABELS_NP))]
val_indices = indices[int(0.8 * len(TRAIN_LABELS_NP)):]

train_batch_generator = BatchGenerator(TRAIN_DATA_NP[train_indices], TRAIN_LABELS_NP[train_indices], batch_size)
val_batch_generator = BatchGenerator(TRAIN_DATA_NP[val_indices], TRAIN_LABELS_NP[val_indices], batch_size)

history = {
    "loss": [],
    "val_loss": [],
    "val_accuracy": [],
    "val_top_3_accuracy": []
}


with tqdm(total=iterations, desc='Training') as pbar:
    optimizer = optax.sgd(learning_rate=learning_rate)
    opt_state = optimizer.init(_params)

    for _iter in range(iterations):
        train_loss = 0
        for batch_X, batch_Y in train_batch_generator:
            _params, opt_state, loss_batch = update(opt_state, _params, batch_X, batch_Y)
            train_loss += loss_batch

        train_loss /= train_batch_generator.num_batches
        history["loss"].append(train_loss)

        # compute validation loss
        val_loss = 0
        for batch_X, batch_Y in val_batch_generator:
            val_loss += loss_fn(_params, batch_X, batch_Y)

        val_loss /= val_batch_generator.num_batches
        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(accuracy_score(TRAIN_LABELS_NP[val_indices], predict(_params, TRAIN_DATA_NP[val_indices])))
        history["val_top_3_accuracy"].append(top_k_accuracy_score(TRAIN_LABELS_NP[val_indices], predict_proba(_params, TRAIN_DATA_NP[val_indices]), k=3))

        pbar.set_postfix({"val_loss": val_loss, "loss": train_loss})
        pbar.update(1)

In our implementation, we have configured the SGD optimizer with only the learning rate, even though SGD supports two additional options: Nesterov and momentum. Momentum helps to accelerate SGD in the relevant direction and dampens oscillations by incorporating a fraction of the update vector from the previous step into the current step. Nesterov, a refined version of momentum, first makes a big jump in the direction of the previous accumulated gradient before calculating the gradient, which can lead to faster convergence and reduced overshooting.

In order to include these parameters, it is sufficient to specify them during initialization. This allows for the customization of the optimizer to enhance its efficiency and effectiveness right from the start.

Python
optimizer = optax.sgd(learning_rate=learning_rate, momentum=0.9)

# or

optimizer = optax.sgd(learning_rate=learning_rate, nesterov=True)

# signature
optax.sgd(learning_rate, momentum=None, nesterov=False, accumulator_dtype=None)

Adam

The Adam optimizer is a widely acclaimed enhancement to classical optimization methods, particularly suited for large-scale machine learning problems. A versatile combination of momentum and adaptive learning rates, Adam stands for “Adaptive Moment Estimation.” This optimizer adjusts the learning rate for each parameter individually by estimating the first and second moments of the gradients. This approach not only accelerates the training process compared to standard stochastic gradient descent, but also tends to be less sensitive to hyperparameter settings. Adam’s ability to navigate complex landscapes and avoid common pitfalls like vanishing learning rate or slow convergence makes it a favorite choice among deep learning practitioners. Let’s see how easy it is to use the Adam optimizer in our implementation using the Optax library, and how it can significantly enhance the training of neural networks.

Python
# signature
optax.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None, *, nesterov=False)

In our case, replacing SGD with Adam simply involves updating the line where we initialized SGD in the previous example. By swapping in Adam at this point, the rest of the code will continue to function normally, allowing us to benefit from Adam’s advanced optimization features without additional modifications.

Python
# ...


with tqdm(total=iterations, desc='Training') as pbar:
    optimizer = optax.adam(learning_rate=learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0)
    opt_state = optimizer.init(_params)

    for _iter in range(iterations):
        train_loss = 0
        for batch_X, batch_Y in train_batch_generator:
            _params, opt_state, loss_batch = update(opt_state, _params, batch_X, batch_Y)
            train_loss += loss_batch

        train_loss /= train_batch_generator.num_batches
        history["loss"].append(train_loss)

        # compute validation loss
        val_loss = 0
        for batch_X, batch_Y in val_batch_generator:
            val_loss += loss_fn(_params, batch_X, batch_Y)

        val_loss /= val_batch_generator.num_batches
        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(accuracy_score(TRAIN_LABELS_NP[val_indices], predict(_params, TRAIN_DATA_NP[val_indices])))
        history["val_top_3_accuracy"].append(top_k_accuracy_score(TRAIN_LABELS_NP[val_indices], predict_proba(_params, TRAIN_DATA_NP[val_indices]), k=3))

        pbar.set_postfix({"val_loss": val_loss, "loss": train_loss})
        pbar.update(1)

The complete code implementation is available in my JAX tutorial repository, where you’ll find a separate notebook. This notebook extends our previous neural network implementation to include the use of the Optax library, showcasing how it enhances the model’s optimization process.

Summary

In this new entry of the Unlocking in JAX series, we explore enhancing neural networks with the Optax library. The post details how we upgraded our basic optimizer implementation from a previous project to utilize advanced capabilities. We illustrate the straightforward integration process and emphasize the ease of experimenting with a variety of optimizers available within the library. Optax provides a flexible and modular approach, allowing for precise control over the learning process. Its design not only boosts convergence rates but also simplifies the adoption of cutting-edge optimization techniques, making it a robust tool for improving model performance.

Scroll to Top