未验证 提交 76c495d7 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix set_grad_var_shape (#50722)

* fix set_grad_var_shape

* recover modify
上级 d4217fc6
......@@ -1263,6 +1263,7 @@ def set_grad_var_shape(program, dist_context):
"relu_grad",
"exp_grad",
"sigmoid_grad",
"unsqueeze2_grad",
]
forward_list = [
"reshape2",
......@@ -1283,6 +1284,7 @@ def set_grad_var_shape(program, dist_context):
"relu",
"exp",
"sigmoid",
"unsqueeze2",
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册