diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 46e77eb50ca3f7842c431f5cc577337903036e86..6cf53c106d1efb770c47120be312635e379aa640 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1043,7 +1043,7 @@ def set_grad_var_shape(program, dist_context): "fill_zeros_like" ]: forward_var_name = op.input_arg_names[0] - elif op.type == "matmul_v2_grad": + elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": forward_var_name = None for output_name in op.output_names: if var_name in op.output(output_name):