未验证 提交 904cc443 编写于 作者: V Void Main 提交者: GitHub

[Feature] Build parser to support distributed training (#30658)

[Feature] Build parser to support distributed training
上级 5b77b259
......@@ -18,11 +18,17 @@ from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
import numpy as np
from . import ascend_parser
from paddle.distributed import fleet
import hccl.manage.api as hccl
from collections import namedtuple
HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
class AscendIRParser(object):
def __init__(self):
self.graph_idx = 0
self.hcom_endpoints = {}
self.groups_to_create = []
def _construct_input_map(self, input_varlist):
ret_map = {}
......@@ -38,8 +44,37 @@ class AscendIRParser(object):
ret_map[var.name] = ge_input
return ge_in_operator, ret_map
def _endpoint_to_world_rank_id(self, endpoint):
world_endpoints = fleet.worker_endpoints()
assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids())
return world_endpoints.index(endpoint)
def parse_op(self, op):
if op.type in ascend_parser.registerd_op:
if op.type == 'c_gen_nccl_id':
endpoint = op.attr("endpoint")
other_endpoints = op.attr("other_endpoints")
rank = op.attr("rank")
nccl_id = op.output_arg_names[0]
# c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints
# we should combine these together to produce world_rank_ids
self.hcom_endpoints[nccl_id] = other_endpoints[:]
self.hcom_endpoints[nccl_id].insert(rank, endpoint)
print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id]))
elif op.type == 'c_comm_init':
nccl_id = op.input_arg_names[0]
nranks = op.attr("nranks")
assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count"
rank = op.attr("rank")
ring_id = op.attr("ring_id")
group_name = "hcom_group_" + str(ring_id)
global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]]
self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids))
print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type])
op_parser.apply(op)
......@@ -137,6 +172,8 @@ class AscendOptimizer(Optimizer):
parameter_list=None,
no_grad_set=None,
auto_dp=False):
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(loss, startup_program=startup_program)
self.ascend_instance = core.AscendInstance()
......@@ -172,6 +209,10 @@ class AscendOptimizer(Optimizer):
startup_graph, main_graph = self.parser.parse_program(
startup_program, main_block.program, input_varlist, self.fetch_list)
for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids))
self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph)
......
......@@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
import numpy as np
from paddle.distributed import fleet
registerd_op = {
"elementwise_add": "AddParser",
......@@ -555,7 +555,8 @@ class AllReduceParser(AscendParserBase):
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
reduction = self.reduction
group = "hccl_world_group" #self.op.attr("group")
ring_id = self.op.attr("ring_id")
group = "hcom_group_" + str(ring_id)
fusion = None #self.op.attr("fusion")
fusion_id = None #self.op.attr("fusion_id")
......@@ -658,6 +659,7 @@ class ReceiveParser(AscendParserBase):
"shape", shape).set_attr_int32("dtype", dtype)
return [receive], [[0]]
class ScaleParser(AscendParserBase):
def __init__(self, graph, var2geop):
super(ScaleParser, self).__init__(graph, var2geop)
......@@ -697,3 +699,5 @@ class ReshapeParser(AscendParserBase):
reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis)
return [reshape, reshape], [[0],[1]]
......@@ -21,6 +21,11 @@ import paddle.fluid.core as core
import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.ascend import ascend_parser, ascend_optimizer
from collections import namedtuple
Block = namedtuple('Block', ['program'])
Loss = namedtuple('Loss', ['block'])
paddle.enable_static()
......@@ -63,10 +68,6 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
})
block.create_var(
name="data",
persistable=True,
dtype='float32')
with fluid.program_guard(main_program):
op_type="c_allreduce_sum"
......@@ -79,6 +80,9 @@ def init_communicator(startup_program, main_program, current_endpoint, endpoints
attrs={'ring_id': ring_id,
'use_calc_stream': True})
print("startup program:", startup_program)
print("main program:", main_program)
def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
startup_programs=[]
main_programs=[]
......@@ -89,6 +93,7 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
groups[0]=[trainer_endpoints[0], trainer_endpoints[1]]
groups[1]=[trainer_endpoints[2], trainer_endpoints[3]]
groups[2]=[trainer_endpoints[0], trainer_endpoints[2]]
print("groups:", groups)
for i in range(len(trainer_endpoints)):
startup_programs.append(fluid.Program())
......@@ -105,6 +110,20 @@ def train(world_endpoints, world_device_ids, local_device_ids,local_rank):
print(startup_programs[local_rank])
print(main_programs[local_rank])
print("local rank: ", local_rank)
print("local startup program: ", startup_programs[local_rank])
startup_program = startup_programs[local_rank]
main_program = main_programs[local_rank]
loss = Loss(Block(main_program))
optimizer = ascend_optimizer.AscendOptimizer(None, fetch_list=[])
optimizer.minimize(loss, startup_program, auto_dp=True)
exe = paddle.static.Executor(paddle.CPUPlace())
#exe.run(startup_program)
exe.run(main_program)
worker_endpoints=fleet.worker_endpoints()
world_device_ids=fleet.world_device_ids()
local_device_ids=fleet.local_device_ids()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册