From 0399b39f08177c59ecd29f7f58078b9b7fbe91ea Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Wed, 9 Aug 2023 19:29:08 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90New=20IR=E3=80=91backward=20code=20of?= =?UTF-8?q?=20new=20ir=20(#55957)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [prim][newir] add basic framework for primitive * support desctensor in new ir * add vjp interface * support vjp in new ir * support vjp in new ir * polish vjp interface * fix stop_gradients set * fix vjp dispatch * add comment * add vjp test for new ir * add test for tanh vjp * [prim][newir] add basic framework for primitive * support desctensor in new ir * support vjp in new ir * support vjp in new ir * polish vjp interface * fix stop_gradients set * fix vjp dispatch * add comment * add vjp test for new ir * add test for tanh vjp * add eager and static backend for warp lower level api * support call_vjp pybind * polish code and add test for vjp * remove useless code * polish code * remove useless code * support mean vjp * backward origin code * add test for mean vjp and support has_vjp function * fix call_vjp * polish code * add attrs and dtype interface * add primitive ops set for backend * fix compile bugs * fix some bugs * fix windows bugs * add vjp test for tanh_ * fix inference CI * fix inference ci * modify fluid cmake * origin test of tanh and mean passed * fix conflict * modify stop_gradient * modify block.ops * modify test * fix conflict * reply review comments * reply review comments * pulish code --------- Co-authored-by: cxxly Co-authored-by: Charles-hit Co-authored-by: zhangbo9674 Co-authored-by: YuanRisheng --- paddle/fluid/pybind/ir.cc | 17 +- python/paddle/autograd/backward.py | 650 +++++++++++++++++++++++ python/paddle/autograd/backward_utils.py | 50 ++ python/paddle/decomposition/decomp.py | 2 +- test/ir/new_ir/test_build_op.py | 6 +- test/ir/new_ir/test_ir_backward.py | 109 ++++ test/ir/new_ir/test_ir_pybind.py | 32 +- test/ir/new_ir/test_ir_vjp.py | 24 +- test/prim/new_ir_prim/test_decomp_op.py | 2 +- 9 files changed, 851 insertions(+), 41 deletions(-) create mode 100644 python/paddle/autograd/backward.py create mode 100644 python/paddle/autograd/backward_utils.py create mode 100644 test/ir/new_ir/test_ir_backward.py diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index b55906b67ed..2081d327a23 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -82,14 +82,15 @@ void BindBlock(py::module *m) { block.def("front", &Block::front, return_value_policy::reference) .def("get_parent_program", [](Block &self) { return self.GetParentOp()->GetParentProgram(); }) - .def("get_ops", - [](Block &self) -> py::list { - py::list op_list; - for (auto iter = self.begin(); iter != self.end(); iter++) { - op_list.append(*iter); - } - return op_list; - }) + .def_property_readonly( + "ops", + [](Block &self) -> py::list { + py::list op_list; + for (auto iter = self.begin(); iter != self.end(); iter++) { + op_list.append(*iter); + } + return op_list; + }) .def("remove_op", [](Block &self, Operation *op) { auto op_iter = std::find(self.begin(), self.end(), op); self.erase(op_iter); diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py new file mode 100644 index 00000000000..ba9d8a7a3f2 --- /dev/null +++ b/python/paddle/autograd/backward.py @@ -0,0 +1,650 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from collections.abc import Sequence + +import paddle.ir +from paddle.autograd.backward_utils import State + +""" + grad: for templete test, will combine in paddle.grad . + calc_gradient: for internal use, optest, parallel etc . + calc_gradient_helper: for dygraph to static . +""" +__all__ = ['grad', 'calc_gradient', 'calc_gradient_helper'] + + +def check_type(input, input_name, expected_type, op_name, extra_message=''): + if not isinstance(input, expected_type): + raise TypeError( + f"The type of '{input_name}' in {op_name} must be {expected_type}, but received {type(input)}. {extra_message}" + ) + + +def _as_list(x): + if x is None: + return [] + return list(x) if isinstance(x, Sequence) else [x] + + +def check_all_puts(block, inputs, outputs): + for output in outputs: + if output.get_defining_op().get_parent_block() != block: + raise ValueError("all outputs must be in the same block") + for input in inputs: + if input.get_defining_op().get_parent_block() != block: + raise ValueError( + "all inputs must be in the same block with outputs" + ) + + +def update_no_grad_set_by_stopgradient(block, no_grad_set): + for op in block.ops: + for opresult_idx in range(op.num_results()): + value = op.result(opresult_idx) + if value.stop_gradient and value not in no_grad_set: + no_grad_set.add(value) + + +def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op): + backward_ops.append(grad_op) + op_to_opgrad_list.append(grad_op) + + +def prepare_grad_outputs( + block, grad_outputs, outputs, value_to_valuegrad, op_to_opgrad +): + """ + if grad_outputs is none, add fill_1 op to create grad_outputs, + else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error. + + if only part of op's outputs in outputs, add fill_0 op to create other grad_outputs. + eg: split. + + update value_to_valuegrad and op_to_opgrad. + + return complete_outputs and complete_gradoutputs, backward_ops. + + """ + if not grad_outputs: + grad_outputs = [None] * len(outputs) + + if len(grad_outputs) != len(outputs): + raise ValueError( + "grad_outputs should have the same length of as outputs." + ) + backward_ops = [] + for i, grad in enumerate(grad_outputs): + output = outputs[i] + # fwd : op1 -> op2 -> op3 -> output + # bwd : op1G <- op2G <- op3G <- outputG <- fillop/feedop + if grad is None: + output_grad = paddle.full( + output.shape, + 1.0, + dtype=output.dtype, + ) + fillop = output_grad.get_defining_op() + + update_bwdop_structure( + backward_ops, + op_to_opgrad[output.get_defining_op()], + fillop, + ) + value_to_valuegrad[output] = [[output_grad]] + else: + if output.shape != grad.shape: + raise ValueError( + "The shape of grad_output[%d] should be the same as the shape of output[%d]" + % (i, i) + ) + if output.dtype != grad.dtype: + raise ValueError( + "The dtype of grad_output[%d] should be the same as the dtype of output[%d]" + % (i, i) + ) + feedop = grad.get_defining_op() + update_bwdop_structure( + backward_ops, op_to_opgrad[output.get_defining_op()], feedop + ) + value_to_valuegrad[output] = [[grad]] + + # add input for bwd first op + complete_outputs = outputs + complete_gradoutputs = grad_outputs + + visited_output = set() + for output in outputs: + if output in visited_output: + continue + for opresult in output.get_defining_op().results(): + if opresult in value_to_valuegrad: + visited_output.add(opresult) + continue + else: + grad_value = paddle.full( + opresult.shape, + 0.0, + opresult.dtype, + ) + fillop = grad.get_defining_op() + + update_bwdop_structure( + backward_ops, + op_to_opgrad[opresult.get_defining_op()], + fillop, + ) + value_to_valuegrad[opresult] = [grad_value] + + visited_output.add(opresult) + + complete_outputs.append(opresult) + complete_gradoutputs.append(grad_value) + + return complete_outputs, complete_gradoutputs, backward_ops + + +def some_in_set(value_list, value_set): + def operand2value(values): + value_set = set() + for item in values: + if isinstance(item, paddle.ir.OpOperand): + value_set.add(item.source()) + else: + value_set.add(item) + return value_set + + if operand2value(value_list) & operand2value(value_set): + return True + else: + return False + + +def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): + ''' + prune ops which do not in the path from inputs_set to outputs_set, + prune ops which do not in the path from outputs_set to inputs_set, + + pruned op in total_ops is uneffective_ops, else is effective_ops + + ''' + relevant_op_flags = [True] * len(total_ops) + + # from input to output + if inputs_set: + for i, op in enumerate(total_ops): + if some_in_set(op.results(), inputs_set): + continue + + if some_in_set(op.operands_source(), inputs_set): + for value in op.results(): + if value not in no_grad_set: + inputs_set.add(value) + + else: + relevant_op_flags[i] = False + + # from output to input + for i, op in reversed(list(enumerate(total_ops))): + # while op support + if some_in_set(op.results(), outputs_set): + for operand in op.operands_source(): + if operand not in no_grad_set: + outputs_set.add(operand) + else: + relevant_op_flags[i] = False + + effective_ops = [ + total_ops[i] for i in range(len(total_ops)) if relevant_op_flags[i] + ] + uneffective_ops = [ + total_ops[i] + for i in reversed(range(len(total_ops))) + if not relevant_op_flags[i] + ] + + return effective_ops, uneffective_ops + + +def update_no_grad_set_after_purne( + block, effective_forward_op, no_grad_set, inputs, outputs +): + ''' + update no_grad_set after forward purne + + from inputs to outputs add value not in the path to no_grad_set, + from outputs to inputs add value not in the path to no_grad_set, + ''' + inputs_set = set(inputs) + if inputs_set: + for op in block.ops: + if some_in_set(op.operands_source(), inputs_set): + for value in op.results(): + if value not in no_grad_set: + inputs_set.add(value) + + for op in effective_forward_op: + for value in op.operands_source(): + if value not in inputs_set: # and value.get_stopgradient(): + no_grad_set.add(value) + + outputs_set = set(outputs) + no_grad_set_tmp = set() + for op in reversed(effective_forward_op): + for output in op.results(): + if output not in outputs_set and not some_in_set( + [output], set(op.operands_source()) + ): + no_grad_set_tmp.add(output) + + for input in op.operands_source(): + if input not in no_grad_set: + outputs_set.add(input) + + no_grad_set.update(no_grad_set_tmp) + + +def inverse_sort_op(ops): + ''' + if topo graph is op1 -> op2 -> op3 + return [op3, op2, op1] + + ''' + + # init pending_count[op] which descibes number of + # pending edges for its grad_op + + pending_count = collections.defaultdict(int) + ops_set = set(ops) + sorted_list = [] + for op in ops: + for x in op.operands(): + if x.source().get_defining_op() in ops_set: + pending_count[x.source().get_defining_op()] += 1 + + queue = collections.deque() + + for op in ops: + if pending_count[op] == 0: + queue.append(op) + + while queue: + op = queue.popleft() + sorted_list.append(op) + + for x in op.operands(): + x_op = x.source().get_defining_op() + pending_count[x_op] -= 1 + if pending_count[x_op] == 0: + queue.append(x_op) + + if len(sorted_list) != len(ops): + raise ValueError( + "inverse_sort_op wrong, sorted_list size is not equal to origin_list size" + ) + + return sorted_list + + +def append_backward_ops( + block, effective_forward_op, no_grad_set, backward_ops, state +): + ''' + add grad_op in order of topological inverse sort + eg: + from :op1 -> v1 -> op2 -> v2 -> op3 -> v3 + to: og1_g <- v1_g <- op2_g <- v2_g <- op3_g <- v3_g + + if op has grad_op, prepare its grad_op's inputs by value_to_valuegrad, + eg: + value_to_valuegrad[v3] = [[v3_g]]; + v2_g = call_vjp(op3, [v3_g], [v2_stopgradient]) + + + if op don't has grad_op, if it don't has input and it's output has more than + one output_grad, add sumop for grad aggregation. + (eg: full op and get_parameter op etc.) + + else continue to next op. + ''' + for op in effective_forward_op: + if paddle.framework.core.has_vjp(op): + # prepare output_grad + output_grad_list = [] # (opresult) + zero_flag = [False] * op.num_results() + for i, value in enumerate(op.results()): + if ( + value not in state.value_to_valuegrad + or state.value_to_valuegrad[value] is None + ): + # first case: + # this fwd_op's output didn't used by other fwd_op, + # so no output_grad created. + + # second case: + # last bwd_op return None because input in no_grad_set, + # but this bwd_op need a input. + + grad_value = paddle.full( + value.shape, + 0.0, + dtype=value.dtype, + ) + fillop = grad_value.get_defining_op() + + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], fillop + ) + state.value_to_valuegrad[value] = [[grad_value]] + zero_flag[i] = True + + if len(state.value_to_valuegrad[value]) > 1: + # one value is input of more than one fwd_op, + # so more than one bwd_op create input_grad, + # need add sum op to accumulate gradient + + paddle.add_n(list(state.value_to_valuegrad[value])) + sumop = block.ops[len(block.ops) - 1] + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], sumop + ) + state.value_to_valuegrad[value] = [[sumop.result(0)]] + state.value_to_sumvaluegrad[ + value + ] = state.value_to_valuegrad[value] + + output_grad = state.value_to_valuegrad[value][0] + output_grad_list.append(output_grad) + + # all(zero_flag) support this op has no contribution for grad + # should be delete (prune sub_graph) + if len(output_grad_list) == 0 or all(zero_flag): + continue + + # prepare input_grad stop_gradient info. + input_grad_stopgradient_list = [] + for input in op.operands_source(): + if input in no_grad_set: + input_grad_stopgradient_list.append([1]) + else: + input_grad_stopgradient_list.append([0]) + + before_ops_num = len(block.ops) + # prim should be a globel flag, it will make create_grad_op choose diffrient func + input_grad_list = paddle.framework.core.call_vjp( + op, output_grad_list, input_grad_stopgradient_list + ) + after_ops_num = len(block.ops) + + # find new grad_op_list + grad_op_list = [] + for i in range(before_ops_num, after_ops_num): + grad_op_list.append(block.ops[i]) + + for i, input in enumerate(op.operands()): + input_grad = input_grad_list[i] + state.value_to_valuegrad[input.source()].append(input_grad) + + # add grad_op + for grad_op in grad_op_list: + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], grad_op + ) + + else: + if op.num_operands() == 0 and op.num_results() != 0: + for value in op.results(): + if len(state.value_to_valuegrad[value]) > 1: + # need add sum op + paddle.add_n(list(state.value_to_valuegrad[value])) + sumop = block.ops[len(block.ops) - 1] + update_bwdop_structure( + backward_ops, state.op_to_opgrad[op], sumop + ) + state.value_to_valuegrad[value] = [[sumop.result(0)]] + state.value_to_sumvaluegrad[ + value + ] = state.value_to_valuegrad[value] + else: + state.op_to_opgrad[op] = [] + else: + state.op_to_opgrad[op] = [] + + +def create_backward_purne_set(inputs, outputs, no_grad_set, state): + outputs_set = set() + for input in inputs: + if state.value_to_valuegrad[input] != []: + outputs_set.add(state.value_to_valuegrad[input][0][0]) + + inputs_set = set() + for output in outputs: + if state.value_to_valuegrad[output] != []: + inputs_set.add(state.value_to_valuegrad[output][0][0]) + + no_gradvar_set = set() # grad_value of value in no_grad_set + for key in state.value_to_valuegrad: + if key in no_grad_set: + no_gradvar_set.add(state.value_to_valuegrad[key][0][0]) + + for key in state.value_to_sumvaluegrad: + if key in no_grad_set: + for item in state.value_to_valuegrad[key][0]: + no_gradvar_set.add(item) + + return outputs_set, inputs_set, no_gradvar_set + + +def remove_op(block, op, state): + ''' + remove op from block + ''' + block.remove_op(op) + if state.opgrad_to_op[op] != []: + fwd_op = state.opgrad_to_op[op][0] + state.op_to_opgrad[fwd_op].remove(op) + + for valuegrad in op.results(): + value = state.valuegrad_to_value[valuegrad][0] + state.value_to_valuegrad[value] = [] + if value in state.sumvaluegrad_to_value: + raise ValueError('input_grad in [%s] is value which need to sum ') + + +def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): + block = outputs[0].get_defining_op().get_parent_block() + state = State(block.get_parent_program()) + # check all inputs and outputs in the same block + check_all_puts(block, inputs, outputs) + # update no_grad_set if some value stop_gradient=True + update_no_grad_set_by_stopgradient(block, no_grad_set) + complete_outputs, _, backward_ops = prepare_grad_outputs( + block, + grad_outputs, + outputs, + state.value_to_valuegrad, + state.op_to_opgrad, + ) + + inputs_set = set(inputs) + outputs_set = set(complete_outputs) + effective_forward_op, _ = prune_ops( + block.ops, inputs_set, outputs_set, no_grad_set + ) + update_no_grad_set_after_purne( + block, effective_forward_op, no_grad_set, inputs, complete_outputs + ) + + sorted_effective_forward_op = inverse_sort_op(effective_forward_op) + + append_backward_ops( + block, sorted_effective_forward_op, no_grad_set, backward_ops, state + ) + # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) + + outputs_set, inputs_set, no_gradvar_set = create_backward_purne_set( + inputs, complete_outputs, no_grad_set, state + ) + _, remove_ops = prune_ops( + backward_ops, inputs_set, outputs_set, no_gradvar_set + ) + + state.turn_map() + + for bwd_op in inverse_sort_op(remove_ops): + remove_op(block, bwd_op, state) + + input_grad_map = state.value_to_valuegrad + return input_grad_map + + +def calc_gradient(outputs, inputs, grad_outputs, no_grad_set): + """ + caclulate gradient of input + + Args: + outputs (Value|list(Value)|tuple(Value)): the output Value or + Value list/tuple of the graph to compute gradients. + inputs (Value|list(Value)|tuple(Value)): the input Value or + Value list/tuple of the graph to compute gradients. The returned + values of this API are the gradients of `inputs` . + grad_outputs (Value|list(Value|None)|tuple(Value|None), optional): + initial gradient values of `outputs` . If `grad_outputs` is None, + the initial gradient values of `outputs` would be Values filled with 1; + if `grad_outputs` is not None, it must have the same length as `outputs` , + and in this case, the initial gradient value of the i-th `outputs` would + be: (1) a Value filled with 1 when the i-th element of `grad_outputs` + is None; (2) the i-th element of `grad_outputs` when the i-th element of + `grad_outputs` is a Value. Default None. + no_grad_set (set(Value), optional): + the Values whose gradients are not needed to compute. Default None. + + Return: + list[Value]:A list of gradients for inputs + If an input does not affect targets, the corresponding gradient Tensor + will be None + TODO if allow_unused=False raise TypeError() if input_grad has None + """ + # record input value and its gradient (Value to Value) + input_to_inputgrad_map = calc_gradient_helper( + outputs, inputs, grad_outputs=grad_outputs, no_grad_set=no_grad_set + ) + + inputgrad = [] + for input in inputs: + inputgrad.append( + input_to_inputgrad_map[input][0][0] + if input_to_inputgrad_map[input] != [] + else None + ) + return inputgrad + + +def grad( + outputs, + inputs, + grad_outputs=None, + retain_graph=None, + create_graph=False, + only_inputs=True, + allow_unused=False, + no_grad_vars=None, +): + ''' + .. note:: + **This API is ONLY available in imperative mode.** + + This API computes the sum of gradients of `outputs` with respect to each `inputs` . + + Parameters: + outputs (Value|list(Value)|tuple(Value)): the output Value or + Value list/tuple of the graph to compute gradients. + inputs (Value|list(Value)|tuple(Value)): the input Value or + Value list/tuple of the graph to compute gradients. The returned + values of this API are the gradients of `inputs` . + grad_outputs (Value|list(Value|None)|tuple(Value|None), optional): + initial gradient values of `outputs` . If `grad_outputs` is None, + the initial gradient values of `outputs` would be Values filled with 1; + if `grad_outputs` is not None, it must have the same length as `outputs` , + and in this case, the initial gradient value of the i-th `outputs` would + be: (1) a Value filled with 1 when the i-th element of `grad_outputs` + is None; (2) the i-th element of `grad_outputs` when the i-th element of + `grad_outputs` is a Value. Default None. + retain_graph (bool, optional): whether to retain the forward graph which + is used to calculate the gradient. When it is True, the graph would + be retained, in which way users can calculate backward twice for the + same graph. When it is False, the graph would be freed. Default None, + which means it is equal to `create_graph` . + create_graph (bool, optional): whether to create the gradient graphs of + the computing process. When it is True, higher order derivatives are + supported to compute; when it is False, the gradient graphs of the + computing process would be discarded. Default False. + only_inputs (bool, optional): whether to only compute the gradients of + `inputs` . If it is False, the gradients of all remaining leaf + Values in the graph would be also computed and accumulated. + If it is True, only the gradients of `inputs` would be computed. + Default True. only_inputs=False is under development, and it is + not supported yet. + allow_unused (bool, optional): whether to raise error or return None if some + Values of `inputs` are unreachable in the graph. If some Values of + `inputs` are unreachable in the graph (i.e., their gradients are None), + error would be raised if allow_unused=False, or None would be returned as + their gradients if allow_unused=True. Default False. + no_grad_vars (Value|list(Value)|tuple(Value)|set(Value), optional): + the Values whose gradients are not needed to compute. Default None. + + Returns: + list: a list of Values, whose length is the same as the Value number + inside `inputs`, and the i-th returned Value is the sum of gradients of + `outputs` with respect to the i-th `inputs`. + ''' + check_type( + outputs, + 'outputs', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), + 'paddle.ir.grad', + ) + check_type( + inputs, + 'inputs', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple), + 'paddle.ir.grad', + ) + check_type( + grad_outputs, + 'grad_outputs', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)), + 'paddle.ir.grad', + ) + + check_type( + no_grad_vars, + 'no_grad_vars', + ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)), + 'paddle.ir.grad', + ) + outputs = _as_list(outputs) + inputs = _as_list(inputs) + grad_outputs = _as_list(grad_outputs) + if no_grad_vars is None: + no_grad_set = set() + elif no_grad_vars is not set: + no_grad_set = set(no_grad_vars) + else: + no_grad_set = no_grad_vars + + input_grad = calc_gradient(outputs, inputs, grad_outputs, no_grad_set) + + return input_grad diff --git a/python/paddle/autograd/backward_utils.py b/python/paddle/autograd/backward_utils.py new file mode 100644 index 00000000000..8febb5c6dad --- /dev/null +++ b/python/paddle/autograd/backward_utils.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + + +class State: + """ + record relationship of forward op/value and backward op/value + one state must be bining with a program + + """ + + def __init__(self, program): + self.program = program + # opresult -> list(list(opresult)) + self.value_to_valuegrad = collections.defaultdict(list) + self.value_to_sumvaluegrad = collections.defaultdict(list) + # operation -> list(operation) + self.op_to_opgrad = collections.defaultdict(list) + + # opresult -> list(opresult) + self.valuegrad_to_value = collections.defaultdict(list) + self.sumvaluegrad_to_value = collections.defaultdict(list) + # operation -> list(operation) + self.opgrad_to_op = collections.defaultdict(list) + + def turn_map(self) -> None: + for k, v in self.value_to_valuegrad.items(): + if v != []: + for value in v[0]: + self.valuegrad_to_value[value] = [k] + for k, v in self.value_to_sumvaluegrad.items(): + if v != []: + for value in v[0]: + self.sumvaluegrad_to_value[value] = [k] + for k, v in self.op_to_opgrad.items(): + if v != []: + self.opgrad_to_op[v[0]] = [k] diff --git a/python/paddle/decomposition/decomp.py b/python/paddle/decomposition/decomp.py index 47f1d05bbb0..9d8fa8a05b2 100644 --- a/python/paddle/decomposition/decomp.py +++ b/python/paddle/decomposition/decomp.py @@ -162,7 +162,7 @@ def _decompose_subgraph(block, op_filter): """ if isinstance(block, Block): - ops_list = block.get_ops() + ops_list = block.ops for op in ops_list: op_name = op.name() decom_rule = register.get_decomp_rule(op_name) diff --git a/test/ir/new_ir/test_build_op.py b/test/ir/new_ir/test_build_op.py index 70668864615..40a882d7a57 100644 --- a/test/ir/new_ir/test_build_op.py +++ b/test/ir/new_ir/test_build_op.py @@ -39,7 +39,7 @@ def get_ir_program(): class TestBuildOp(unittest.TestCase): def test_build_mean_op(self): newir_program = get_ir_program() - tanh_out = newir_program.block().get_ops()[-1].result(0) + tanh_out = newir_program.block().ops[-1].result(0) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) with paddle.ir.core.program_guard(newir_program): out = paddle.mean(tanh_out) @@ -82,8 +82,8 @@ class TestBuildOp3(unittest.TestCase): def test_insertion_point(self): newir_program = get_ir_program() paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) - add_op = newir_program.block().get_ops()[-2] - tanh_op = newir_program.block().get_ops()[-1] + add_op = newir_program.block().ops[-2] + tanh_op = newir_program.block().ops[-1] add_out = add_op.result(0) tanh_operand = tanh_op.operands()[0] diff --git a/test/ir/new_ir/test_ir_backward.py b/test/ir/new_ir/test_ir_backward.py new file mode 100644 index 00000000000..8c93162f105 --- /dev/null +++ b/test/ir/new_ir/test_ir_backward.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle import ir +from paddle.autograd.backward import grad + +paddle.enable_static() + + +def get_ir_program_0(): + x = paddle.randn([4, 4]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + x_s.stop_gradient = False + k_s = paddle.tanh(x_s) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +def get_ir_program_1(): + x = paddle.randn([2, 2]) + main_program, start_program = ( + paddle.static.Program(), + paddle.static.Program(), + ) + with paddle.static.program_guard(main_program, start_program): + x_s = paddle.static.data('x', [4, 4], x.dtype) + y_s = paddle.static.data('y', [4, 4], x.dtype) + x_s.stop_gradient = False + z_x = paddle.tanh(y_s) + k_s = paddle.tanh(x_s) + out = paddle.add(z_x, k_s) + newir_program = ir.translate_to_new_ir(main_program.desc) + return newir_program + + +class TesBackward(unittest.TestCase): + def test_1(self): + newir_program = get_ir_program_0() + input = newir_program.block().ops[-1].operand(0).source() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.mean(tanh_out) + out2 = paddle.mean(tanh_out) + input_grad = grad(out, input, out2) + + print(newir_program) + self.assertEqual(out.get_defining_op().name(), "pd.mean") + self.assertEqual(input_grad[0].get_defining_op().name(), "pd.tanh_grad") + self.assertEqual( + out.get_defining_op() + .operands()[0] + .source() + .get_defining_op() + .name(), + "pd.tanh", + ) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + def test_2(self): + # test create output_grad in backward use full op + newir_program = get_ir_program_0() + input = newir_program.block().ops[-1].operand(0).source() + tanh_out = newir_program.block().ops[-1].result(0) + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + with paddle.ir.core.program_guard(newir_program): + out = paddle.mean(tanh_out) + input_grad = grad(out, input) + + print(newir_program) + self.assertEqual(newir_program.block().ops[-3].name(), "pd.full") + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + + # TODO(Ruting) test add_n op when add_n api and add_grad finished + # def test_3(self): + # # test add_n op + # newir_program = get_ir_program_1() + # input = newir_program.block().ops[-1].operand(0).source() + # tanh_out = newir_program.block().ops[-1].result(0) + # paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + # with paddle.ir.core.program_guard(newir_program): + # out = paddle.mean(tanh_out) + # input_grad = grad(out, input) + + # print(newir_program) + # self.assertEqual(newir_program.block().ops[-1].name(), "pd.add_n") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/new_ir/test_ir_pybind.py b/test/ir/new_ir/test_ir_pybind.py index feb89cbc88d..26844acc958 100644 --- a/test/ir/new_ir/test_ir_pybind.py +++ b/test/ir/new_ir/test_ir_pybind.py @@ -49,21 +49,21 @@ class TestPybind(unittest.TestCase): def test_block(self): newir_program = get_ir_program() block = newir_program.block() - ops = block.get_ops() + ops = block.ops self.assertEqual( len(ops), 4 ) # ir program add "builtin.get_parameter" by default, so size is 4 block.remove_op(ops[3]) - self.assertEqual(len(block.get_ops()), 3) + self.assertEqual(len(block.ops), 3) def test_operation(self): newir_program = get_ir_program() - ops = newir_program.block().get_ops() - matmul_op = newir_program.block().get_ops()[1] - add_op = newir_program.block().get_ops()[2] - tanh_op = newir_program.block().get_ops()[3] + ops = newir_program.block().ops + matmul_op = newir_program.block().ops[1] + add_op = newir_program.block().ops[2] + tanh_op = newir_program.block().ops[3] parent_block = tanh_op.get_parent_block() - parent_ops_num = len(parent_block.get_ops()) + parent_ops_num = len(parent_block.ops) self.assertEqual(parent_ops_num, 4) self.assertEqual(tanh_op.num_results(), 1) self.assertEqual(len(matmul_op.get_input_names()), 2) @@ -72,9 +72,9 @@ class TestPybind(unittest.TestCase): def test_value(self): newir_program = get_ir_program() - matmul_op = newir_program.block().get_ops()[1] - add_op = newir_program.block().get_ops()[2] - tanh_op = newir_program.block().get_ops()[3] + matmul_op = newir_program.block().ops[1] + add_op = newir_program.block().ops[2] + tanh_op = newir_program.block().ops[3] self.assertEqual( matmul_op.result(0).dtype, paddle.fluid.core.DataType.FLOAT32 @@ -123,8 +123,8 @@ class TestPybind(unittest.TestCase): def test_type(self): newir_program = get_ir_program() - matmul_op = newir_program.block().get_ops()[1] - add_op = newir_program.block().get_ops()[2] + matmul_op = newir_program.block().ops[1] + add_op = newir_program.block().ops[2] print(matmul_op.result(0).type()) self.assertEqual( matmul_op.result(0).type() == add_op.result(0).type(), True @@ -152,8 +152,8 @@ class TestPybind(unittest.TestCase): newir_program = ir.translate_to_new_ir(main_program.desc) print(newir_program) - conv_attr = newir_program.block().get_ops()[3].attrs() - full_attr = newir_program.block().get_ops()[8].attrs() + conv_attr = newir_program.block().ops[3].attrs() + full_attr = newir_program.block().ops[8].attrs() self.assertEqual(conv_attr["stop_gradient"], [False]) self.assertEqual(conv_attr["dilations"], [1, 1]) self.assertEqual(conv_attr["data_format"], "NCHW") @@ -166,13 +166,13 @@ class TestPybind(unittest.TestCase): def test_operands(self): newir_program = get_ir_program() - matmul_op = newir_program.block().get_ops()[1] + matmul_op = newir_program.block().ops[1] operands = matmul_op.operands() self.assertEqual(len(operands), 2) def test_results(self): newir_program = get_ir_program() - matmul_op = newir_program.block().get_ops()[1] + matmul_op = newir_program.block().ops[1] results = matmul_op.results() self.assertEqual(len(results), 1) diff --git a/test/ir/new_ir/test_ir_vjp.py b/test/ir/new_ir/test_ir_vjp.py index ec0e8632a4b..12931b89cca 100644 --- a/test/ir/new_ir/test_ir_vjp.py +++ b/test/ir/new_ir/test_ir_vjp.py @@ -38,8 +38,8 @@ def get_ir_program(): class TestTanhVjp(unittest.TestCase): def test_tanh_vjp1(self): newir_program = get_ir_program() - tanh_op = newir_program.block().get_ops()[-2] - fill_constant_op = newir_program.block().get_ops()[-1] + tanh_op = newir_program.block().ops[-2] + fill_constant_op = newir_program.block().ops[-1] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[0]] with paddle.ir.core.program_guard(newir_program): @@ -65,12 +65,12 @@ class TestTanhVjp(unittest.TestCase): .name(), "pd.full", ) - self.assertEqual(len(newir_program.block().get_ops()), 4) + self.assertEqual(len(newir_program.block().ops), 4) def test_tanh_vjp2(self): newir_program = get_ir_program() - tanh_op = newir_program.block().get_ops()[-2] - fill_constant_op = newir_program.block().get_ops()[-1] + tanh_op = newir_program.block().ops[-2] + fill_constant_op = newir_program.block().ops[-1] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[1]] with paddle.ir.core.program_guard(newir_program): @@ -90,8 +90,8 @@ class TestMeanVjp(unittest.TestCase): paddle.mean(x, axis=[0, 1]) paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0) newir_program = ir.translate_to_new_ir(main_program.desc) - fill_constant_op = newir_program.block().get_ops()[-1] - mean_op = newir_program.block().get_ops()[-2] + fill_constant_op = newir_program.block().ops[-1] + mean_op = newir_program.block().ops[-2] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[0]] with paddle.ir.core.program_guard(newir_program): @@ -117,7 +117,7 @@ class TestMeanVjp(unittest.TestCase): .name(), "pd.full", ) - self.assertEqual(len(newir_program.block().get_ops()), 4) + self.assertEqual(len(newir_program.block().ops), 4) def test_mean_vjp2(self): main_program, start_program = ( @@ -130,8 +130,8 @@ class TestMeanVjp(unittest.TestCase): paddle.mean(x, axis=[0, 1]) paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0) newir_program = ir.translate_to_new_ir(main_program.desc) - fill_constant_op = newir_program.block().get_ops()[-1] - mean_op = newir_program.block().get_ops()[-2] + fill_constant_op = newir_program.block().ops[-1] + mean_op = newir_program.block().ops[-2] out_grads = [[fill_constant_op.result(0)]] stop_gradients = [[1]] with paddle.ir.core.program_guard(newir_program): @@ -151,8 +151,8 @@ class TesthasVjp(unittest.TestCase): paddle.mean(x, axis=[0, 1]) paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0) newir_program = ir.translate_to_new_ir(main_program.desc) - fill_constant_op = newir_program.block().get_ops()[-1] - mean_op = newir_program.block().get_ops()[-2] + fill_constant_op = newir_program.block().ops[-1] + mean_op = newir_program.block().ops[-2] self.assertEqual(has_vjp(fill_constant_op), False) self.assertEqual(has_vjp(mean_op), True) diff --git a/test/prim/new_ir_prim/test_decomp_op.py b/test/prim/new_ir_prim/test_decomp_op.py index f56b68f2317..bb6ae6d69ef 100644 --- a/test/prim/new_ir_prim/test_decomp_op.py +++ b/test/prim/new_ir_prim/test_decomp_op.py @@ -43,7 +43,7 @@ class TestBuildOp(unittest.TestCase): newir_program = get_ir_program() paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) decompose(newir_program) - op_name_list = [op.name() for op in newir_program.block().get_ops()] + op_name_list = [op.name() for op in newir_program.block().ops] self.assertEqual( op_name_list, [ -- GitLab