未验证 提交 c687edec 编写于 作者: G gongweibao 提交者: GitHub

Fix reshape on GE graph. (#31084)

Fix reshape on GE graph
上级 a6edbc47
......@@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer):
parameter_list=None,
no_grad_set=None,
auto_dp=False,
rank_table_file=None):
rank_table_file=None,
precision_mode="must_keep_origin_dtype"):
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(
......@@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer):
config = {
"ge.exec.deviceId": str(fleet.local_device_ids()),
"ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype",
"ge.exec.precision_mode": precision_mode,
}
# if multi trainers
if rank_table_file and fleet.world_size() > 1:
......
......@@ -203,7 +203,8 @@ class AscendParserBase(object):
def _accumulated_op_id(self):
global global_cnt
global_cnt += 1
return "." + str(global_cnt)
name = "." + str(global_cnt)
return name
def _create_ge_tensor(self, shape, dtype, value):
tensor_desc = core.GETensorDesc(
......@@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase):
"unsqueeze" + self._accumulated_op_id(),
"Unsqueeze").set_input("x",
y).set_attr_vec_int32("axes", [0])
y_stack = core.GEOperatorFactory.create_operator(
"stack" + self._accumulated_op_id(),
"TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
"axis", 0).set_attr_int32("tiles", shape_out_grad[0])
x_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
"BatchMatMul").set_input("x1", out_grad).set_input(
"x2", y_unsqueeze).set_attr_bool(
"x2", y_stack).set_attr_bool(
"adj_x1", False).set_attr_bool("adj_x2", True)
y_grad = core.GEOperatorFactory.create_operator(
self.parser_name + self._accumulated_op_id(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册