提交 37b33973 编写于 作者: C cyber-pioneer

fix composite check output

上级 2590527f
......@@ -145,6 +145,13 @@
variance : Variance
scale : Scale
bias : Bias
outputs :
out : Y
mean_out: MeanOut
variance_out: VarianceOut
saved_mean: SavedMean
saved_variance: SavedVariance
reserve_space: ReserveSpace
extra :
attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]
......
......@@ -84,7 +84,7 @@ def generate_code(
else:
op_name = key
map_dct[op_name] = {"phi_name": op_name}
for element in ["inputs", "attrs"]:
for element in ["inputs", "outputs", "attrs"]:
if element in item.keys():
map_dct[op_name][element] = item[element]
for element in ["scalar", "int_array"]:
......
......@@ -609,6 +609,10 @@ def _lower_composite(block, blacklist=[]):
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
):
if new_out is not None:
assert orig_out.shape == new_out.shape, (
f'when replace origin op with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}'
)
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册