diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d893484a9f577bf725bbb4a0574109dd3c5576ee..46fdb3786df7f63c672c9b7cefce959b7c3dc444 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -145,6 +145,13 @@ variance : Variance scale : Scale bias : Bias + outputs : + out : Y + mean_out: MeanOut + variance_out: VarianceOut + saved_mean: SavedMean + saved_variance: SavedVariance + reserve_space: ReserveSpace extra : attrs : [bool use_mkldnn = false, bool fuse_with_relu = false] diff --git a/python/paddle/incubate/autograd/generate_op_map.py b/python/paddle/incubate/autograd/generate_op_map.py index d162789c226324096ff9c4eed95a5e2ff8ae1c74..34cef37c3cc995e10049b19a3fdfaab7b15f9fc4 100644 --- a/python/paddle/incubate/autograd/generate_op_map.py +++ b/python/paddle/incubate/autograd/generate_op_map.py @@ -84,7 +84,7 @@ def generate_code( else: op_name = key map_dct[op_name] = {"phi_name": op_name} - for element in ["inputs", "attrs"]: + for element in ["inputs", "outputs", "attrs"]: if element in item.keys(): map_dct[op_name][element] = item[element] for element in ["scalar", "int_array"]: diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 5e79128e568c4168c9b205cbf7fc6dd72222ebff..9314019dd556d9d7d0fe24a903cc60270c2e37bd 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -609,6 +609,10 @@ def _lower_composite(block, blacklist=[]): expand_nested_list(as_tensors(lower_fn(op, *input_args))), ): if new_out is not None: + assert orig_out.shape == new_out.shape, ( + f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' + f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' + ) assert not (orig_out is None) ^ ( new_out is None ), "orig_out and new_out should match."