diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index 978899604eaf8c2ee45c03f866f2d5a081a7e502..824225fd776d1363d79e2218959507df8668bcee 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -214,7 +214,8 @@ class AscendOptimizer(Optimizer): parameter_list=None, no_grad_set=None, auto_dp=False, - rank_table_file=None): + rank_table_file=None, + precision_mode="must_keep_origin_dtype"): minimized = None if self.inner_opt: minimized = self.inner_opt.minimize( @@ -234,7 +235,7 @@ class AscendOptimizer(Optimizer): config = { "ge.exec.deviceId": str(fleet.local_device_ids()), "ge.graphRunMode": "1", - "ge.exec.precision_mode": "must_keep_origin_dtype", + "ge.exec.precision_mode": precision_mode, } # if multi trainers if rank_table_file and fleet.world_size() > 1: diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index d7ba61a8e40144fd5f75156788a95c4b4cb235ea..e191776ffe41e7e9291b1fb5f3d560e8f81fc880 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -203,7 +203,8 @@ class AscendParserBase(object): def _accumulated_op_id(self): global global_cnt global_cnt += 1 - return "." + str(global_cnt) + name = "." + str(global_cnt) + return name def _create_ge_tensor(self, shape, dtype, value): tensor_desc = core.GETensorDesc( @@ -1630,10 +1631,14 @@ class MulGradParser(AscendParserBase): "unsqueeze" + self._accumulated_op_id(), "Unsqueeze").set_input("x", y).set_attr_vec_int32("axes", [0]) + y_stack = core.GEOperatorFactory.create_operator( + "stack" + self._accumulated_op_id(), + "TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32( + "axis", 0).set_attr_int32("tiles", shape_out_grad[0]) x_grad = core.GEOperatorFactory.create_operator( self.parser_name + self._accumulated_op_id(), "BatchMatMul").set_input("x1", out_grad).set_input( - "x2", y_unsqueeze).set_attr_bool( + "x2", y_stack).set_attr_bool( "adj_x1", False).set_attr_bool("adj_x2", True) y_grad = core.GEOperatorFactory.create_operator( self.parser_name + self._accumulated_op_id(),