未验证 提交 c687edec 编写于 作者: G gongweibao 提交者: GitHub

Fix reshape on GE graph. (#31084)

Fix reshape on GE graph
上级 a6edbc47
...@@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer): ...@@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer):
parameter_list=None, parameter_list=None,
no_grad_set=None, no_grad_set=None,
auto_dp=False, auto_dp=False,
rank_table_file=None): rank_table_file=None,
precision_mode="must_keep_origin_dtype"):
minimized = None minimized = None
if self.inner_opt: if self.inner_opt:
minimized = self.inner_opt.minimize( minimized = self.inner_opt.minimize(
...@@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer): ...@@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer):
config = { config = {
"ge.exec.deviceId": str(fleet.local_device_ids()), "ge.exec.deviceId": str(fleet.local_device_ids()),
"ge.graphRunMode": "1", "ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype", "ge.exec.precision_mode": precision_mode,
} }
# if multi trainers # if multi trainers
if rank_table_file and fleet.world_size() > 1: if rank_table_file and fleet.world_size() > 1:
......
...@@ -203,7 +203,8 @@ class AscendParserBase(object): ...@@ -203,7 +203,8 @@ class AscendParserBase(object):
def _accumulated_op_id(self): def _accumulated_op_id(self):
global global_cnt global global_cnt
global_cnt += 1 global_cnt += 1
return "." + str(global_cnt) name = "." + str(global_cnt)
return name
def _create_ge_tensor(self, shape, dtype, value): def _create_ge_tensor(self, shape, dtype, value):
tensor_desc = core.GETensorDesc( tensor_desc = core.GETensorDesc(
...@@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase): ...@@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase):
"unsqueeze" + self._accumulated_op_id(), "unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x", "Unsqueeze").set_input("x",
y).set_attr_vec_int32("axes", [0]) y).set_attr_vec_int32("axes", [0])
y_stack = core.GEOperatorFactory.create_operator(
"stack" + self._accumulated_op_id(),
"TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
"axis", 0).set_attr_int32("tiles", shape_out_grad[0])
x_grad = core.GEOperatorFactory.create_operator( x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(), self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input( "BatchMatMul").set_input("x1", out_grad).set_input(
"x2", y_unsqueeze).set_attr_bool( "x2", y_stack).set_attr_bool(
"adj_x1", False).set_attr_bool("adj_x2", True) "adj_x1", False).set_attr_bool("adj_x2", True)
y_grad = core.GEOperatorFactory.create_operator( y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(), self.parser_name + self._accumulated_op_id(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册