Zing Forum

Reading

TorchJD: Solving Gradient Conflict in Multi-Task Learning with Jacobian Descent

TorchJD is a PyTorch extension library that implements the Jacobian Descent algorithm, specifically designed to resolve gradient conflicts between multiple loss functions in multi-task learning.

PyTorch多任务学习Jacobian下降梯度聚合机器学习神经网络优化
Published 2026-05-21 03:45Recent activity 2026-05-21 03:49Estimated read 5 min
TorchJD: Solving Gradient Conflict in Multi-Task Learning with Jacobian Descent
1

Section 01

[Main Floor] TorchJD: A PyTorch Extension Library for Resolving Gradient Conflicts in Multi-Task Learning

TorchJD is a PyTorch extension library that implements the Jacobian Descent algorithm, specifically designed to resolve gradient conflicts between multiple loss functions in multi-task learning. This article will introduce its background, core methods, usage, application scenarios, etc., to help readers understand the value and significance of this tool.

2

Section 02

[Background] The Dilemma of Gradient Conflicts in Multi-Task Learning

In the field of deep learning, multi-task learning allows a single neural network to handle multiple related tasks simultaneously. However, loss functions of different tasks often produce conflicting gradient directions. When the inner product of two gradients is negative, simple averaging leads to performance degradation of one task. For example, in visual models, classification tasks tend to focus on global features while localization tasks emphasize local details. When their gradient directions are opposite, traditional methods fail to balance both, and the model easily falls into a suboptimal solution.

3

Section 03

[Method] Jacobian Descent: A New Paradigm for Multi-Task Optimization

The Jacobian Descent algorithm introduced by TorchJD changes the paradigm of multi-task optimization. Unlike traditional gradient descent which handles scalar losses, it directly operates on the Jacobian matrix corresponding to the loss vector—each row of the matrix represents the gradient of a loss function with respect to the model parameters. By analyzing the matrix structure to identify gradient conflicts and adopting smarter aggregation strategies, it is more mathematically rigorous and practically effective.

4

Section 04

[Core Mechanism] Conflict-Free Gradient Projection Ensures Task Performance

TorchJD provides more than 10 gradient aggregators. The core idea of the representative UPGrad (Conflict-Free Gradient Projection) is: before aggregation, project each gradient onto the dual cone (a direction that includes non-negative inner products with the original gradient). This projection ensures that with a sufficiently small learning rate, each parameter update is beneficial to all tasks and does not harm the performance of any task.

5

Section 05

[Practical Usage] Seamless Integration of TorchJD with PyTorch

TorchJD aligns with PyTorch users' habits. You only need to replace the traditional loss.backward() with torchjd.autojac.backward(losses), then use jac_to_grad() to convert the Jacobian to a gradient. The mtl_backward() function for multi-task scenarios calculates the loss gradients for task-specific parameters and the Jacobian matrix for shared parameters separately, preserving independent optimization while resolving conflicts in shared parameters.

6

Section 06

[Application Scenarios] Wide Applicability and Extensibility of TorchJD

TorchJD not only supports traditional multi-task learning but also applies to instance-level risk minimization (such as personalized recommendation and federated learning). The torchjd.autojac.jac() function in the library allows calculating the Jacobian directly without storing the .jac field, providing a foundation for complex custom algorithms.

7

Section 07

[Conclusion] TorchJD Opens New Possibilities for Neural Network Optimization

TorchJD brings an important tool to the PyTorch ecosystem. It not only solves the gradient conflict problem in multi-task learning but also expands the possibilities of neural network optimization theory. For deep learning practitioners seeking multi-objective balance, this is a library worth exploring in depth.