diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index da7cdc7f8f525bbf039a64c25fae9f4dcf5ac39c..2d0d159c5eea69af2abd06426bcb31c2d8fa0c49 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -287,6 +287,74 @@ class PartialProgramLayer: return main_program + def prepare_gradient_aggregation(self, main_program, target_program): + """ + Why we need add gradient aggregation operation ? + In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as + def forward(self, in): + x = 2 * in # <---- x is a non-leaf node in program. + y = x + 3 + return x, y + + loss = forward(in)[0].sum() + loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op + """ + + def _need_aggregation(var): + """ + if exist a op whose inputs is var, then return True + """ + if not isinstance(var, framework.Variable) or var.type not in [ + core.VarDesc.VarType.LOD_TENSOR, + core.VarDesc.VarType.SELECTED_ROWS + ]: + return False + if var.dtype not in [paddle.float32, paddle.float64]: + return False + for op in main_program.block(0).ops: + for in_arg in op.input_arg_names: + if in_arg == var.name: + return True + return False + + def _insert_aggregation_ops_for_var(target_program, var): + suffix = "@dy2static" + var_grad_name = var.grad_name + new_grad_name = var.name + suffix + "@GRAD" + finded_ops = list( + filter( + lambda x: any([ + out_arg == var_grad_name + for out_arg in x[1].output_arg_names + ]), enumerate(target_program.block(0).ops))) + + # len(finded_ops) may equals zero when stop_gradient works. + # len(finded_ops) may > 1, because we may have fill_constant op. + if len(finded_ops) == 0: + return None + # step1: create a new var named var.name@GRAD + target_program.block(0).create_var(name=new_grad_name, + type=var.type, + dtype=var.dtype, + shape=var.shape) + # step2: rename the var.name@GRAD to var.name@GRAD@dy2static + for idx, op in finded_ops: + op._rename_input(var_grad_name, new_grad_name) + op._rename_output(var_grad_name, new_grad_name) + # step3: insert sum op to aggregate the gradient. + # var.name@GRAD = sum(var.name@dy2static@GRAD, var.name@GRAD) + target_program.block(0)._insert_op( + finded_ops[-1][0] + 1, + type='sum', + inputs={'X': [var_grad_name, new_grad_name]}, + outputs={"Out": var_grad_name}) + return None + + to_processed_vars = list( + filter(_need_aggregation, self._outputs.tolist())) + for _var in to_processed_vars: + _insert_aggregation_ops_for_var(target_program, _var) + @switch_to_static_graph def _append_backward_desc(self, main_program): # make sure all status of is_test are False in train mode. @@ -299,6 +367,8 @@ class PartialProgramLayer: if targets and self._params: backward.gradients(targets=targets, inputs=[]) + self.prepare_gradient_aggregation(main_program, program) + return program def _prune_unused_params(self, program): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradient_aggregation.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradient_aggregation.py new file mode 100644 index 0000000000000000000000000000000000000000..3b7cca31ce989fa0ae87debd2ed2f6227e93f354 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_gradient_aggregation.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import paddle +import numpy as np + +SEED = 2020 +np.random.seed(SEED) + + +class SimpleNet(paddle.nn.Layer): + + def __init__(self): + super(SimpleNet, self).__init__() + self.linear1 = paddle.nn.Linear(10, 3) + self.linear2 = paddle.nn.Linear(3, 1) + + def forward(self, x): + out1 = self.linear1(x) + out2 = self.linear2(out1) + return [out1, out2] # 梯度为0 + #return [out1] # 梯度正常 + #return [out2, out1] # 梯度正常 + + +class TestGradientAggregationInDy2Static(unittest.TestCase): + + def test_to_static(self): + + def simplenet_grad(inp, to_static=False): + net = SimpleNet() + if to_static: net = paddle.jit.to_static(net) + loss = net(inp) + loss[0].backward() + return net.linear1.weight.grad + + inp = paddle.to_tensor(np.random.randn(10, )).astype("float32") + self.assertTrue( + np.allclose( + simplenet_grad(inp, True).numpy(), + simplenet_grad(inp, False).numpy())) + + +if __name__ == '__main__': + unittest.main()