Skip to content

Library for Jacobian descent with PyTorch. It enables the optimization of neural networks with multiple losses (e.g. multi-task learning).

Notifications You must be signed in to change notification settings

SimplexLab/TorchJD

Repository files navigation

image TorchJD

Doc Static Badge Tests codecov pre-commit.ci status PyPI - Python Version Static Badge

TorchJD is a library extending autograd to enable Jacobian descent with PyTorch. It can be used to train neural networks with multiple objectives. In particular, it supports multi-task learning, with a wide variety of aggregators from the literature. It also enables the instance-wise risk minimization paradigm. The full documentation is available at torchjd.org, with several usage examples.

Jacobian descent (JD)

Jacobian descent is an extension of gradient descent supporting the optimization of vector-valued functions. This algorithm can be used to train neural networks with multiple loss functions. In this context, JD iteratively updates the parameters of the model using the Jacobian matrix of the vector of losses (the matrix stacking each individual loss' gradient). For more details, please refer to Section 2.1 of the paper.

How does this compare to averaging the different losses and using gradient descent?

Averaging the losses and computing the gradient of the mean is mathematically equivalent to computing the Jacobian and averaging its rows. However, this approach has limitations. If two gradients are conflicting (they have a negative inner product), simply averaging them can result in an update vector that is conflicting with one of the two gradients. Averaging the losses and making a step of gradient descent can thus lead to an increase of one of the losses.

This is illustrated in the following picture, in which the two objectives' gradients $g_1$ and $g_2$ are conflicting, and averaging them gives an update direction that is detrimental to the first objective. Note that in this picture, the dual cone, represented in green, is the set of vectors that have a non-negative inner product with both $g_1$ and $g_2$.

image

With Jacobian descent, $g_1$ and $g_2$ are computed individually and carefully aggregated using an aggregator $\mathcal A$. In this example, the aggregator is the Unconflicting Projection of Gradients $\mathcal A_{\text{UPGrad}}$: it projects each gradient onto the dual cone, and averages the projections. This ensures that the update will always be beneficial to each individual objective (given a sufficiently small step size). In addition to $\mathcal A_{\text{UPGrad}}$, TorchJD supports more than 10 aggregators from the literature.

Installation

TorchJD can be installed directly with pip:

pip install torchjd

Some aggregators may have additional dependencies. Please refer to the installation documentation for them.

Usage

Compared to standard torch, torchjd simply changes the way to obtain the .grad fields of your model parameters.

Using the autojac engine

The autojac engine is for computing and aggregating Jacobians efficiently.

1. backward + jac_to_grad

In standard torch, you generally combine your losses into a single scalar loss, and call loss.backward() to compute the gradient of the loss with respect to each model parameter and to store it in the .grad fields of those parameters. The basic usage of torchjd is to replace this loss.backward() by a call to torchjd.autojac.backward(losses). Instead of computing the gradient of a scalar loss, it will compute the Jacobian of a vector of losses, and store it in the .jac fields of the model parameters. You then have to call torchjd.autojac.jac_to_grad to aggregate this Jacobian using the specified Aggregator, and to store the result into the .grad fields of the model parameters. See this usage example for more details.

2. mtl_backward + jac_to_grad

In the case of multi-task learning, an alternative to torchjd.autojac.backward is torchjd.autojac.mtl_backward. It computes the gradient of each task-specific loss with respect to the corresponding task's parameters, and stores it in their .grad fields. It also computes the Jacobian of the vector of losses with respect to the shared parameters and stores it in their .jac field. Then, the torchjd.autojac.jac_to_grad function can be called to aggregate this Jacobian and replace the .jac fields by .grad fields for the shared parameters.

The following example shows how to use TorchJD to train a multi-task model with Jacobian descent, using UPGrad.

  import torch
  from torch.nn import Linear, MSELoss, ReLU, Sequential
  from torch.optim import SGD

+ from torchjd.autojac import jac_to_grad, mtl_backward
+ from torchjd.aggregation import UPGrad

  shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
  task1_module = Linear(3, 1)
  task2_module = Linear(3, 1)
  params = [
      *shared_module.parameters(),
      *task1_module.parameters(),
      *task2_module.parameters(),
  ]

  loss_fn = MSELoss()
  optimizer = SGD(params, lr=0.1)
+ aggregator = UPGrad()

  inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
  task1_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the first task
  task2_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the second task

  for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
      features = shared_module(input)
      output1 = task1_module(features)
      output2 = task2_module(features)
      loss1 = loss_fn(output1, target1)
      loss2 = loss_fn(output2, target2)

-     loss = loss1 + loss2
-     loss.backward()
+     mtl_backward([loss1, loss2], features=features)
+     jac_to_grad(shared_module.parameters(), aggregator)
      optimizer.step()
      optimizer.zero_grad()

Note

In this example, the Jacobian is only with respect to the shared parameters. The task-specific parameters are simply updated via the gradient of their task’s loss with respect to them.

Tip

Once your model parameters all have a .grad field, it's the role of the optimizer to update the parameters values. This is exactly the same as in standard torch.

3. jac

If you're simply interested in computing Jacobians without storing them in the .jac fields, you can also use the torchjd.autojac.jac function, that is analog to torch.autograd.grad, except that it computes the Jacobian of a vector of losses rather than the gradient of a scalar loss.

Using the autogram engine

The Gramian of the Jacobian, defined as the Jacobian multiplied by its transpose, contains all the dot products between individual gradients. It thus contains all the information about conflict and gradient imbalance. It turns out that most aggregators from the literature (e.g. UPGrad) make a linear combination of the rows of the Jacobian, whose weights only depend on the Gramian of the Jacobian.

An alternative implementation of Jacobian descent is thus to:

  • Compute this Gramian incrementally (layer by layer), without ever storing the full Jacobian in memory.
  • Extract the weights from it using a Weighting.
  • Combine the losses using those weights and make a step of gradient descent on the combined loss.

The main advantage of this approach is to save memory because the Jacobian (that is typically large) never has to be stored in memory. The torchjd.autogram.Engine is precisely made to compute the Gramian of the Jacobian efficiently.

The following example shows how to use the autogram engine to minimize the vector of per-instance losses with Jacobian descent using UPGrad.

  import torch
  from torch.nn import Linear, MSELoss, ReLU, Sequential
  from torch.optim import SGD

+ from torchjd.autogram import Engine
+ from torchjd.aggregation import UPGradWeighting

  model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU())

- loss_fn = MSELoss()
+ loss_fn = MSELoss(reduction="none")
  optimizer = SGD(model.parameters(), lr=0.1)

+ weighting = UPGradWeighting()
+ engine = Engine(model, batch_dim=0)

  inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
  targets = torch.randn(8, 16)  # 8 batches of 16 targets for the first task

  for input, target in zip(inputs, targets):
      output = model(input).squeeze(dim=1)  # shape [16]
-     loss = loss_fn(output, target)  # shape [1]
+     losses = loss_fn(output, target)  # shape [16]

-     loss.backward()
+     gramian = engine.compute_gramian(losses)  # shape: [16, 16]
+     weights = weighting(gramian)  # shape: [16]
+     losses.backward(weights)
      optimizer.step()
      optimizer.zero_grad()

You can even go one step further by considering the multiple tasks and each element of the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import Flattening, UPGradWeighting
from torchjd.autogram import Engine

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
    *shared_module.parameters(),
    *task1_module.parameters(),
    *task2_module.parameters(),
]

optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module, batch_dim=0)

inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16)  # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16)  # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
    features = shared_module(input)  # shape: [16, 3]
    out1 = task1_module(features).squeeze(1)  # shape: [16]
    out2 = task2_module(features).squeeze(1)  # shape: [16]

    # Compute the matrix of losses: one loss per element of the batch and per task
    losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1)  # shape: [16, 2]

    # Compute the gramian (inner products between pairs of gradients of the losses)
    gramian = engine.compute_gramian(losses)  # shape: [16, 2, 2, 16]

    # Obtain the weights that lead to no conflict between reweighted gradients
    weights = weighting(gramian)  # shape: [16, 2]

    # Do the standard backward pass, but weighted using the obtained weights
    losses.backward(weights)
    optimizer.step()
    optimizer.zero_grad()

Note

Here, because the losses are a matrix instead of a simple vector, we compute a generalized Gramian and we extract weights from it using a GeneralizedWeighting.

More usage examples can be found here.

Supported Aggregators and Weightings

TorchJD provides many existing aggregators from the literature, listed in the following table.

Aggregator Weighting Publication
UPGrad (recommended) UPGradWeighting Jacobian Descent For Multi-Objective Optimization
AlignedMTL AlignedMTLWeighting Independent Component Alignment for Multi-Task Learning
CAGrad CAGradWeighting Conflict-Averse Gradient Descent for Multi-task Learning
ConFIG - ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks
Constant ConstantWeighting -
DualProj DualProjWeighting Gradient Episodic Memory for Continual Learning
GradDrop - Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout
IMTLG IMTLGWeighting Towards Impartial Multi-task Learning
Krum KrumWeighting Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent
Mean MeanWeighting -
MGDA MGDAWeighting Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
NashMTL - Multi-Task Learning as a Bargaining Game
PCGrad PCGradWeighting Gradient Surgery for Multi-Task Learning
Random RandomWeighting Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning
Sum SumWeighting -
Trimmed Mean - Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates

Contribution

Please read the Contribution page.

Citation

If you use TorchJD for your research, please cite:

@article{jacobian_descent,
  title={Jacobian Descent For Multi-Objective Optimization},
  author={Quinton, Pierre and Rey, Valérian},
  journal={arXiv preprint arXiv:2406.16232},
  year={2024}
}