未验证 提交 de42d193 编写于 作者: G gongweibao 提交者: GitHub

Add paddle ascend distribution training supported (#30796)

Add paddle ascend distribution training supported
上级 ebb5d181
...@@ -26,10 +26,12 @@ HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids']) ...@@ -26,10 +26,12 @@ HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
class AscendIRParser(object): class AscendIRParser(object):
def __init__(self): def __init__(self, auto_dp=False, world_rank_size=1):
self.graph_idx = 0 self.graph_idx = 0
self.hcom_endpoints = {} self.hcom_endpoints = {}
self.groups_to_create = [] self.groups_to_create = []
self._auto_dp = auto_dp
self._world_rank_size = world_rank_size
def _construct_input_map(self, input_varlist): def _construct_input_map(self, input_varlist):
ret_map = {} ret_map = {}
...@@ -91,13 +93,12 @@ class AscendIRParser(object): ...@@ -91,13 +93,12 @@ class AscendIRParser(object):
print("append to create group: %s, with rank_ids: %s" % print("append to create group: %s, with rank_ids: %s" %
(group_name, global_rank_ids)) (group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op: elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse( op_parser = self.parser_factory.create_parse(
ascend_parser.registerd_op[op.type]) ascend_parser.registerd_op[op.type])
op_parser.apply(op) op_parser.apply(op)
else: else:
print("Op[%s] has not been registered, so we have to skip it" % assert False, "Op[%s] has not been registered, so we have to skip it" % (
(op.type)) op.type)
def _parse_program(self, def _parse_program(self,
graph_name, graph_name,
...@@ -161,6 +162,17 @@ class AscendIRParser(object): ...@@ -161,6 +162,17 @@ class AscendIRParser(object):
startup_graph = self._parse_program("startup", startup_program) startup_graph = self._parse_program("startup", startup_program)
main_graph = self._parse_program("main", main_program, input_varlist, main_graph = self._parse_program("main", main_program, input_varlist,
fetch_list) fetch_list)
if self._auto_dp and self._world_rank_size > 1:
assert len(self.groups_to_create
) == 0, "can't parse program under auto_dp mode"
from paddle.distributed import fleet
self.groups_to_create.append(
HcomGroupConfig(
name="hcom_group_0",
nranks=fleet.world_size(),
rank_ids=[x for x in range(fleet.world_size())]))
return startup_graph, main_graph return startup_graph, main_graph
...@@ -196,7 +208,8 @@ class AscendOptimizer(Optimizer): ...@@ -196,7 +208,8 @@ class AscendOptimizer(Optimizer):
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None, no_grad_set=None,
auto_dp=False): auto_dp=False,
rank_table_file=None):
minimized = None minimized = None
if self.inner_opt: if self.inner_opt:
minimized = self.inner_opt.minimize( minimized = self.inner_opt.minimize(
...@@ -205,24 +218,25 @@ class AscendOptimizer(Optimizer): ...@@ -205,24 +218,25 @@ class AscendOptimizer(Optimizer):
self.ascend_instance = core.AscendInstance() self.ascend_instance = core.AscendInstance()
from paddle.distributed import fleet from paddle.distributed import fleet
if auto_dp and fleet.worker_num() > 1: if auto_dp and fleet.world_size() > 1:
from paddle.fluid.transpiler import ascend_transpiler from paddle.fluid.transpiler import ascend_transpiler
t = ascend_transpiler.AscendTranspiler(startup_program, t = ascend_transpiler.AscendTranspiler(startup_program,
loss.block.program) loss.block.program)
t.transpile() t.transpile()
print(loss.block.program) #print(loss.block.program)
# Config about Graph Engine can be found in https://support.huaweicloud.com/ # Config about Graph Engine can be found in https://support.huaweicloud.com/
config = { config = {
"ge.exec.deviceId": str(fleet.local_device_ids()), "ge.exec.deviceId": str(fleet.local_device_ids()),
"ge.graphRunMode": "1", "ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype", "ge.exec.precision_mode": "must_keep_origin_dtype",
# if multi mode
"ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"),
"ge.exec.rankId": str(fleet.worker_index()),
"ge.exec.isUseHcom": "1",
"ge.exec.deployMode": "0",
} }
# if multi trainers
if rank_table_file and fleet.world_size() > 1:
config["ge.exec.rankTableFile"] = rank_table_file
config["ge.exec.rankId"] = str(fleet.worker_index())
config["ge.exec.isUseHcom"] = "1"
config["ge.exec.deployMode"] = "0"
print("ge_initialize config:", config) print("ge_initialize config:", config)
core.ge_initialize(config) core.ge_initialize(config)
...@@ -230,7 +244,8 @@ class AscendOptimizer(Optimizer): ...@@ -230,7 +244,8 @@ class AscendOptimizer(Optimizer):
self.ascend_instance.init_global_resources() self.ascend_instance.init_global_resources()
main_block = loss.block main_block = loss.block
self.parser = AscendIRParser() self.parser = AscendIRParser(
auto_dp=auto_dp, world_rank_size=fleet.world_size())
input_varlist = self._get_input_varlist(main_block.program) input_varlist = self._get_input_varlist(main_block.program)
...@@ -238,9 +253,9 @@ class AscendOptimizer(Optimizer): ...@@ -238,9 +253,9 @@ class AscendOptimizer(Optimizer):
startup_program, main_block.program, input_varlist, self.fetch_list) startup_program, main_block.program, input_varlist, self.fetch_list)
for cfg in self.parser.groups_to_create: for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" % print("create group (%s), nranks: %d, rank_ids: %s" %
(cfg.name, cfg.nranks, cfg.rank_ids)) (cfg.name, cfg.nranks, cfg.rank_ids))
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
self.ascend_instance.add_ascend_subgraph(0, startup_graph) self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph) self.ascend_instance.add_ascend_subgraph(1, main_graph)
......
...@@ -170,6 +170,7 @@ class AscendParserBase(object): ...@@ -170,6 +170,7 @@ class AscendParserBase(object):
self.parser_name, len(index_list), output_num) self.parser_name, len(index_list), output_num)
for output_id in range(output_num): for output_id in range(output_num):
arguments = self.op.output(self.op.output_names[output_id]) arguments = self.op.output(self.op.output_names[output_id])
#print("%d argument: %s" % (output_id, str(arguments)))
if len(arguments) > 0: if len(arguments) > 0:
assert len(arguments) == len( assert len(arguments) == len(
index_list[output_id] index_list[output_id]
...@@ -177,6 +178,8 @@ class AscendParserBase(object): ...@@ -177,6 +178,8 @@ class AscendParserBase(object):
self.parser_name, output_id, len(index_list[output_id]), self.parser_name, output_id, len(index_list[output_id]),
len(arguments)) len(arguments))
for i in range(len(arguments)): for i in range(len(arguments)):
#print("assgin index_list[%d][%d] to %s" %
# (output_id, i, arguments[i]))
self.var2geop[arguments[i]] = geop_list[index_list[ self.var2geop[arguments[i]] = geop_list[index_list[
output_id][i]] output_id][i]]
...@@ -789,6 +792,8 @@ class FillConstantParser(AscendParserBase): ...@@ -789,6 +792,8 @@ class FillConstantParser(AscendParserBase):
"Const").set_attr_tensor("value", tensor) "Const").set_attr_tensor("value", tensor)
self._mark_as_input(const) self._mark_as_input(const)
if self.op.block.var(self.op.output('Out')[0]).persistable: if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in fill_constant" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator( var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable") self.op.output('Out')[0], "Variable")
var.update_output_desc("y", var.update_output_desc("y",
...@@ -800,6 +805,10 @@ class FillConstantParser(AscendParserBase): ...@@ -800,6 +805,10 @@ class FillConstantParser(AscendParserBase):
"assign" + self._accumulated_op_id(), "Assign").set_input( "assign" + self._accumulated_op_id(), "Assign").set_input(
"value", const).set_input("ref", var) "value", const).set_input("ref", var)
return [const], [[0]] return [const], [[0]]
#else:
# print(
# "self.op.output('Out')[0]: %s is not persistable in fill_constant"
# % (self.op.output('Out')[0]))
return [const], [[0]] return [const], [[0]]
...@@ -853,6 +862,8 @@ class TruncatedNormalParser(AscendParserBase): ...@@ -853,6 +862,8 @@ class TruncatedNormalParser(AscendParserBase):
## wirte the output of truncatedNormal from startup_program to main_program ## wirte the output of truncatedNormal from startup_program to main_program
if self.op.block.var(self.op.output('Out')[0]).persistable: if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in truncated_normal" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator( var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable") self.op.output('Out')[0], "Variable")
var.update_output_desc("y", var.update_output_desc("y",
...@@ -867,6 +878,10 @@ class TruncatedNormalParser(AscendParserBase): ...@@ -867,6 +878,10 @@ class TruncatedNormalParser(AscendParserBase):
shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor, shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor,
truncated_normal truncated_normal
], [[-1]] ], [[-1]]
#else:
# print(
# "self.op.output('Out')[0] is not persistable in truncated_noraml"
# )
return [truncated_normal], [[0]] return [truncated_normal], [[0]]
...@@ -1436,7 +1451,8 @@ class SqueezeParser(AscendParserBase): ...@@ -1436,7 +1451,8 @@ class SqueezeParser(AscendParserBase):
.set_input("x", tensor)\ .set_input("x", tensor)\
.set_attr_vec_int32("axes", axes) .set_attr_vec_int32("axes", axes)
shape = core.GEOperatorFactory.create_operator( shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", data_squeezed) "shape" + self._accumulated_op_id(),
"Shape").set_input("x", data_squeezed)
return [shape, data_squeezed], [[1], [0]] return [shape, data_squeezed], [[1], [0]]
...@@ -2172,4 +2188,3 @@ class AdamParser(AscendParserBase): ...@@ -2172,4 +2188,3 @@ class AdamParser(AscendParserBase):
"epsilon", epsilon).set_input("grad", grad) "epsilon", epsilon).set_input("grad", grad)
return [adam], [[0]] return [adam], [[0]]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册