multi-task learning

Multi-task Learning and Beyond: 过去,现在与未来

近期 Multi-task Learning (MTL) 的研究进展有着众多的科研突破,和许多有趣新方向的探索。这激起了我极大的兴趣来写一篇新文章,尝试概括并总结近期 MTL 的研究进展,并探索未来对于 MTL 研究其他方向的可能。

这篇文章将顺着我 18 年硕士论文:Universal Representations: Towards Multi-Task Learning & Beyond 的大体框架,并加以补充近期新文章的方法,和未来新方向的讨论。

Disclaimer: 在硕士论文里提及的自己的文章均为当时未发表的 preliminary results,对于任何人想要了解论文里的细节请直接参看发表的 conference 文章。


Multi-task Learning 的两大研究分支

在绝大部分情况下,MTL 的研究可以归类为以下两个方向,一个是 MTL Network 网络设计;另一个是 MTL Loss function 损失函数设计。我们以下对于这两个方向进行详细解读。

Multi-task Learning Network Design / What to share? [网络设计]

在起初,MTL 的网络设计通常可以列为两种情况:Hard parameter sharing 和 soft parameter sharing。

img

Hard-parameter sharing – 现在是几乎所有做 MTL 不可缺少的 baseline 之一,就是将整个 backbone 网络作为 shared network 来 encode 任务信息,在最后一层网络 split 成几个 task-specific decoders 做 prediction。Hard-parameter sharing 是网络设计参数数量 parameter space 的 (并不严格的,假设不考虑用 network pruning) lower bound,由此作为判断新设计网络对于 efficiency v.s. accuracy 平衡的重要参考对象。

Soft-parameter sharing – 可以看做是 hard-parameter sharing 的另外一个极端,并不常见于现在 MTL 网络设计的比较。在 soft-parameter sharing 中,每一个任务都有其相同大小的 backbone network 作为 parameter space。我们对于其 parameter space 加于特定的 constraint 可以是 sparsity, 或 gradient similarity, 或 LASSO penalty 来 softly* constrain 不同任务网络的 representation space。假设我们不对于 parameter space 加以任何 constraint,那么 soft-parameter sharing 将塌缩成 single task learning。

任一 MTL 网络设计可以看做是找 hard 和 soft parameter sharing 的平衡点:1. 如何网络设计可以小巧轻便。2. 如何网络设计可以最大幅度的让不同任务去共享信息。

MTL network design is all about sharing.

  • Cross-Stitch Network

Cross-Stitch Network 是过去几年内比较经典的网络设计,也已常用于各类 MTL 研究的baseline 之一。其核心思想是将每个独立的 task-specific 网络使用 learnable parameters (cross-stitch units) 以 linear combination 的方式连接其中不同任务的 convolutional blocks。

img

Visualisation of Cross-Stitch Network

对于任务 A 与 B,每个 convolutional block 输出层 xA,Bi,j ,我们将计算:

[xAijxBij]=[ΛAAΛABΛBAΛBB][xAijxBij].

通过这样的运算,下一个 convolutional block 的输入层则为 x~A,Bi,j .

启发于 Cross-stitch 的设计,NDDR-CNN 也有类似的思路。然而不同的是,对于中间层的 convolutional block 的信息融合,他们采用了 concatenate 并通过 [1 x 1] 的 convolutional layer 来 reduce dimensionality。 这样的设计使得每个任务的 channel 都可以与其他不同 index 的 channel 交融信息,而规避了原始 Cross-stitch 只能 infuse 相同 channel 信息的局限性。当 NDDR 的 convolutional layer weights 的 non-diagonal elements 是 0 的话, NDDR-CNN 则数学上等价于 Cross-Stich Network。

Cross-Stitch Network 和 NDDR-CNN 的最大弱势就是对于每个任务都需要一个新的网络,以此整个 parameter space 会对于任务的数量增加而线性增加,因此并不 efficient。

  • Multi-task Attention Network

基于 Cross-stitch Network efficiency 的缺点,我后续提出了 Multi-task Attention Network (MTAN) 让网络设计更加小巧轻便,整个网络的 parameter space 对于任务数量的增加以 sub-linearly 的方式增加。

img

Visualisation of Multi-task Attention Network

MTAN 的核心思想是,assume 在 shared network encode 得到 general representation 信息之后,我们只需要少量的参数来 refine task-shared representation into task-specific representation, 就可以对于任意任务得到一个很好的 representation. 因此整个网络只需要增加少量的 task-specific attention module,两层 [1 x 1] conv layer,作为额外的 parameter space 来 attend 到 task-shared represenation。整个模型的参数相对于 Cross-Stitch Network 来说则大量减少。

  • AdaShare
    AdaShare 则更是将 MTL 网络设计的 efficiency 做到的极致。与其增加额外的 conv layer 来 refine task-shared representation,AdaShare 将单个 backbone 网络看做 representation 的整体,通过 differentiable task-specific policy 来决定对于任何一个 task,否用去更新或者利用这个网络的 block 的 representation。

img

Visualisation of AdaShare

由于整个网络是应用于所有任务的 representation,因此 network parameter space 是 agnostic 于任务数量,永远为常数,等价于 hard-parameter sharing。而搭接的 task-specific policy 是利用 gumbel-softmax 对于每一个 conv block 来 categorical sampling “select” 或者 “skip” 两种 policy,因为整个 MTL 的网络设计也因此会随着任务的不同而变化,类似于最近大火的 Neural Architecture Search 的思想。

  • MTL + NAS

MTL-NAS 则是将 MTL 和 NAS 结合的另外一个例子。他搭载于 NDDR 的核心思想,将其拓展到任意 block 的交融,因此网络搜索于如何将不同 task 的不同 block 交融来获得最好的 performance。

img

Visualisation of MTL-NAS

我个人更偏向 Adashare 的搜索方式,在单个网络里逐层搜索,这样的 task-specific representation 已经足够好过将每一个 task 定义成新网络的结果。由此, MTL-NAS 也躲不掉网络参数线性增加的特点,不过对于 MTL 网络设计提供了新思路。

MTL + NAS 和传统的 single-task NAS 会有着不同需求,和训练方式。

  1. MTL+NAS 并不适合用 NAS 里最常见的 two-stage training 方式:以 validation performance 作为 supervision 来 update architecture 参数,得到 converged architecture 后再 re-train 整个网络。因为 MTL 的交融信息具备 training-adaptive 的性质, 因此 fix 网络结构后,这样的 training-adaptive 信息会丢失,得到的 performance 会低于边搜边收敛的 one-stage 方式。换句话说,训练中的 oscillation 和 feature fusion 对于 MTL 网络是更重要的,而在 single task learning 中,并没有 feature fusion 这个概念。这间接导致了 NAS 训练方式的需求不同。
  2. MTL+NAS is task-specific. 在 NAS 训练中,要是 dataset 的 complexity 过大,有时候我们会采用 proxy task 的方式来加快训练速度。最常见的情况则是用 CIFAR-10 作为 proxy dataset 来搜好的网络结构,应用于过大的 ImageNet dataset。而这一方式并不适用于 MTL,因为对于任一任务,或者几个任务的 pair,他们所需要的 feature 信息和任务特性并不同,因此无法通过 proxy task 的方式来加速训练。每一组任务的网络都是独特和唯一的。

我相信在未来 MTL 网络设计的研究中,我们会得到更具备 interpretable/human-understandable 的网络特性,能够理解任务与任务之间的相关性,和复杂性。再通过得到的任务相关性,我们可以作为一个很好 prior knowledge 去 initialise 一个更好的起始网络,而由此得到一个更优秀的模型,一种良性循环。

A Better Task Relationship⟺A Better Multi-task Architecture


Multi-task Learning Loss Function Design / How to learn? [损失函数设计与梯度优化]

平行于网络设计,另外一个较为热门的方向是 MTL 的 loss function design, 或者理解为如何去更好得 update 网络里的 task-specific gradients。

对于任意 task i, 我们有损失函数: L=∑iαiLi , 其中 αi 为 task-specific learning parameters. 那么,我们需要找到一组很好的 αi 来 optimise 所有 task i 的 performance Li 。 其中最为简单且直接的方式则为 equal weighting: αi=1 , 也就是默认每一个 task 对于 representation 的 contribution 是相同的。

  • Weight Uncertainty

Weight Uncertainty 是最早几篇研究 MTL loss function design 的文章之一。这篇文章 assume 在每个 model 里存在一种 data-agnostic task-dependent uncertainty 称之为 Homoscedastic uncertainty。(这种说法其实非常的古怪,只有剑桥组喜欢这么称呼。)而通过 maximise log -likelihood of the model prediction uncertainty 可以来 balance MTL training。这里的 likelihood (通常 parameterised by Gaussian) 可以看做是 relative confidence between tasks。

对于任何 model prediction y, 我们定义 Gaussian likelihood of model: p(y|fw(x))=N(fw(x),σ2) 其中这里的 σ 为 learnable noise scalar (Gaussian variance),那么我们需要 maximise:

log⁡p(y|fw(x))∝−12σ2‖y−fw(x)‖2−log⁡σ,

由此我们可以得到新定义的 loss function:

L=∑i12σi2Li+log⁡σi.

最后推导的公式非常简洁,也因此用在很多 MTL benchmark 里。

然而这篇文章有着非常大的争议,其中最著名的一点是作者对于如此简单的公式却一直拒绝开源,并且无视大量其他 researchers 的邮件对于 implementation 的细节询问,惹怒了不少同行(包括我)。此外,weight uncertainty 非常依赖于 optimiser 的选择,在我个人实验里,我发现有且仅有 ADAM optimiser 可以让 weight uncertainty 正确收敛,而在其他 optimiser 上 weight uncertainty 没有任何收敛趋势。这篇博客 则更为指出,这个 weight uncertainty 公式可以直接得到 closed-form solution:当 learnable σ is minimised, 整个 loss function 将转化成 geometric mean of task losses,因此再次对于这里 uncertainty assumption 可行性提出了质疑。

  • GradNorm

GradNorm 则为另外一篇最早期做 MTL loss function 的文章。 GradNorm 的实现是通过计算 inverse training rate: L~(t)=L(t)/L(0) 下降速率,作为 indicator 来平衡不同任务之间的梯度更新。

我们定义 GW(t) 为 W 参数在 weighted multi-task loss 在 t step 上计算到梯度的 L2-norm; mean task gradient 为 G¯W(t)=E[GiW(t)]; relative inverse training rate 为 ri(t)=Li(t)/E[Li(t)] 。 GradNorm 通过以下 objective 来更新 task-specific weighting:

|GiW(t)−G¯W(t)⋅ri(t)α|

其中 G¯W(t)⋅ri(t)α 则为理想的梯度 L2-norm (作为 constant), 来调整 task-specific weighting. α 作为一个平衡超参, α 越大则 task-specific weighting 越平衡。由于每次计算 GW(t) 需要对所有 task 在每个 layer 进行 backprop,因此非常 computational expensive。由此,作者就以计算最后一层的 shared layer 作为 approximation 来加快训练速度。

  • Dynamic Weight Average

由于 GradNorm 在计算 task-specific weighting 上需要运算两次 backprop 因此在 implementation 上非常复杂。我后续提出了一个简单的方法,只通过计算 loss 的 relative descending rate 来计算 task weighting:

αk(t):=Kexp⁡(wk(t−1)/T)∑iexp⁡(wi(t−1)/T),wk(t−1)=Lk(t−1)Lk(t−2).

这里的 wk 则通过计算两个相邻的 time step 的 loss ratio 作为 descending rate。因此 wk 越小,收敛速率就越大,任务就越简单,得到权重也就越小。

  • MTL as Multi-objective Optimisation

之前介绍的几个 task-weighting 方法都基于一些特定的 heuristic,很难保证在 MTL optimisation 取得 optimum. 在 MTL as Multi-objective Optimisation 里,作者将 MTL 问题看做是多目标优化问题,其目标为取得 Pareto optimum.

Pareto optimum 是指任何对其中一个任务的 performance 变好的情况,一定会对其他剩余所有任务的 performance 变差。作者利用了一个叫 multiple gradient descent algorithm (MGDA) 的方法来寻找这个 Pareto stationary point。大致方式是,在每次计算 task-specific gradients 后,其得到 common direction 来更新 shared parameter s。这个 common direction 如果存在,则整个 optimisation 并未收敛到 Pareto optimum。这样的收敛方法保证了 shared parameter 不会出现 conflicting gradients 让每一个任务的 loss 收敛更加平滑。

  • Other heuristic methods

对于 MTL loss function 的设计,还有其他不同的 heuristic 方法,或基于任务的难易性 (Dynamic Task Prioritization), 或直接对于计算到的任务梯度进行映射,防止出现任务梯度之间的 conflicting gradients 情况 (Gradient Surgery)。

但在各式各样的 MTL loss function design 里,很难出现其中一个方法在所有数据集里都 outperform 其他方法的情况。甚至,在部分数据集里,最简单的 equal task weighting 也表现得较为优异。一方面,task weighting 的有效性非常依赖于 MTL 网络本身的设计;此外 task weighting 的更新也依赖于数据集和 optimiser 本身。假设如果核心目标仅仅是取得最好的 MTL performance,那我建议应该花更多的时间去研究更好的网络而不是 task weighting。但不可否认的是,task weighting 的研究可以更好的帮助人们理解任务之间的相关性和复杂性,以此反过来帮助人们更好的设计模型本身。


Auxiliary Learning – Not all tasks are created equal [辅助学习]

跟 MTL 高度相关的一个方向被称之为 Auxiliary learning (AL, 辅助学习):他的训练过程与 MTL 完全一致。唯一的不同是,在 Auxiliary Learning 里,只有部分任务的 performance 是需要被考虑的 (primary task),其他(辅助)任务 (auxiliary task) 的存在的意义是,帮助那部分需要被考虑的任务学习到更好的 representation。

  • Supervised Auxiliary Learning

Auxiliary Learning 存在的普遍性其实远超于我们的想象。比如,MTL 其实就是一种特殊形式的 AL,我们可以把其中任意一个 task 作为 primary task,其他剩余的 task 看作为 auxiliary task。在 MTL 里,我们默认所有 task 与 task 存在一种 mutual beneficial 的关系,因此所有 learning tasks 都 benefit 到这种相关性。

Auxiliary Learning 还应用在很多领域里,比如 这篇文章 发现在训练 depth 和 normal prediction 的同时,可以有效的帮助 object detection 的精确度。或者 这篇文章 发现在做 short sequence 的重建时,可以帮助 RNN 更有效的训练 very long sequence。

在 supervised auxiliary learning 的 setting 中,整个网络和任务的选择非常依赖于人类的先验知识,并不具备绝对的普遍性。

  • Meta Auxiliary Learning

考虑 supervised auxiliary learning 对于任务选择的局限性,我后续提出了一种基于 meta learning 的方法来自动生成 auxiliary task,我把这种方法称之为 Meta Auxiliary Learning (MAXL)。

在传统的 supervised auxiliary learning,我发现有这样如下两个规律:

  1. 假设 primary 和 auxiliary task 是在同一个 domain,那么 primary task 的 performance 会提高当且仅当 auxiliary task 的 complexity 高于 primary task。
  2. 假设 primary 和 auxiliary task 是在同一个 domain,那么 primary task 的最终 performance 只依赖于 complexity 最高的 auxiliary task。

这里对于 task 的 complexity 的定义比较 tricky 并不 general,目前我只考虑了最简单的图片分类的情况:细分类任务的 complexity 是高于粗分类任务。

比如在下图,我们看到猫,狗两类的分类的信息,直觉上一定低于细分类,约克夏,波斯猫之类更为细节的信息。而分类猫狗所需要的信息是分类细分类的子集,因此得出了规律 2.

img

再由于规律 2,我们只需要考虑最简单的两个任务训练的情况: primary task 和 auxiliary task 各为一个任务。因此在这里,我们”生成“一个好的 auxiliary task的问题,也就可转化为对任意输入图片,我们需要有一个好的网络去生成好的细分类标签。

img

Visualisation of MAXL

因此在 MAXL 框架里,我们有两个网络:一个网络是 multi-task network 类同于 hard parameter sharing 来做 multi-task training。另外一个网络是 label-generation network 来生成细分类标签,给 multi-task network 作为 auxiliary task 的 prediction label。

于是对于 Multi-task network fθ1(⋅) ,我们需要做以下更新:

argminθ1(L(fθ1pri(xtrain(i)),ytrain(i))+L(fθ1aux(xtrain(i)),gθ2(xtrain(i),ytrain(i),ψ))).

对于 label-generation network gθ2(⋅) , 我们需要通过 meta update 的方式做以下更新:

argminθ2L(fθ1+pri(xval(i)),yval(i))

其中,

θ1+=θ1−α∇θ1(L(fθ1pri(xval(i)),yval(i))+L(fθ1aux(xval(i)),gθ2(xval(i),yval(i),ψ)).

这里, θ1+ 代表了 meta update,来通过 maximise validation performance of primary task 来寻找最好的 auxiliary label。这里的 second derivative trick 类同于经典 meta learning 算法 MAML 的梯度更新。

在 label-generation network 里还存在一个 hyper-parameter ψ 代表人类定义的 dataset hierarchy。 假设我们在做简单的二分类,那么 ψ=[2,3] 则意味着将第一个类再细分成 2 类,第二个类再细分成 3 类。那么 label-generation network 就以这样的 hierarchy 通过 masked version of softmax 来生成相应的合适的 auxiliary label。于是 multi-task network 就在进行两个分类任务: primary task 是二分类,auxilairy task 是五分类。

通过 MAXL,我们发现他可以对这样的图片分类任务有着一定的效果提升。我们后续 visualize 生成的标签,发现在一些简单的数据集里有着人类可理解的 clustering 含义。

img

在上图中,上半部分是 CIFAR-100 的分类,下半部分是 MNIST 的分类。其中这三类 auxiliary class 中的图片是通过 label-generation network 生成在这个 class 里 5 个 confidence 最高的图片。

在较为复杂的 CIFAR-100 数据集中,我们很难理解 MAXL 的分类到底在干什么。而在 MNIST 中,我们可以发现 不同粗细的 数字 3,不同方向的 数字 9,有无中间的 horizontal bar 的数字 7 cluster 到了一起。这种有趣的现象开拓了一个新颖的方向,对自动化辅助任务生成的探索。

  • Self-supervised Auxiliary Learning: Learning to X by Y

Auxiliary learning 还可以应用在 self-supervised learning 中,假设对于 primary task 和 auxiliary task 都没有任何 human label。这里的 self-supervised training 跟跟 supervised auxiliary task 一样,是有效得利用了人类先验知识对于部分任务组合的特性。

较为常见的组合是:depth 和 ego-motion 的 prediction 可以在时间上具备 consistency。同样利用 time consistency 的还有 colorisation 和 tracking, feature learning 和 ego-motion.

同样的这样的 consistency 也可以存在基于 RL 的robot learning 里,例如 grasping 和 pushing,或者 robot manipulation

总结

终于到总结了!近些年来 MTL 的研究出现了很多新颖且有价值的工作,但是对于任务自身的理解,和任务之间关系的理解还是有很大的不足和进步空间。在 Taskonomy 里,作者尝试了上千种(大量 CO2 排放)任务的组合来绘制出不同任务之间的关系图。但是真实 MTL 训练中,我相信这种关系图应该随着时间的变化而变化,且依赖网络本身。因此,如何更好得通过任务之间的关系去优化网络结构还是一个未解之谜,如何设计/生成辅助任务并通过 MTL 更好得帮助 primary task 也并未了解透彻。希望在后续的研究中能看到更多文章对于 MTL 的深入探索,实现 universal representation 的最终愿景。

写于五月七日,在疫情笼罩中的伦敦里摸鱼。

转载只需注明作者和来源即可,无需私信确认。

编辑于 2020-05-08 02:07

Multi-task Learning and Beyond: 过去,现在与未来 - 知乎 (zhihu.com)


multi-task learning
http://csgituser.github.io/2023/05/08/multi-task-learning/
Author
Museum
Posted on
May 8, 2023
Licensed under