From 520b05469be4e257827590d8df51bf1de97ddc04 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Thu, 18 Aug 2022 18:46:27 +0800 Subject: [PATCH] add grad var shape infer of matmul and mul (#45237) --- python/paddle/distributed/auto_parallel/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 46e77eb50ca..6cf53c106d1 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1043,7 +1043,7 @@ def set_grad_var_shape(program, dist_context): "fill_zeros_like" ]: forward_var_name = op.input_arg_names[0] - elif op.type == "matmul_v2_grad": + elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": forward_var_name = None for output_name in op.output_names: if var_name in op.output(output_name): -- GitLab