From 37b3397385986576ff25cc6824d2e39f78e2e0c7 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Fri, 10 Feb 2023 05:49:55 +0000 Subject: [PATCH] fix composite check output --- paddle/phi/api/yaml/op_compat.yaml | 7 +++++++ python/paddle/incubate/autograd/generate_op_map.py | 2 +- python/paddle/incubate/autograd/primx.py | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d893484a9f..46fdb3786d 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -145,6 +145,13 @@ variance : Variance scale : Scale bias : Bias + outputs : + out : Y + mean_out: MeanOut + variance_out: VarianceOut + saved_mean: SavedMean + saved_variance: SavedVariance + reserve_space: ReserveSpace extra : attrs : [bool use_mkldnn = false, bool fuse_with_relu = false] diff --git a/python/paddle/incubate/autograd/generate_op_map.py b/python/paddle/incubate/autograd/generate_op_map.py index d162789c22..34cef37c3c 100644 --- a/python/paddle/incubate/autograd/generate_op_map.py +++ b/python/paddle/incubate/autograd/generate_op_map.py @@ -84,7 +84,7 @@ def generate_code( else: op_name = key map_dct[op_name] = {"phi_name": op_name} - for element in ["inputs", "attrs"]: + for element in ["inputs", "outputs", "attrs"]: if element in item.keys(): map_dct[op_name][element] = item[element] for element in ["scalar", "int_array"]: diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 5e79128e56..9314019dd5 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -609,6 +609,10 @@ def _lower_composite(block, blacklist=[]): expand_nested_list(as_tensors(lower_fn(op, *input_args))), ): if new_out is not None: + assert orig_out.shape == new_out.shape, ( + f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' + f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' + ) assert not (orig_out is None) ^ ( new_out is None ), "orig_out and new_out should match." -- GitLab