diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index a8aca3d2b88e27ee717f63b9ab0aec32b245d58f..c3aec546156e1873817634a943508e6964afb6f0 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -24,6 +24,7 @@ from collections import namedtuple HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids']) + class AscendIRParser(object): def __init__(self): self.graph_idx = 0 @@ -34,19 +35,26 @@ class AscendIRParser(object): ret_map = {} ge_in_operator = [] for id, var in enumerate(input_varlist): - if var.is_data: # input data - ge_input = core.GEOperatorFactory.create_operator(var.name, "Data").set_attr_int32("index", id) + if var.is_data: # input data + ge_input = core.GEOperatorFactory.create_operator( + var.name, "Data").set_attr_int32("index", id) ret_map[var.name] = ge_input ge_in_operator.append(ge_input) - else: # param, learning ... - ge_input = core.GEOperatorFactory.create_operator(var.name, "Variable") - ge_input.update_output_desc("y", core.GETensorDesc(core.GEShape(var.shape), core.GEFormat.FORMAT_ND, core.GEDataType.DT_FLOAT)) + else: # param, learning ... + ge_input = core.GEOperatorFactory.create_operator(var.name, + "Variable") + ge_input.update_output_desc("y", + core.GETensorDesc( + core.GEShape(var.shape), + core.GEFormat.FORMAT_ND, + core.GEDataType.DT_FLOAT)) ret_map[var.name] = ge_input return ge_in_operator, ret_map def _endpoint_to_world_rank_id(self, endpoint): world_endpoints = fleet.worker_endpoints() - assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids()) + assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % ( + endpoint, fleet.world_device_ids()) return world_endpoints.index(endpoint) def parse_op(self, op): @@ -62,26 +70,40 @@ class AscendIRParser(object): self.hcom_endpoints[nccl_id] = other_endpoints[:] self.hcom_endpoints[nccl_id].insert(rank, endpoint) - print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id])) + print("nccl_id (%s) registered endpoints %s" % + (nccl_id, self.hcom_endpoints[nccl_id])) elif op.type == 'c_comm_init': nccl_id = op.input_arg_names[0] nranks = op.attr("nranks") - assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count" + assert nranks == len(self.hcom_endpoints[ + nccl_id]), "nranks doesn't match endpoint count" rank = op.attr("rank") ring_id = op.attr("ring_id") group_name = "hcom_group_" + str(ring_id) - global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]] - self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids)) - print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids)) + global_rank_ids = [ + self._endpoint_to_world_rank_id(endpoint) + for endpoint in self.hcom_endpoints[nccl_id] + ] + self.groups_to_create.append( + HcomGroupConfig( + name=group_name, nranks=nranks, rank_ids=global_rank_ids)) + print("append to create group: %s, with rank_ids: %s" % + (group_name, global_rank_ids)) elif op.type in ascend_parser.registerd_op: print("Op[%s] has been registered, begin to parse it" % (op.type)) - op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type]) + op_parser = self.parser_factory.create_parse( + ascend_parser.registerd_op[op.type]) op_parser.apply(op) else: - print("Op[%s] has not been registered, so we have to skip it" % (op.type)) - - def _parse_program(self, graph_name, program, input_varlist=[], fetch_list=[]): + print("Op[%s] has not been registered, so we have to skip it" % + (op.type)) + + def _parse_program(self, + graph_name, + program, + input_varlist=[], + fetch_list=[]): begin_graph_idx = self.graph_idx ge_in_operator = [] ge_out_operator = [] @@ -96,7 +118,8 @@ class AscendIRParser(object): ge_in_operator, self.var2geop = self._construct_input_map(input_varlist) - self.parser_factory = ascend_parser.AscendParserFactory(graph, self.var2geop) + self.parser_factory = ascend_parser.AscendParserFactory(graph, + self.var2geop) for i, curop in list(enumerate(block.ops)): self.parse_op(curop) @@ -133,9 +156,11 @@ class AscendIRParser(object): self.graph_idx += 1 return graph - def parse_program(self, startup_program, main_program, input_varlist, fetch_list): + def parse_program(self, startup_program, main_program, input_varlist, + fetch_list): startup_graph = self._parse_program("startup", startup_program) - main_graph = self._parse_program("main", main_program, input_varlist, fetch_list) + main_graph = self._parse_program("main", main_program, input_varlist, + fetch_list) return startup_graph, main_graph @@ -174,14 +199,16 @@ class AscendOptimizer(Optimizer): auto_dp=False): minimized = None if self.inner_opt: - minimized = self.inner_opt.minimize(loss, startup_program=startup_program) + minimized = self.inner_opt.minimize( + loss, startup_program=startup_program) self.ascend_instance = core.AscendInstance() from paddle.distributed import fleet if auto_dp and fleet.worker_num() > 1: from paddle.fluid.transpiler import ascend_transpiler - t = ascend_transpiler.AscendTranspiler(startup_program, loss.block.program) + t = ascend_transpiler.AscendTranspiler(startup_program, + loss.block.program) t.transpile() print(loss.block.program) @@ -211,7 +238,8 @@ class AscendOptimizer(Optimizer): for cfg in self.parser.groups_to_create: hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids) - print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids)) + print("create group (%s), nranks: %d, rank_ids: %s" % + (cfg.name, cfg.nranks, cfg.rank_ids)) self.ascend_instance.add_ascend_subgraph(0, startup_graph) self.ascend_instance.add_ascend_subgraph(1, main_graph) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index 2b3313aa9e99ae8682ceb22d490254aab657199a..b7f21b4051ceca195f55d894e37f53263acaf2ca 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -69,11 +69,13 @@ class AscendHelper(object): } def dtype2ge(self, dtype): - assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (dtype) + assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % ( + dtype) return self.dtype2ge_map[dtype] def dtype2np(self, index): - assert index in self.dtype2np_map, "index[%d] is not supported %d" % (dtype) + assert index in self.dtype2np_map, "index[%d] is not supported %d" % ( + dtype) return self.dtype2np_map[index] @@ -98,7 +100,8 @@ class AscendParserBase(object): self.ascend_helper = AscendHelper() def _get_ge_input(self, input_var_name): - assert input_var_name in self.var2geop, "var %s not created before" % (input_var_name) + assert input_var_name in self.var2geop, "var %s not created before" % ( + input_var_name) return self.var2geop[input_var_name] def update_output(self, geop_list, index_list): @@ -119,7 +122,8 @@ class AscendParserBase(object): for i in range(len(arguments)): print("assgin index_list[%d][%d] to %s" % (output_id, i, arguments[i])) - self.var2geop[arguments[i]] = geop_list[index_list[output_id][i]] + self.var2geop[arguments[i]] = geop_list[index_list[ + output_id][i]] for geop in geop_list: self.graph.add_op(geop) @@ -483,11 +487,15 @@ class TruncatedNormalParser(AscendParserBase): "const" + self._accumulated_op_id(), "Const").set_attr_tensor( "value", tensor3) - tensor4 = self._create_ge_tensor([1], dtype, mean-2*std) - min_tensor = core.GEOperatorFactory.create_operator("const" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor4) + tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std) + min_tensor = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), "Const").set_attr_tensor( + "value", tensor4) - tensor5 = self._create_ge_tensor([1], dtype, mean+2*std) - max_tensor = core.GEOperatorFactory.create_operator("const" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor5) + tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std) + max_tensor = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), "Const").set_attr_tensor( + "value", tensor5) self._mark_as_input(shape_tensor) self._mark_as_input(mean_tensor) @@ -546,10 +554,11 @@ class AllGatherParser(AscendParserBase): "rank_size", rank_size).set_attr_string("group", group) return [allgather], [[0]] + class AllReduceParser(AscendParserBase): def __init__(self, graph, var2geop, reduction): super(AllReduceParser, self).__init__(graph, var2geop) - self.parser_name = "c_allreduce_" + reduction + self.parser_name = "c_allreduce_" + reduction self.reduction = reduction def _apply(self): @@ -557,8 +566,8 @@ class AllReduceParser(AscendParserBase): reduction = self.reduction ring_id = self.op.attr("ring_id") group = "hcom_group_" + str(ring_id) - fusion = None #self.op.attr("fusion") - fusion_id = None #self.op.attr("fusion_id") + fusion = None #self.op.attr("fusion") + fusion_id = None #self.op.attr("fusion_id") allreduce = core.GEOperatorFactory.create_operator( "allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input( @@ -611,10 +620,10 @@ class ReduceScatterParser(AscendParserBase): rank_size = self.op.attr("rank_size") reduce_scatter = core.GEOperatorFactory.create_operator( - "reducescatter" + self._accumulated_op_id(), "HcomReduceScatter").set_input( - "x", x).set_attr_string( - "reduction", reduction).set_attr_string( - "group", group).set_attr_int32("rank_size", rank_size) + "reducescatter" + self._accumulated_op_id(), + "HcomReduceScatter").set_input("x", x).set_attr_string( + "reduction", reduction).set_attr_string( + "group", group).set_attr_int32("rank_size", rank_size) return [reduce_scatter], [[0]] @@ -631,9 +640,8 @@ class SendParser(AscendParserBase): send = core.GEOperatorFactory.create_operator( "send" + self._accumulated_op_id(), "HcomSend").set_input( - "x", x).set_attr_int32( - "sr_tag", sr_tag).set_attr_int32( - "dest_rank", dest_rank).set_attr_string("group", group) + "x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32( + "dest_rank", dest_rank).set_attr_string("group", group) return [send], [[0]] @@ -652,11 +660,10 @@ class ReceiveParser(AscendParserBase): receive = core.GEOperatorFactory.create_operator( "receive" + self._accumulated_op_id(), "HcomReceive").set_input( - "x", x).set_attr_int32( - "sr_tag", sr_tag).set_attr_int32( - "src_rank", src_rank).set_attr_string( - "group", group).set_attr_vec_int32( - "shape", shape).set_attr_int32("dtype", dtype) + "x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32( + "src_rank", src_rank).set_attr_string( + "group", group).set_attr_vec_int32( + "shape", shape).set_attr_int32("dtype", dtype) return [receive], [[0]] @@ -667,18 +674,30 @@ class ScaleParser(AscendParserBase): def _apply(self): x = self._get_ge_input(self.op.input_arg_names[0]) - scale = self.op.attr("scale") #self.get_ge_input(self.op.input_arg_names[1]) + scale = self.op.attr( + "scale") #self.get_ge_input(self.op.input_arg_names[1]) bias = self.op.attr("bias") bias_after_scale = self.op.attr("bias_after_scale") if bias_after_scale: - scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", bias) + scale_value = core.GEOperatorFactory.create_operator( + "scale" + self._accumulated_op_id(), "Power").set_input( + "x", x).set_attr_float("power", 1.0).set_attr_float( + "scale", scale).set_attr_float("shift", bias) else: - x_add_bias = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", bias) #set_input("x2", bias) - scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0) + x_add_bias = core.GEOperatorFactory.create_operator( + "adds" + self._accumulated_op_id(), "Adds").set_input( + "x", x).set_attr_float("value", + bias) #set_input("x2", bias) + scale_value = core.GEOperatorFactory.create_operator( + "scale" + self._accumulated_op_id(), "Power").set_input( + "x", x_add_bias).set_attr_float( + "power", 1.0).set_attr_float( + "scale", scale).set_attr_float("shift", 0.0) #tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x) #bias_ = self.create_ge_tensor([1], 5, bias) #const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias) - return [scale_value],[[0]] + return [scale_value], [[0]] + class ReshapeParser(AscendParserBase): def __init__(self, graph, var2geop): @@ -695,9 +714,12 @@ class ReshapeParser(AscendParserBase): print("shape: ", shape) data_x1_shape = self._get_ge_input(self.op.input_arg_names[0]) tensor = self._create_ge_tensor([len(shape)], 2, shape) - const_shape = core.GEOperatorFactory.create_operator("shape" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor) - reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis) - - return [reshape, reshape], [[0],[1]] - + const_shape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), "Const").set_attr_tensor( + "value", tensor) + reshape = core.GEOperatorFactory.create_operator( + "reshape" + self._accumulated_op_id(), "Reshape").set_input( + "x", data_x1_shape).set_input( + "shape", const_shape).set_attr_int32("axis", axis) + return [reshape, reshape], [[0], [1]]