From b15734315bc4d5a72c603c8583ad242a4b7e678b Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Thu, 17 Aug 2023 17:10:52 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90New=20IR=E3=80=91=20backward=20gradien?= =?UTF-8?q?ts=20accumulate=20test=20and=20pulish=20append=5Fbackward=5Fops?= =?UTF-8?q?=20func=20for=20op=5Fpattern=20=20(#56265)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [prim][newir] add basic framework for primitive * support desctensor in new ir * add vjp interface * support vjp in new ir * support vjp in new ir * polish vjp interface * fix stop_gradients set * fix vjp dispatch * add comment * add vjp test for new ir * add test for tanh vjp * [prim][newir] add basic framework for primitive * support desctensor in new ir * support vjp in new ir * support vjp in new ir * polish vjp interface * fix stop_gradients set * fix vjp dispatch * add comment * add vjp test for new ir * add test for tanh vjp * add eager and static backend for warp lower level api * support call_vjp pybind * polish code and add test for vjp * remove useless code * polish code * remove useless code * support mean vjp * backward origin code * add test for mean vjp and support has_vjp function * fix call_vjp * polish code * add attrs and dtype interface * add primitive ops set for backend * fix compile bugs * fix some bugs * fix windows bugs * add vjp test for tanh_ * fix inference CI * fix inference ci * modify fluid cmake * origin test of tanh and mean passed * fix conflict * modify stop_gradient * remove useless deps * add cmake * modify block.ops * modify test * fix conflict * reply review comments * reply review comments * pulish code * fix comment * fix test * polish code * modify backward stop_gradients * modify static_backend.cc * refactor grad_op * support add and add_inplace vjp * remove useless code * remove useless code * remove cout * modify add_n * modify add_n with add_vjp test * modify add_n with add_vjp test * fix conflict and concat call_vjp * modify backward test * Add more gen api --------- Co-authored-by: cxxly Co-authored-by: Charles-hit Co-authored-by: zhangbo9674 Co-authored-by: YuanRisheng Co-authored-by: 0x45f --- python/paddle/autograd/backward.py | 214 +++++++++++++++++------ python/paddle/autograd/backward_utils.py | 6 +- test/ir/new_ir/test_ir_backward.py | 87 +++++---- 3 files changed, 218 insertions(+), 89 deletions(-) diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index 6ea1c491d4a..671182f7c30 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -181,7 +181,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): ''' relevant_op_flags = [True] * len(total_ops) - # from input to output if inputs_set: for i, op in enumerate(total_ops): @@ -192,7 +191,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): for value in op.results(): if value not in no_grad_set: inputs_set.add(value) - else: relevant_op_flags[i] = False @@ -313,22 +311,50 @@ def append_backward_ops( v2_g = call_vjp(op3, [v3_g], [v2_stopgradient]) + special pattern 1: + v11 -> combine_op -> v1 -> op -> v3 + v12 -> + v2 -> + value_to_valuegrad[v3] = [[v3_g]] + + v1 is inside python api, we don't describe it in backward process(state) + so v1_grad is inside vjp, we don't describe it in backward process(state) + [[v11_g, v12_g], v2_g] = call_vjp(combine_op, [v3_g], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient) + + + op_vjp is: + v11_g <- split_op <- v1_g <- op_g <- v3_g + v12_g <- + v2_g <- + + value_to_valuegrad[v11] = [[v11_g]] + value_to_valuegrad[v12] = [[v12_g]] + value_to_valuegrad[v2] = [[v2_g]] + if op don't has grad_op, if it don't has input and it's output has more than one output_grad, add sumop for grad aggregation. (eg: full op and get_parameter op etc.) else continue to next op. ''' - for op in effective_forward_op: - if paddle.framework.core.has_vjp(op): - # prepare output_grad - output_grad_list = [] # (opresult) - zero_flag = [False] * op.num_results() - for i, value in enumerate(op.results()): - if ( - value not in state.value_to_valuegrad - or state.value_to_valuegrad[value] is None - ): + + def make_output_grad(op, split_op): + zero_flag = [False] * op.num_results() + for i, value in enumerate(op.results()): + if ( + value not in state.value_to_valuegrad + or state.value_to_valuegrad[value] is None + ): + if split_op is not None and value == split_op.operand_source(0): + # pattern case: + # this fwd_op's output is vectorType, it will split to + # Type by builtin.split op, so need get from split op's ouput + split_zero_flag, split_output_grad = make_output_grad( + split_op, None + ) + zero_flag[i] = all(split_zero_flag) + grad_value = [op_list[0] for op_list in split_output_grad] + else: # first case: # this fwd_op's output didn't used by other fwd_op, # so no output_grad created. @@ -336,7 +362,6 @@ def append_backward_ops( # second case: # last bwd_op return None because input in no_grad_set, # but this bwd_op need a input. - grad_value = paddle.full( value.shape, 0.0, @@ -347,25 +372,103 @@ def append_backward_ops( update_bwdop_structure( backward_ops, state.op_to_opgrad[op], fillop ) - state.value_to_valuegrad[value] = [[grad_value]] zero_flag[i] = True - if len(state.value_to_valuegrad[value]) > 1: - # one value is input of more than one fwd_op, - # so more than one bwd_op create input_grad, - # need add sum op to accumulate gradient + state.value_to_valuegrad[value] = [[grad_value]] - paddle.add_n(list(state.value_to_valuegrad[value])) - sumop = block.ops[len(block.ops) - 1] - update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], sumop - ) - state.value_to_valuegrad[value] = [[sumop.result(0)]] - state.value_to_sumvaluegrad[ - value - ] = state.value_to_valuegrad[value] + if len(state.value_to_valuegrad[value]) > 1: + # one value is input of more than one fwd_op, + # so more than one bwd_op create input_grad, + # need add sum op to accumulate gradient + + paddle.add_n( + [item[0] for item in state.value_to_valuegrad[value]] + ) + combineop = block.ops[len(block.ops) - 2] + sumop = block.ops[len(block.ops) - 1] + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], combineop + ) + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], sumop + ) + state.value_to_valuegrad[value] = [[sumop.result(0)]] + state.value_to_sumvaluegrad[value] = state.value_to_valuegrad[ + value + ] + + output_grad = state.value_to_valuegrad[value][0] + return zero_flag, output_grad + + def make_input_stopgradient(combine_op, op): + input_grad_stopgradient_list = [] + for input in op.operands_source(): + if combine_op is not None and input == combine_op.result(0): + stop_gradient = make_input_stopgradient(None, combine_op) + input_grad_stopgradient_list.append( + [info[0] for info in stop_gradient] + ) + else: + if input in no_grad_set: + input_grad_stopgradient_list.append([True]) + else: + input_grad_stopgradient_list.append([False]) + + return input_grad_stopgradient_list + + def update_input_grad_map(combine_op, op, input_grad_list): + for i, input in enumerate(op.operands_source()): + if combine_op is not None and input == combine_op.reslut(0): + update_input_grad_map(None, combine_op, input_grad_list[i]) + else: + input_grad = input_grad_list[i] + if isinstance(input_grad, list): + state.value_to_valuegrad[input].append(input_grad) + else: + state.value_to_valuegrad[input].append([input_grad]) + + # make op to op pattern, there are four patterns: + # [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType) + # [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType) + # [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType) + # [op4] (op4's inputs and outputs are not vectorType) + # einsum has twp vectorType outputs, special pattern + + pattern_effective_op_list = [] + for idx, op in enumerate(effective_forward_op): + if op.name() == "builtin.combine": + pattern_effective_op_list.append([op]) + pattern_effective_op_list[-1].append(effective_forward_op[idx + 1]) + elif op.name() == "builtin.split": + pattern_effective_op_list[-1].append(op) + else: + if ( + not pattern_effective_op_list + or op not in pattern_effective_op_list[-1] + ): + pattern_effective_op_list.append([op]) + + for op_pattern in pattern_effective_op_list: + combine_op = None + split_op = None + if len(op_pattern) == 1: + op = op_pattern[0] + elif len(op_pattern) == 2: + if op_pattern[0] == 'builtin.combine': + combine_op = op_pattern[0] + op = op_pattern[1] + else: + op = op_pattern[0] + split_op = op_pattern[1] + else: + combine_op = op_pattern[0] + op = op_pattern[1] + split_op = op_pattern[2] - output_grad = state.value_to_valuegrad[value][0] + if paddle.framework.core.has_vjp(op): + # prepare output_grad + output_grad_list = [] # (opresult) + zero_flag, output_grad = make_output_grad(op, split_op) output_grad_list.append(output_grad) # all(zero_flag) support this op has no contribution for grad @@ -374,42 +477,42 @@ def append_backward_ops( continue # prepare input_grad stop_gradient info. - input_grad_stopgradient_list = [] - for input in op.operands_source(): - if input in no_grad_set: - input_grad_stopgradient_list.append([True]) - else: - input_grad_stopgradient_list.append([False]) + input_grad_stopgradient_list = make_input_stopgradient( + combine_op, op + ) + # create grad_op before_ops_num = len(block.ops) - # prim should be a globel flag, it will make create_grad_op choose diffrient func input_grad_list = paddle.framework.core.call_vjp( op, output_grad_list, input_grad_stopgradient_list ) after_ops_num = len(block.ops) - # find new grad_op_list - grad_op_list = [] + # update grad_op structure for i in range(before_ops_num, after_ops_num): - grad_op_list.append(block.ops[i]) - - for i, input in enumerate(op.operands()): - input_grad = input_grad_list[i] - state.value_to_valuegrad[input.source()].append(input_grad) - - # add grad_op - for grad_op in grad_op_list: update_bwdop_structure( - backward_ops, state.op_to_opgrad[op], grad_op + backward_ops, state.op_to_opgrad[op], block.ops[i] ) + # update input_grad map + update_input_grad_map(combine_op, op, input_grad_list) + else: if op.num_operands() == 0 and op.num_results() != 0: for value in op.results(): if len(state.value_to_valuegrad[value]) > 1: # need add sum op - paddle.add_n(list(state.value_to_valuegrad[value])) + paddle.add_n( + [ + item[0] + for item in state.value_to_valuegrad[value] + ] + ) + combineop = block.ops[len(block.ops) - 2] sumop = block.ops[len(block.ops) - 1] + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], combineop + ) update_bwdop_structure( backward_ops, state.op_to_opgrad[op], sumop ) @@ -441,7 +544,7 @@ def create_backward_purne_set(inputs, outputs, no_grad_set, state): for key in state.value_to_sumvaluegrad: if key in no_grad_set: - for item in state.value_to_valuegrad[key][0]: + for item in state.value_to_sumvaluegrad[key][0]: no_gradvar_set.add(item) return outputs_set, inputs_set, no_gradvar_set @@ -457,10 +560,14 @@ def remove_op(block, op, state): state.op_to_opgrad[fwd_op].remove(op) for valuegrad in op.results(): - value = state.valuegrad_to_value[valuegrad][0] - state.value_to_valuegrad[value] = [] - if value in state.sumvaluegrad_to_value: - raise ValueError('input_grad in [%s] is value which need to sum ') + if state.valuegrad_to_value[valuegrad] != []: + value = state.valuegrad_to_value[valuegrad][0] + state.value_to_valuegrad[value] = [] + + if value in state.sumvaluegrad_to_value: + raise ValueError( + 'input_grad in [%s] is value which need to sum ', op.name() + ) def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): @@ -500,13 +607,14 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): _, remove_ops = prune_ops( backward_ops, inputs_set, outputs_set, no_gradvar_set ) - state.turn_map() for bwd_op in inverse_sort_op(remove_ops): remove_op(block, bwd_op, state) input_grad_map = state.value_to_valuegrad + + state.turn_map() return input_grad_map diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py index 8febb5c6dad..2a8fa240474 100644 --- a/python/paddle/autograd/backward_utils.py +++ b/python/paddle/autograd/backward_utils.py @@ -7,7 +7,7 @@ # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, +# distributed under the License is distributed on an "AS IS" BASIS,tes # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. @@ -37,6 +37,10 @@ class State: self.opgrad_to_op = collections.defaultdict(list) def turn_map(self) -> None: + self.valuegrad_to_value = collections.defaultdict(list) + self.sumvaluegrad_to_value = collections.defaultdict(list) + self.opgrad_to_op = collections.defaultdict(list) + for k, v in self.value_to_valuegrad.items(): if v != []: for value in v[0]: diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py index 2cd375468f2..e6b47bbcd10 100644 --- a/test/ir/new_ir/test_ir_backward.py +++ b/test/ir/new_ir/test_ir_backward.py @@ -35,25 +35,8 @@ def get_ir_program_0(): return newir_program -def get_ir_program_1(): - x = paddle.randn([2, 2]) - main_program, start_program = ( - paddle.static.Program(), - paddle.static.Program(), - ) - with paddle.static.program_guard(main_program, start_program): - x_s = paddle.static.data('x', [4, 4], x.dtype) - y_s = paddle.static.data('y', [4, 4], x.dtype) - x_s.stop_gradient = False - z_x = paddle.tanh(y_s) - k_s = paddle.tanh(x_s) - out = paddle.add(z_x, k_s) - newir_program = ir.translate_to_new_ir(main_program.desc) - return newir_program - - -class TesBackward(unittest.TestCase): - def test_1(self): +class TesBackward_1(unittest.TestCase): + def test_grad(self): newir_program = get_ir_program_0() input = newir_program.block().ops[-1].operand(0).source() tanh_out = newir_program.block().ops[-1].result(0) @@ -63,7 +46,6 @@ class TesBackward(unittest.TestCase): out2 = paddle.mean(tanh_out) input_grad = grad(out, input, out2) - print(newir_program) self.assertEqual(out.get_defining_op().name(), "pd.mean") self.assertEqual(input_grad[0].get_defining_op().name(), "pd.tanh_grad") self.assertEqual( @@ -76,7 +58,7 @@ class TesBackward(unittest.TestCase): ) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) - def test_2(self): + def test_full(self): # test create output_grad in backward use full op newir_program = get_ir_program_0() input = newir_program.block().ops[-1].operand(0).source() @@ -86,7 +68,6 @@ class TesBackward(unittest.TestCase): out = paddle.mean(tanh_out) input_grad = grad(out, input) - print(newir_program) self.assertEqual(newir_program.block().ops[-3].name(), "pd.full") self.assertEqual(input_grad[0].get_defining_op().name(), "pd.tanh_grad") self.assertEqual( @@ -100,19 +81,55 @@ class TesBackward(unittest.TestCase): ) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) - # TODO(Ruting) test add_n op when add_n api and add_grad finished - # def test_3(self): - # # test add_n op - # newir_program = get_ir_program_1() - # input = newir_program.block().ops[-1].operand(0).source() - # tanh_out = newir_program.block().ops[-1].result(0) - # paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) - # with paddle.ir.core.program_guard(newir_program): - # out = paddle.mean(tanh_out) - # input_grad = grad(out, input) - - # print(newir_program) - # self.assertEqual(newir_program.block().ops[-1].name(), "pd.add_n") + def test_no_grad_set(self): + # test create output_grad in backward use full op + newir_program = get_ir_program_0() + input = newir_program.block().ops[-1].operand(0).source() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.mean(tanh_out) + input_grad = grad(out, input, no_grad_vars=[input]) + + self.assertEqual(newir_program.block().ops[-1].name(), "pd.mean") + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + +def get_ir_program_1(): + x = paddle.randn([2, 2]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + y_s = paddle.static.data('y', [4, 4], x.dtype) + x_s.stop_gradient = False + y_s.stop_gradient = False + + k_s = paddle.tanh(x_s) + z_x = paddle.tanh(x_s) + out = paddle.add(z_x, k_s) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TesBackward_2(unittest.TestCase): + def test_add_n(self): + # test add_n op + newir_program = get_ir_program_1() + input_x = newir_program.block().ops[-3].operand(0).source() + + add_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.mean(add_out) + input_grad = grad(out, input_x) + + self.assertEqual(newir_program.block().ops[-1].name(), "pd.add_n") + self.assertEqual( + newir_program.block().ops[-2].name(), "builtin.combine" + ) if __name__ == "__main__": -- GitLab