From de42d1933677d34df62a0f691378fe75507c0714 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Mon, 1 Feb 2021 15:34:54 +0800 Subject: [PATCH] Add paddle ascend distribution training supported (#30796) Add paddle ascend distribution training supported --- .../ascend/ascend_optimizer.py | 43 +++++++++++++------ .../meta_optimizers/ascend/ascend_parser.py | 29 ++++++++++--- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index 89a19b6479..71b22d9519 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -26,10 +26,12 @@ HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids']) class AscendIRParser(object): - def __init__(self): + def __init__(self, auto_dp=False, world_rank_size=1): self.graph_idx = 0 self.hcom_endpoints = {} self.groups_to_create = [] + self._auto_dp = auto_dp + self._world_rank_size = world_rank_size def _construct_input_map(self, input_varlist): ret_map = {} @@ -91,13 +93,12 @@ class AscendIRParser(object): print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids)) elif op.type in ascend_parser.registerd_op: - print("Op[%s] has been registered, begin to parse it" % (op.type)) op_parser = self.parser_factory.create_parse( ascend_parser.registerd_op[op.type]) op_parser.apply(op) else: - print("Op[%s] has not been registered, so we have to skip it" % - (op.type)) + assert False, "Op[%s] has not been registered, so we have to skip it" % ( + op.type) def _parse_program(self, graph_name, @@ -161,6 +162,17 @@ class AscendIRParser(object): startup_graph = self._parse_program("startup", startup_program) main_graph = self._parse_program("main", main_program, input_varlist, fetch_list) + if self._auto_dp and self._world_rank_size > 1: + assert len(self.groups_to_create + ) == 0, "can't parse program under auto_dp mode" + + from paddle.distributed import fleet + self.groups_to_create.append( + HcomGroupConfig( + name="hcom_group_0", + nranks=fleet.world_size(), + rank_ids=[x for x in range(fleet.world_size())])) + return startup_graph, main_graph @@ -196,7 +208,8 @@ class AscendOptimizer(Optimizer): startup_program=None, parameter_list=None, no_grad_set=None, - auto_dp=False): + auto_dp=False, + rank_table_file=None): minimized = None if self.inner_opt: minimized = self.inner_opt.minimize( @@ -205,24 +218,25 @@ class AscendOptimizer(Optimizer): self.ascend_instance = core.AscendInstance() from paddle.distributed import fleet - if auto_dp and fleet.worker_num() > 1: + if auto_dp and fleet.world_size() > 1: from paddle.fluid.transpiler import ascend_transpiler t = ascend_transpiler.AscendTranspiler(startup_program, loss.block.program) t.transpile() - print(loss.block.program) + #print(loss.block.program) # Config about Graph Engine can be found in https://support.huaweicloud.com/ config = { "ge.exec.deviceId": str(fleet.local_device_ids()), "ge.graphRunMode": "1", "ge.exec.precision_mode": "must_keep_origin_dtype", - # if multi mode - "ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"), - "ge.exec.rankId": str(fleet.worker_index()), - "ge.exec.isUseHcom": "1", - "ge.exec.deployMode": "0", } + # if multi trainers + if rank_table_file and fleet.world_size() > 1: + config["ge.exec.rankTableFile"] = rank_table_file + config["ge.exec.rankId"] = str(fleet.worker_index()) + config["ge.exec.isUseHcom"] = "1" + config["ge.exec.deployMode"] = "0" print("ge_initialize config:", config) core.ge_initialize(config) @@ -230,7 +244,8 @@ class AscendOptimizer(Optimizer): self.ascend_instance.init_global_resources() main_block = loss.block - self.parser = AscendIRParser() + self.parser = AscendIRParser( + auto_dp=auto_dp, world_rank_size=fleet.world_size()) input_varlist = self._get_input_varlist(main_block.program) @@ -238,9 +253,9 @@ class AscendOptimizer(Optimizer): startup_program, main_block.program, input_varlist, self.fetch_list) for cfg in self.parser.groups_to_create: - hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids) print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids)) + hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids) self.ascend_instance.add_ascend_subgraph(0, startup_graph) self.ascend_instance.add_ascend_subgraph(1, main_graph) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index 222543856a..8e2f5b60ab 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -170,6 +170,7 @@ class AscendParserBase(object): self.parser_name, len(index_list), output_num) for output_id in range(output_num): arguments = self.op.output(self.op.output_names[output_id]) + #print("%d argument: %s" % (output_id, str(arguments))) if len(arguments) > 0: assert len(arguments) == len( index_list[output_id] @@ -177,6 +178,8 @@ class AscendParserBase(object): self.parser_name, output_id, len(index_list[output_id]), len(arguments)) for i in range(len(arguments)): + #print("assgin index_list[%d][%d] to %s" % + # (output_id, i, arguments[i])) self.var2geop[arguments[i]] = geop_list[index_list[ output_id][i]] @@ -789,6 +792,8 @@ class FillConstantParser(AscendParserBase): "Const").set_attr_tensor("value", tensor) self._mark_as_input(const) if self.op.block.var(self.op.output('Out')[0]).persistable: + #print("%s is Persistable in fill_constant" % + # (self.op.output('Out')[0])) var = core.GEOperatorFactory.create_operator( self.op.output('Out')[0], "Variable") var.update_output_desc("y", @@ -800,6 +805,10 @@ class FillConstantParser(AscendParserBase): "assign" + self._accumulated_op_id(), "Assign").set_input( "value", const).set_input("ref", var) return [const], [[0]] + #else: + # print( + # "self.op.output('Out')[0]: %s is not persistable in fill_constant" + # % (self.op.output('Out')[0])) return [const], [[0]] @@ -853,6 +862,8 @@ class TruncatedNormalParser(AscendParserBase): ## wirte the output of truncatedNormal from startup_program to main_program if self.op.block.var(self.op.output('Out')[0]).persistable: + #print("%s is Persistable in truncated_normal" % + # (self.op.output('Out')[0])) var = core.GEOperatorFactory.create_operator( self.op.output('Out')[0], "Variable") var.update_output_desc("y", @@ -867,6 +878,10 @@ class TruncatedNormalParser(AscendParserBase): shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor, truncated_normal ], [[-1]] + #else: + # print( + # "self.op.output('Out')[0] is not persistable in truncated_noraml" + # ) return [truncated_normal], [[0]] @@ -1366,7 +1381,7 @@ class UniformRandomParser(AscendParserBase): tensor1 = self._create_ge_tensor([len(shape)], 2, shape) shape_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), + "const" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor1) ge_ur = core.GEOperatorFactory.create_operator( @@ -1379,9 +1394,9 @@ class UniformRandomParser(AscendParserBase): scale = max_v - min_v scale_value = core.GEOperatorFactory.create_operator( - "scale" + self._accumulated_op_id(), "Power").set_input( - "x", ge_ur).set_attr_float("power", 1.0).set_attr_float( - "scale", scale).set_attr_float("shift", min_v) + "scale" + self._accumulated_op_id(), "Power").set_input( + "x", ge_ur).set_attr_float("power", 1.0).set_attr_float( + "scale", scale).set_attr_float("shift", min_v) return [scale_value], [[0]] @@ -1429,14 +1444,15 @@ class SqueezeParser(AscendParserBase): def _apply(self): tensor = self._get_ge_input(self.op.input_arg_names[0]) - axes = self.op.attr("axes") + axes = self.op.attr("axes") data_squeezed = core.GEOperatorFactory\ .create_operator("squeeze" + self._accumulated_op_id(), "Squeeze")\ .set_input("x", tensor)\ .set_attr_vec_int32("axes", axes) shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape").set_input("x", data_squeezed) + "shape" + self._accumulated_op_id(), + "Shape").set_input("x", data_squeezed) return [shape, data_squeezed], [[1], [0]] @@ -2172,4 +2188,3 @@ class AdamParser(AscendParserBase): "epsilon", epsilon).set_input("grad", grad) return [adam], [[0]] - -- GitLab