未验证 提交 76c495d7 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix set_grad_var_shape (#50722)

* fix set_grad_var_shape

* recover modify
上级 d4217fc6
...@@ -1263,6 +1263,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1263,6 +1263,7 @@ def set_grad_var_shape(program, dist_context):
"relu_grad", "relu_grad",
"exp_grad", "exp_grad",
"sigmoid_grad", "sigmoid_grad",
"unsqueeze2_grad",
] ]
forward_list = [ forward_list = [
"reshape2", "reshape2",
...@@ -1283,6 +1284,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1283,6 +1284,7 @@ def set_grad_var_shape(program, dist_context):
"relu", "relu",
"exp", "exp",
"sigmoid", "sigmoid",
"unsqueeze2",
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册