Created by: chenwhql
当StaticModelRunner处理一些较为复杂的静态图预训练模型时,会存在预训练模型有多个fetch输出的情况。在多个fetch输出的场景中,有一种特殊的case,模型的多个输出可能在同一个分支上,如下图所示,示例取自PaddleHub ernie
将这多个输出作为StaticModuleRunner的输出进行处理时,会通过fluid.gradients
为载入的预训练预测program添加反向,写法类似
fluid.gradients(targets=[sequqnce_out, pooled_out])
注意此时的网络,由于这两个输出在同一分支上,gradients
生成反向的时候,并不知道后续这两个输出后续使用的状况,中间的输出节点仅仅是一个正常的中间节点而已
但是之后用户可能仅使用其中一个输出,并仅为这个输出添加后续的损失计算逻辑,如果用户只为中间的变量添加了后续op,网络将变成如下状态
这里网络的性质发生改变,两个输出其实关联了两个分支,这些后续添加的操作在StaticModelRunner初始化时是不可知的,这导致后续在sequence_out
节点上需要的梯度累加操作,在StaticModelRunner里面添加反向时并未被添加
后续计算时,pooled_out
没有使用,梯度为0,且不会在sequence_out
节点累加,导致后续一些利梯度都为0,参数无法正确更新
所以这里提前为这种情况增加scale操作,提前让StaticModelRunner将这里识别为两个分支,确保反向添加时,梯度累加操作不会遗漏,修改后的初始网络为
-
补充1:这种处理方法可以解决当前问题,但是在执行上和静态图是有差别的,这里保留了无用梯度反向传播的op,理论上性能不如直接使用原静态图训练。在静态图中,由于网络在最后一次性添加反向,这种需要可以被正确识别,并进行合理的裁剪
-
补充2:这里为什么不在forward执行的时候根据用户需求,动态调整网络结构呢?
- 这需要用户添加额外声明,比如一定要写明
pooled_out.stop_gradient=True
,否则动态图没有全局网络,StaticModelRunner无法感知到外部需求 - forward是每个step都需要被执行的操作,在这里动态修改网络,会在循环中引入额外开销,同样导致性能下降,复杂度提高
- 后续可以详细考虑下如何设计此处的实现
- 这需要用户添加额外声明,比如一定要写明