From c687edecd80363782a41a9d73e883ff4e7687bdb Mon Sep 17 00:00:00 2001 From: gongweibao Date: Sat, 20 Feb 2021 17:33:38 +0800 Subject: [PATCH] Fix reshape on GE graph. (#31084) Fix reshape on GE graph --- .../fleet/meta_optimizers/ascend/ascend_optimizer.py | 5 +++-- .../fleet/meta_optimizers/ascend/ascend_parser.py | 9 +++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index 978899604e..824225fd77 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer): parameter_list=None, no_grad_set=None, auto_dp=False, - rank_table_file=None): + rank_table_file=None, + precision_mode="must_keep_origin_dtype"): minimized = None if self.inner_opt: minimized = self.inner_opt.minimize( @@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer): config = { "ge.exec.deviceId": str(fleet.local_device_ids()), "ge.graphRunMode": "1", - "ge.exec.precision_mode": "must_keep_origin_dtype", + "ge.exec.precision_mode": precision_mode, } # if multi trainers if rank_table_file and fleet.world_size() > 1: diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index d7ba61a8e4..e191776ffe 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -203,7 +203,8 @@ class AscendParserBase(object): def _accumulated_op_id(self): global global_cnt global_cnt += 1 - return "." + str(global_cnt) + name = "." + str(global_cnt) + return name def _create_ge_tensor(self, shape, dtype, value): tensor_desc = core.GETensorDesc( @@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase): "unsqueeze" + self._accumulated_op_id(), "Unsqueeze").set_input("x", y).set_attr_vec_int32("axes", [0]) + y_stack = core.GEOperatorFactory.create_operator( + "stack" + self._accumulated_op_id(), + "TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32( + "axis", 0).set_attr_int32("tiles", shape_out_grad[0]) x_grad = core.GEOperatorFactory.create_operator( self.parser_name + self._accumulated_op_id(), "BatchMatMul").set_input("x1", out_grad).set_input( - "x2", y_unsqueeze).set_attr_bool( + "x2", y_stack).set_attr_bool( "adj_x1", False).set_attr_bool("adj_x2", True) y_grad = core.GEOperatorFactory.create_operator( self.parser_name + self._accumulated_op_id(), -- GitLab