Created by: Aurelius84
背景
在控制流中(如cond、switch等),如果某些分支不会实际运行,则为了保证框架兼容性,对此分支内的Variable
的反向grad
进行了值为0.的初始化。避免由于反向grad
未初始化导致输入到其他op时,在runtime infershape时挂掉,尤其对于优化器op。
相关PR:
但是,这种方式,会带来性能上的损失。因为控制流的False分支里的所有param
都会对应一个优化器Op,并进行了梯度值为0的更新计算。
目标
因为,此PR旨在对optimizer ops进行优化,支持未初始化grad(参考sum_op),以避免优化器op的Compute
函数调用。
方案
修改optimizer op里的infershape和Compute函数,对grad是否初始化进行判断。若未初始化,则直接return。
目前仅对SGD
和Adam
两个op进行升级测试。
存在问题
- 如果fetch 控制流False分支的反向grad值,则返回为
None
,而非之前的全0值 - 需要sum op支持全未初始化的Tensor Array。由于在optimizer.py中会对部分同名grad作梯度聚合,有可能输入sum的所有同名梯度都是未初始化的(目前sum的输入至少一个是初始化的)
- 允许在runtime期间grad未初始化,可能存在输入到其他op的情况(如assign,cast等),需评估是否需要同步修改此类op支持未初始化tensor。
- 可能需要给出更加优雅的兼容方式,比如在optimizer.py对False分支的参数梯度进行剪枝