未验证 提交 f694e991 编写于 作者: X xiongkun 提交者: GitHub

prepare_gradient_aggregation for non-leaf output of PartialProgramLayer (#44893)

* 1. add prepare_gradient_aggregation in PartialProgramLayer

* 1. draft

* fix ci problem
上级 68b06ba6
......@@ -287,6 +287,74 @@ class PartialProgramLayer:
return main_program
def prepare_gradient_aggregation(self, main_program, target_program):
"""
Why we need add gradient aggregation operation ?
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
def forward(self, in):
x = 2 * in # <---- x is a non-leaf node in program.
y = x + 3
return x, y
loss = forward(in)[0].sum()
loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op
"""
def _need_aggregation(var):
"""
if exist a op whose inputs is var, then return True
"""
if not isinstance(var, framework.Variable) or var.type not in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS
]:
return False
if var.dtype not in [paddle.float32, paddle.float64]:
return False
for op in main_program.block(0).ops:
for in_arg in op.input_arg_names:
if in_arg == var.name:
return True
return False
def _insert_aggregation_ops_for_var(target_program, var):
suffix = "@dy2static"
var_grad_name = var.grad_name
new_grad_name = var.name + suffix + "@GRAD"
finded_ops = list(
filter(
lambda x: any([
out_arg == var_grad_name
for out_arg in x[1].output_arg_names
]), enumerate(target_program.block(0).ops)))
# len(finded_ops) may equals zero when stop_gradient works.
# len(finded_ops) may > 1, because we may have fill_constant op.
if len(finded_ops) == 0:
return None
# step1: create a new var named var.name@GRAD
target_program.block(0).create_var(name=new_grad_name,
type=var.type,
dtype=var.dtype,
shape=var.shape)
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static
for idx, op in finded_ops:
op._rename_input(var_grad_name, new_grad_name)
op._rename_output(var_grad_name, new_grad_name)
# step3: insert sum op to aggregate the gradient.
# var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD)
target_program.block(0)._insert_op(
finded_ops[-1][0] + 1,
type='sum',
inputs={'X': [var_grad_name, new_grad_name]},
outputs={"Out": var_grad_name})
return None
to_processed_vars = list(
filter(_need_aggregation, self._outputs.tolist()))
for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var)
@switch_to_static_graph
def _append_backward_desc(self, main_program):
# make sure all status of is_test are False in train mode.
......@@ -299,6 +367,8 @@ class PartialProgramLayer:
if targets and self._params:
backward.gradients(targets=targets, inputs=[])
self.prepare_gradient_aggregation(main_program, program)
return program
def _prune_unused_params(self, program):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import numpy as np
SEED = 2020
np.random.seed(SEED)
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.linear1 = paddle.nn.Linear(10, 3)
self.linear2 = paddle.nn.Linear(3, 1)
def forward(self, x):
out1 = self.linear1(x)
out2 = self.linear2(out1)
return [out1, out2] # 梯度为0
#return [out1] # 梯度正常
#return [out2, out1] # 梯度正常
class TestGradientAggregationInDy2Static(unittest.TestCase):
def test_to_static(self):
def simplenet_grad(inp, to_static=False):
net = SimpleNet()
if to_static: net = paddle.jit.to_static(net)
loss = net(inp)
loss[0].backward()
return net.linear1.weight.grad
inp = paddle.to_tensor(np.random.randn(10, )).astype("float32")
self.assertTrue(
np.allclose(
simplenet_grad(inp, True).numpy(),
simplenet_grad(inp, False).numpy()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册