未验证 提交 520b0546 编写于 作者: C caozhou 提交者: GitHub

add grad var shape infer of matmul and mul (#45237)

上级 41294cb5
...@@ -1043,7 +1043,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1043,7 +1043,7 @@ def set_grad_var_shape(program, dist_context):
"fill_zeros_like" "fill_zeros_like"
]: ]:
forward_var_name = op.input_arg_names[0] forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad": elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
forward_var_name = None forward_var_name = None
for output_name in op.output_names: for output_name in op.output_names:
if var_name in op.output(output_name): if var_name in op.output(output_name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册