未验证 提交 904cc443 编写于 作者: V Void Main 提交者: GitHub

[Feature] Build parser to support distributed training (#30658)

[Feature] Build parser to support distributed training
上级 5b77b259
...@@ -18,11 +18,17 @@ from paddle.fluid.optimizer import Optimizer ...@@ -18,11 +18,17 @@ from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
from . import ascend_parser from . import ascend_parser
from paddle.distributed import fleet
import hccl.manage.api as hccl
from collections import namedtuple
HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
class AscendIRParser(object): class AscendIRParser(object):
def __init__(self): def __init__(self):
self.graph_idx = 0 self.graph_idx = 0
self.hcom_endpoints = {}
self.groups_to_create = []
def _construct_input_map(self, input_varlist): def _construct_input_map(self, input_varlist):
ret_map = {} ret_map = {}
...@@ -38,8 +44,37 @@ class AscendIRParser(object): ...@@ -38,8 +44,37 @@ class AscendIRParser(object):
ret_map[var.name] = ge_input ret_map[var.name] = ge_input
return ge_in_operator, ret_map return ge_in_operator, ret_map
def _endpoint_to_world_rank_id(self, endpoint):
world_endpoints = fleet.worker_endpoints()
assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids())
return world_endpoints.index(endpoint)
def parse_op(self, op): def parse_op(self, op):
if op.type in ascend_parser.registerd_op: if op.type == 'c_gen_nccl_id':
endpoint = op.attr("endpoint")
other_endpoints = op.attr("other_endpoints")
rank = op.attr("rank")
nccl_id = op.output_arg_names[0]
# c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints
# we should combine these together to produce world_rank_ids
self.hcom_endpoints[nccl_id] = other_endpoints[:]
self.hcom_endpoints[nccl_id].insert(rank, endpoint)
print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id]))
elif op.type == 'c_comm_init':
nccl_id = op.input_arg_names[0]
nranks = op.attr("nranks")
assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count"
rank = op.attr("rank")
ring_id = op.attr("ring_id")
group_name = "hcom_group_" + str(ring_id)
global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]]
self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids))
print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type)) print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type]) op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type])
op_parser.apply(op) op_parser.apply(op)
...@@ -137,6 +172,8 @@ class AscendOptimizer(Optimizer): ...@@ -137,6 +172,8 @@ class AscendOptimizer(Optimizer):
parameter_list=None, parameter_list=None,
no_grad_set=None, no_grad_set=None,
auto_dp=False): auto_dp=False):
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(loss, startup_program=startup_program) minimized = self.inner_opt.minimize(loss, startup_program=startup_program)
self.ascend_instance = core.AscendInstance() self.ascend_instance = core.AscendInstance()
...@@ -172,6 +209,10 @@ class AscendOptimizer(Optimizer): ...@@ -172,6 +209,10 @@ class AscendOptimizer(Optimizer):
startup_graph, main_graph = self.parser.parse_program( startup_graph, main_graph = self.parser.parse_program(
startup_program, main_block.program, input_varlist, self.fetch_list) startup_program, main_block.program, input_varlist, self.fetch_list)
for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids))
self.ascend_instance.add_ascend_subgraph(0, startup_graph) self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph) self.ascend_instance.add_ascend_subgraph(1, main_graph)
......
...@@ -11,11 +11,11 @@ ...@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
from paddle.distributed import fleet
registerd_op = { registerd_op = {
"elementwise_add": "AddParser", "elementwise_add": "AddParser",
...@@ -555,7 +555,8 @@ class AllReduceParser(AscendParserBase): ...@@ -555,7 +555,8 @@ class AllReduceParser(AscendParserBase):
def _apply(self): def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0]) x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.reduction reduction = self.reduction
group = "hccl_world_group" #self.op.attr("group") ring_id = self.op.attr("ring_id")
group = "hcom_group_" + str(ring_id)
fusion = None #self.op.attr("fusion") fusion = None #self.op.attr("fusion")
fusion_id = None #self.op.attr("fusion_id") fusion_id = None #self.op.attr("fusion_id")
...@@ -658,6 +659,7 @@ class ReceiveParser(AscendParserBase): ...@@ -658,6 +659,7 @@ class ReceiveParser(AscendParserBase):
"shape", shape).set_attr_int32("dtype", dtype) "shape", shape).set_attr_int32("dtype", dtype)
return [receive], [[0]] return [receive], [[0]]
class ScaleParser(AscendParserBase): class ScaleParser(AscendParserBase):
def __init__(self, graph, var2geop): def __init__(self, graph, var2geop):
super(ScaleParser, self).__init__(graph, var2geop) super(ScaleParser, self).__init__(graph, var2geop)
...@@ -697,3 +699,5 @@ class ReshapeParser(AscendParserBase): ...@@ -697,3 +699,5 @@ class ReshapeParser(AscendParserBase):
reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis) reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis)
return [reshape, reshape], [[0],[1]] return [reshape, reshape], [[0],[1]]
...@@ -21,6 +21,11 @@ import paddle.fluid.core as core ...@@ -21,6 +21,11 @@ import paddle.fluid.core as core
import paddle import paddle
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.ascend import ascend_parser, ascend_optimizer
from collections import namedtuple
Block = namedtuple('Block', ['program'])
Loss = namedtuple('Loss', ['block'])
paddle.enable_static() paddle.enable_static()
...@@ -63,10 +68,6 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints ...@@ -63,10 +68,6 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints
'ring_id': ring_id, 'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward, OP_ROLE_KEY: OpRole.Forward,
}) })
block.create_var(
name="data",
persistable=True,
dtype='float32')
with fluid.program_guard(main_program): with fluid.program_guard(main_program):
op_type="c_allreduce_sum" op_type="c_allreduce_sum"
...@@ -79,6 +80,9 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints ...@@ -79,6 +80,9 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints
attrs={'ring_id': ring_id, attrs={'ring_id': ring_id,
'use_calc_stream': True}) 'use_calc_stream': True})
print("startup program:", startup_program)
print("main program:", main_program)
def train(world_endpoints, world_device_ids, local_device_ids,local_rank): def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
startup_programs=[] startup_programs=[]
main_programs=[] main_programs=[]
...@@ -89,6 +93,7 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank): ...@@ -89,6 +93,7 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
groups[0]=[trainer_endpoints[0], trainer_endpoints[1]] groups[0]=[trainer_endpoints[0], trainer_endpoints[1]]
groups[1]=[trainer_endpoints[2], trainer_endpoints[3]] groups[1]=[trainer_endpoints[2], trainer_endpoints[3]]
groups[2]=[trainer_endpoints[0], trainer_endpoints[2]] groups[2]=[trainer_endpoints[0], trainer_endpoints[2]]
print("groups:", groups)
for i in range(len(trainer_endpoints)): for i in range(len(trainer_endpoints)):
startup_programs.append(fluid.Program()) startup_programs.append(fluid.Program())
...@@ -105,6 +110,20 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank): ...@@ -105,6 +110,20 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
print(startup_programs[local_rank]) print(startup_programs[local_rank])
print(main_programs[local_rank]) print(main_programs[local_rank])
print("local rank: ", local_rank)
print("local startup program: ", startup_programs[local_rank])
startup_program = startup_programs[local_rank]
main_program = main_programs[local_rank]
loss = Loss(Block(main_program))
optimizer = ascend_optimizer.AscendOptimizer(None, fetch_list=[])
optimizer.minimize(loss, startup_program, auto_dp=True)
exe = paddle.static.Executor(paddle.CPUPlace())
#exe.run(startup_program)
exe.run(main_program)
worker_endpoints=fleet.worker_endpoints() worker_endpoints=fleet.worker_endpoints()
world_device_ids=fleet.world_device_ids() world_device_ids=fleet.world_device_ids()
local_device_ids=fleet.local_device_ids() local_device_ids=fleet.local_device_ids()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册