diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 7e33ea40a60769486887ac7dded8f5cbac7440c2..51209301bce6e61969a0b1e0f3efba6a7799526b 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -145,6 +145,13 @@ variance : Variance scale : Scale bias : Bias + outputs : + out : Y + mean_out: MeanOut + variance_out: VarianceOut + saved_mean: SavedMean + saved_variance: SavedVariance + reserve_space: ReserveSpace extra : attrs : [bool use_mkldnn = false, bool fuse_with_relu = false] @@ -421,6 +428,17 @@ - op : dropout backward : dropout_grad + inputs : + x : X + outputs : + out : Out + mask : Mask + attrs : + p : dropout_prob + is_test : is_test + mode : dropout_implementation + seed : seed + fix_seed : fix_seed extra : attrs : [bool fix_seed = false, int seed = 0] @@ -808,6 +826,14 @@ - op : layer_norm backward : layer_norm_grad + inputs : + x : X + scale : Scale + bias : Bias + outputs : + out : Y + mean : Mean + variance : Variance extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] @@ -978,6 +1004,17 @@ outputs : out : Out +- op : mean (reduce_mean) + backward : reduce_mean_grad + inputs : + x : X + outputs : + out : Out + attrs : + {axis : dim, keepdim : keep_dim} + extra : + attrs : [bool use_mkldnn = false] + - op : meshgrid backward : meshgrid_grad inputs : @@ -1196,11 +1233,6 @@ extra : attrs : [bool use_mkldnn = false] -- op : reduce_mean - backward : reduce_mean_grad - extra : - attrs : [bool use_mkldnn = false] - - op : reduce_min backward : reduce_min_grad extra : diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index d7b4be3b4d3d0c8fdbebfc5f0d3dc50ff64df0fc..c94d92f17f8723d74a76e530a07bed48809a87c1 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -449,14 +449,14 @@ def _test_use_sync(value): # ops in forward_blacklisk will not be replaced by composite ops. -prim_config = {"forward_blacklist": []} +prim_config = {"forward_blacklist": set(), "composite_ops_record": set()} def _set_prim_forward_blacklist(ops=None): if ops is None: prim_config["forward_blacklist"] = [] elif isinstance(ops, str): - prim_config["forward_blacklist"].append(ops) + prim_config["forward_blacklist"].add(ops) elif isinstance(ops, (list, tuple)): for item in ops: if not isinstance(item, str): @@ -464,7 +464,7 @@ def _set_prim_forward_blacklist(ops=None): "ops set in forward_blacklist must belong to [str, str of tuple or list]" ) else: - prim_config["forward_blacklist"].append(item) + prim_config["forward_blacklist"].add(item) else: raise TypeError( "ops set in forward_blacklist must belong to [str, str of tuple or list]" diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py index fd54850b2cb6f2a55d18f28c4644b7ea4b859fd1..6be130bbc57131f2666c8341b636d94bc2b76df5 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py @@ -68,7 +68,7 @@ def expect_forward(inputs): class TestCompositeSoftmax(unittest.TestCase): def setUp(self): self.dtypes = ["float32", "float64"] - self.shapes = [[2, 3, 4], [2, 3]] + self.shapes = [[], [2, 3, 4], [2, 3]] self.axes = [-1, 0, 1] def cal_composite(self, inputs): @@ -101,6 +101,9 @@ class TestCompositeSoftmax(unittest.TestCase): return res def compare_forward(self): + if not attrs.shape and attrs.axis not in [-1, 0]: + # op softmax does not support both case + return np_data = generate_data(attrs.shape) tensor_data = paddle.to_tensor(np_data) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py index d8b373133280f274713bf97596bc00273a55e647..87a2fafb50f607e302e50fc0275e3890148aefa9 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py @@ -143,7 +143,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): def setUp(self): core._set_prim_backward_enabled(True) self.dtypes = ["float32", "float64"] - self.shapes = [[2, 3, 4], [2, 3]] + self.shapes = [[], [2, 3, 4], [2, 3]] self.axes = [-1, 0, 1] def cal_composite_grad(self, inputs): @@ -169,6 +169,9 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): return res def compare_backward(self): + if not attrs.shape and attrs.axis not in [-1, 0]: + # op softmax does not support both case + return np_data = generate_data(attrs.shape) tensor_data = paddle.to_tensor(np_data) diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index ac2da7df81ce5e473ce636a284c8473ee9b10424..2900e59c7bdf208f5b9e16966bf9a4f8f7e8803d 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -18,8 +18,6 @@ # ops.yaml or legacy_ops.yaml. -import paddle - from .primitives import * # noqa: F403 from .primreg import REGISTER_COMPOSITE, lookup_composite @@ -32,6 +30,10 @@ def _composite(op, *args): @REGISTER_COMPOSITE('softmax') def softmax_composite(x, axis): """define composite rule of op softmax""" + if not x.shape: + # do not return 1, to ensure gradients + res = divide(x + 1e-5, x + 1e-5) + return res max_temp = max(x, axis, keepdim=True) max_temp.stop_gradient = True molecular = exp(x - max_temp) @@ -95,16 +97,15 @@ def composite_batchnorm( y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) # add op assign to detach tensor in void unsafe change outside the rule. + batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) + batch_var_ = assign(reshape(batch_var, run_var.shape)) + run_mean_ = assign(run_mean) + run_var_ = assign(run_var) - batch_mean_ = paddle.assign(batch_mean) - batch_var_ = paddle.assign(batch_var) - run_mean_ = paddle.assign(run_mean) - run_var_ = paddle.assign(run_var) + # reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op defination. + reserve_space = None - if trainable_statistics or not is_test: - return run_mean_, None, batch_mean_, batch_var_, run_var_, y - else: - return run_mean_, batch_mean_, batch_var_, run_var_, y + return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space @REGISTER_COMPOSITE('gelu') diff --git a/python/paddle/incubate/autograd/generate_op_map.py b/python/paddle/incubate/autograd/generate_op_map.py index d162789c226324096ff9c4eed95a5e2ff8ae1c74..34cef37c3cc995e10049b19a3fdfaab7b15f9fc4 100644 --- a/python/paddle/incubate/autograd/generate_op_map.py +++ b/python/paddle/incubate/autograd/generate_op_map.py @@ -84,7 +84,7 @@ def generate_code( else: op_name = key map_dct[op_name] = {"phi_name": op_name} - for element in ["inputs", "attrs"]: + for element in ["inputs", "outputs", "attrs"]: if element in item.keys(): map_dct[op_name][element] = item[element] for element in ["scalar", "int_array"]: diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 5f817a06ba6df89f6e496f8ccb7a27d8d2f02044..df4fc1c513ae56c109b56ad10b440c88ba320e9d 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -236,6 +236,8 @@ def to_prim(blocks): f"Expect block or sequence of blocks, but got {type(blocks)}." ) with framework.program_guard(main_program): - print("Running lowering for forward...") + print("Lowering composite forward ops begin...") primx._lower_composite(blocks, prim_config["forward_blacklist"]) + replace_ops = prim_config["composite_ops_record"] + print(f"Lowering composite forward ops finish: {replace_ops}") return diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 5e79128e568c4168c9b205cbf7fc6dd72222ebff..09c13e9aa40ddb78588821cea821dda9cc189150 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -18,6 +18,7 @@ from collections import OrderedDict import paddle from paddle.fluid import framework as framework +from paddle.fluid.core import prim_config from paddle.fluid.framework import Operator, default_main_program from paddle.incubate.autograd.utils import as_tensors @@ -36,6 +37,7 @@ from .utils import ( flatten_and_remove_none, get_input_var_list, get_output_var_list, + map_output_for_composite, prepare_python_api_arguments, ) @@ -596,19 +598,43 @@ def _lower_composite(block, blacklist=[]): # if output var of composite rule is None, this means this var is not needed none_vars_to_remove = set() + change = None # Step2: Process all ops in the target block for op_idx in range(len(block.ops)): op = block.ops[op_idx] ops_to_remove.append(op_idx) if lookup_fn(op.type) is not None and op.type not in blacklist: + change = True + op_name = op.type + prim_config["composite_ops_record"].add(op_name) input_args = prepare_python_api_arguments(op) bind(input_args, to_bind, value_table) + orig_outs = expand_nested_list(map_output_for_composite(op)) + new_outs = expand_nested_list( + as_tensors(lower_fn(op, *input_args)) + ) + assert len(orig_outs) == len(new_outs), ( + f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, ' + f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}' + ) for orig_out, new_out in zip( - expand_nested_list(get_output_var_list(op)), - expand_nested_list(as_tensors(lower_fn(op, *input_args))), + orig_outs, + new_outs, ): - if new_out is not None: + if orig_out is None: + # to keep same as phi op defination, orig_out may receive None + continue + elif new_out is not None: + assert orig_out.dtype == new_out.dtype, ( + f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, ' + f'but orig_out.dtype={orig_out.dtype} and new_out.dtype={new_out.dtype}' + ) + if orig_out.shape and new_out.shape: + assert orig_out.shape == new_out.shape, ( + f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, ' + f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' + ) assert not (orig_out is None) ^ ( new_out is None ), "orig_out and new_out should match." @@ -675,6 +701,10 @@ def _lower_composite(block, blacklist=[]): block.desc._remove_var(var_name.encode()) del block.vars[var_name] block._sync_with_cpp() + + # composite ops may contain other composite ops, thus, call _lower_composite again. + if change: + _lower_composite(block, blacklist) return elif isinstance(block, typing.Sequence): diff --git a/python/paddle/incubate/autograd/utils.py b/python/paddle/incubate/autograd/utils.py index 211851160b17fd6c148428c2983624fba3062e5f..b4a78dec8692f34f5d982c59bcc7edadbc1156fa 100644 --- a/python/paddle/incubate/autograd/utils.py +++ b/python/paddle/incubate/autograd/utils.py @@ -169,6 +169,7 @@ def _get_args_values(op, phi_name): arg_type, arg_name = _solve_arg(item) op_content = op_map[op.type] if arg_type in ("Tensor", "Tensor[]"): + # assume Tensor type must belong to inputs if ( "inputs" in op_content.keys() and arg_name in op_content["inputs"].keys() @@ -182,8 +183,12 @@ def _get_args_values(op, phi_name): "attrs" in op_content.keys() and arg_name in op_content["attrs"].keys() ): - attrs.append(op.attr(op_content["attrs"][arg_name])) - attrs.append(op.attr(arg_name)) + arg_name = op_content["attrs"][arg_name] + # Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded. + if arg_name not in op.attr_names: + attrs.append(None) + else: + attrs.append(op.attr(arg_name)) return inputs, attrs @@ -202,7 +207,13 @@ def prepare_python_api_arguments(op): else: phi_name = op.type inputs, attrs = _get_args_values(op, phi_name) - res = [get_var_block(op.block, op.input(n)) for n in inputs] + res = [] + for item in inputs: + if item in op.input_names: + res.append(get_var_block(op.block, op.input(item))) + else: + # Note: in some cases, inputs may be optional, thus assign None. Such case must be recorded. + res.append(None) if attrs: res.extend(attrs) return res @@ -218,6 +229,38 @@ def get_output_var_list(op): ] +def map_output_for_composite(op): + """origin op outputs must be mapped into outputs of composite rule. map info has been defined in op_compat.yaml""" + origin_output_names = op.output_names + if origin_output_names is None: + return [] + else: + name = op.type + res = [] + if op_map[name].get("outputs"): + for item in op_map[name]["outputs"].keys(): + origin_output_name = op_map[name]["outputs"][item] + if origin_output_name not in origin_output_names: + res.append(None) + # Note: in some cases, some output of origin op is optional, so op name may not be in origin_output_names + continue + origin_output_var = get_var_block( + op.block, op.output(origin_output_name) + ) + res.append(origin_output_var) + elif len(origin_output_names) == 1: + # When origin output num is 1, map info is not needed. + origin_output_var = get_var_block( + op.block, op.output(origin_output_names[0]) + ) + res.append(origin_output_var) + else: + raise ValueError( + "When replace op with composite rule, there must exist output map info from origin op to composite rule." + ) + return res + + def flatten(inp): if inp is None or isinstance(inp, paddle.fluid.framework.Variable): return [inp]