diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 46fdb3786df7f63c672c9b7cefce959b7c3dc444..5214e462b98eaf14f385779cb21edba892a14a25 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -414,6 +414,17 @@ - op : dropout backward : dropout_grad + inputs : + x : X + outputs : + out : Out + mask : Mask + attrs : + p : dropout_prob + is_test : is_test + mode : dropout_implementation + seed : seed + fix_seed : fix_seed extra : attrs : [bool fix_seed = false, int seed = 0] @@ -790,6 +801,14 @@ - op : layer_norm backward : layer_norm_grad + inputs : + x : X + scale : Scale + bias : Bias + outputs : + out : Y + mean : Mean + variance : Variance extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] @@ -940,6 +959,17 @@ outputs : out : Out +- op : mean (reduce_mean) + backward : reduce_mean_grad + inputs : + x : X + outputs : + out : Out + attrs : + {axis : dim, keepdim : keep_dim} + extra : + attrs : [bool use_mkldnn = false] + - op : meshgrid backward : meshgrid_grad inputs : @@ -1145,17 +1175,6 @@ extra : attrs : [bool use_mkldnn = false] -- op : mean (reduce_mean) - backward : reduce_mean_grad - inputs : - x : X - outputs : - out : Out - attrs : - {axis : dim, keepdim : keep_dim} - extra : - attrs : [bool use_mkldnn = false] - - op : reduce_min backward : reduce_min_grad extra : diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index b2a6cc806276650863c7d71e7633826a6e0072c0..a69bce3c37bc43f5c39f8b0ca560c17fecfcafa7 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -597,11 +597,13 @@ def _lower_composite(block, blacklist=[]): # if output var of composite rule is None, this means this var is not needed none_vars_to_remove = set() + change = None # Step2: Process all ops in the target block for op_idx in range(len(block.ops)): op = block.ops[op_idx] ops_to_remove.append(op_idx) if lookup_fn(op.type) is not None and op.type not in blacklist: + change = True input_args = prepare_python_api_arguments(op) bind(input_args, to_bind, value_table) @@ -681,6 +683,10 @@ def _lower_composite(block, blacklist=[]): block.desc._remove_var(var_name.encode()) del block.vars[var_name] block._sync_with_cpp() + + # composite ops may contain other ops, thus, call _lower_composite again. + if change: + _lower_composite(block, blacklist) return elif isinstance(block, typing.Sequence): diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index 2f651f0e16f630de01806eadeeaf956d797d843d..90bdb78336dc4a257e39c7ae37f3babf2a6cca0f 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -232,14 +232,20 @@ def get_output_vars_from_comosite(op): origin_output_name = op_map[name]["outputs"][item] if origin_output_name not in origin_output_names: continue - origin_output_var = get_var_block(op.block, op.output(origin_output_name)) + origin_output_var = get_var_block( + op.block, op.output(origin_output_name) + ) res.append(origin_output_var) elif len(origin_output_names) == 1: # When origin output num is 1, map info is not needed. - origin_output_var = get_var_block(op.block, op.output(origin_output_names[0])) + origin_output_var = get_var_block( + op.block, op.output(origin_output_names[0]) + ) res.append(origin_output_var) else: - raise ValueError("When replace op with composite rule, there must exist output map info from origin op to composite rule.") + raise ValueError( + "When replace op with composite rule, there must exist output map info from origin op to composite rule." + ) return res