diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index 8e8447ad7eab0a2bc887b0aad157d56500660bc8..a8aca3d2b88e27ee717f63b9ab0aec32b245d58f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -18,11 +18,17 @@ from paddle.fluid.optimizer import Optimizer import paddle.fluid.core as core import numpy as np from . import ascend_parser +from paddle.distributed import fleet +import hccl.manage.api as hccl +from collections import namedtuple +HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids']) class AscendIRParser(object): def __init__(self): self.graph_idx = 0 + self.hcom_endpoints = {} + self.groups_to_create = [] def _construct_input_map(self, input_varlist): ret_map = {} @@ -38,8 +44,37 @@ class AscendIRParser(object): ret_map[var.name] = ge_input return ge_in_operator, ret_map + def _endpoint_to_world_rank_id(self, endpoint): + world_endpoints = fleet.worker_endpoints() + assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids()) + return world_endpoints.index(endpoint) + def parse_op(self, op): - if op.type in ascend_parser.registerd_op: + if op.type == 'c_gen_nccl_id': + endpoint = op.attr("endpoint") + other_endpoints = op.attr("other_endpoints") + rank = op.attr("rank") + + nccl_id = op.output_arg_names[0] + + # c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints + # we should combine these together to produce world_rank_ids + self.hcom_endpoints[nccl_id] = other_endpoints[:] + self.hcom_endpoints[nccl_id].insert(rank, endpoint) + + print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id])) + elif op.type == 'c_comm_init': + nccl_id = op.input_arg_names[0] + nranks = op.attr("nranks") + assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count" + rank = op.attr("rank") + ring_id = op.attr("ring_id") + + group_name = "hcom_group_" + str(ring_id) + global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]] + self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids)) + print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids)) + elif op.type in ascend_parser.registerd_op: print("Op[%s] has been registered, begin to parse it" % (op.type)) op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type]) op_parser.apply(op) @@ -137,7 +172,9 @@ class AscendOptimizer(Optimizer): parameter_list=None, no_grad_set=None, auto_dp=False): - minimized = self.inner_opt.minimize(loss, startup_program=startup_program) + minimized = None + if self.inner_opt: + minimized = self.inner_opt.minimize(loss, startup_program=startup_program) self.ascend_instance = core.AscendInstance() @@ -172,6 +209,10 @@ class AscendOptimizer(Optimizer): startup_graph, main_graph = self.parser.parse_program( startup_program, main_block.program, input_varlist, self.fetch_list) + for cfg in self.parser.groups_to_create: + hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids) + print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids)) + self.ascend_instance.add_ascend_subgraph(0, startup_graph) self.ascend_instance.add_ascend_subgraph(1, main_graph) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index 36fa1575fcdb375c901a4d05bab1ff135912c750..2b3313aa9e99ae8682ceb22d490254aab657199a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -1,21 +1,21 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import paddle.fluid.framework as framework from paddle.fluid.optimizer import Optimizer import paddle.fluid.core as core import numpy as np +from paddle.distributed import fleet registerd_op = { "elementwise_add": "AddParser", @@ -555,7 +555,8 @@ class AllReduceParser(AscendParserBase): def _apply(self): x = self._get_ge_input(self.op.input_arg_names[0]) reduction = self.reduction - group = "hccl_world_group" #self.op.attr("group") + ring_id = self.op.attr("ring_id") + group = "hcom_group_" + str(ring_id) fusion = None #self.op.attr("fusion") fusion_id = None #self.op.attr("fusion_id") @@ -658,12 +659,13 @@ class ReceiveParser(AscendParserBase): "shape", shape).set_attr_int32("dtype", dtype) return [receive], [[0]] + class ScaleParser(AscendParserBase): def __init__(self, graph, var2geop): super(ScaleParser, self).__init__(graph, var2geop) self.parser_name = "scale" - def _apply(self): + def _apply(self): x = self._get_ge_input(self.op.input_arg_names[0]) scale = self.op.attr("scale") #self.get_ge_input(self.op.input_arg_names[1]) bias = self.op.attr("bias") @@ -672,9 +674,9 @@ class ScaleParser(AscendParserBase): scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", bias) else: x_add_bias = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", bias) #set_input("x2", bias) - scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0) + scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0) #tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x) - #bias_ = self.create_ge_tensor([1], 5, bias) + #bias_ = self.create_ge_tensor([1], 5, bias) #const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias) return [scale_value],[[0]] @@ -695,5 +697,7 @@ class ReshapeParser(AscendParserBase): tensor = self._create_ge_tensor([len(shape)], 2, shape) const_shape = core.GEOperatorFactory.create_operator("shape" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor) reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis) - + return [reshape, reshape], [[0],[1]] + + diff --git a/python/paddle/fluid/tests/unittests/ascend_group.py b/python/paddle/fluid/tests/unittests/ascend_group.py index 2d5b709a48eefffe5b2b0a5f328fc3bdd40b919a..0bc810373c9567b3ed7e1d346b8e7fa0e384d6da 100644 --- a/python/paddle/fluid/tests/unittests/ascend_group.py +++ b/python/paddle/fluid/tests/unittests/ascend_group.py @@ -21,6 +21,11 @@ import paddle.fluid.core as core import paddle from paddle.fluid.layer_helper import LayerHelper from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.ascend import ascend_parser, ascend_optimizer +from collections import namedtuple + +Block = namedtuple('Block', ['program']) +Loss = namedtuple('Loss', ['block']) paddle.enable_static() @@ -63,10 +68,6 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints 'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward, }) - block.create_var( - name="data", - persistable=True, - dtype='float32') with fluid.program_guard(main_program): op_type="c_allreduce_sum" @@ -79,6 +80,9 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints attrs={'ring_id': ring_id, 'use_calc_stream': True}) + print("startup program:", startup_program) + print("main program:", main_program) + def train(world_endpoints, world_device_ids, local_device_ids,local_rank): startup_programs=[] main_programs=[] @@ -89,6 +93,7 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank): groups[0]=[trainer_endpoints[0], trainer_endpoints[1]] groups[1]=[trainer_endpoints[2], trainer_endpoints[3]] groups[2]=[trainer_endpoints[0], trainer_endpoints[2]] + print("groups:", groups) for i in range(len(trainer_endpoints)): startup_programs.append(fluid.Program()) @@ -105,6 +110,20 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank): print(startup_programs[local_rank]) print(main_programs[local_rank]) + print("local rank: ", local_rank) + print("local startup program: ", startup_programs[local_rank]) + + startup_program = startup_programs[local_rank] + main_program = main_programs[local_rank] + loss = Loss(Block(main_program)) + optimizer = ascend_optimizer.AscendOptimizer(None, fetch_list=[]) + optimizer.minimize(loss, startup_program, auto_dp=True) + + exe = paddle.static.Executor(paddle.CPUPlace()) + #exe.run(startup_program) + exe.run(main_program) + + worker_endpoints=fleet.worker_endpoints() world_device_ids=fleet.world_device_ids() local_device_ids=fleet.local_device_ids()