未验证 提交 de42d193 编写于 作者: G gongweibao 提交者: GitHub

Add paddle ascend distribution training supported (#30796)

Add paddle ascend distribution training supported
上级 ebb5d181
......@@ -26,10 +26,12 @@ HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
class AscendIRParser(object):
def __init__(self):
def __init__(self, auto_dp=False, world_rank_size=1):
self.graph_idx = 0
self.hcom_endpoints = {}
self.groups_to_create = []
self._auto_dp = auto_dp
self._world_rank_size = world_rank_size
def _construct_input_map(self, input_varlist):
ret_map = {}
......@@ -91,13 +93,12 @@ class AscendIRParser(object):
print("append to create group: %s, with rank_ids: %s" %
(group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse(
ascend_parser.registerd_op[op.type])
op_parser.apply(op)
else:
print("Op[%s] has not been registered, so we have to skip it" %
(op.type))
assert False, "Op[%s] has not been registered, so we have to skip it" % (
op.type)
def _parse_program(self,
graph_name,
......@@ -161,6 +162,17 @@ class AscendIRParser(object):
startup_graph = self._parse_program("startup", startup_program)
main_graph = self._parse_program("main", main_program, input_varlist,
fetch_list)
if self._auto_dp and self._world_rank_size > 1:
assert len(self.groups_to_create
) == 0, "can't parse program under auto_dp mode"
from paddle.distributed import fleet
self.groups_to_create.append(
HcomGroupConfig(
name="hcom_group_0",
nranks=fleet.world_size(),
rank_ids=[x for x in range(fleet.world_size())]))
return startup_graph, main_graph
......@@ -196,7 +208,8 @@ class AscendOptimizer(Optimizer):
startup_program=None,
parameter_list=None,
no_grad_set=None,
auto_dp=False):
auto_dp=False,
rank_table_file=None):
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(
......@@ -205,24 +218,25 @@ class AscendOptimizer(Optimizer):
self.ascend_instance = core.AscendInstance()
from paddle.distributed import fleet
if auto_dp and fleet.worker_num() > 1:
if auto_dp and fleet.world_size() > 1:
from paddle.fluid.transpiler import ascend_transpiler
t = ascend_transpiler.AscendTranspiler(startup_program,
loss.block.program)
t.transpile()
print(loss.block.program)
#print(loss.block.program)
# Config about Graph Engine can be found in https://support.huaweicloud.com/
config = {
"ge.exec.deviceId": str(fleet.local_device_ids()),
"ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype",
# if multi mode
"ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"),
"ge.exec.rankId": str(fleet.worker_index()),
"ge.exec.isUseHcom": "1",
"ge.exec.deployMode": "0",
}
# if multi trainers
if rank_table_file and fleet.world_size() > 1:
config["ge.exec.rankTableFile"] = rank_table_file
config["ge.exec.rankId"] = str(fleet.worker_index())
config["ge.exec.isUseHcom"] = "1"
config["ge.exec.deployMode"] = "0"
print("ge_initialize config:", config)
core.ge_initialize(config)
......@@ -230,7 +244,8 @@ class AscendOptimizer(Optimizer):
self.ascend_instance.init_global_resources()
main_block = loss.block
self.parser = AscendIRParser()
self.parser = AscendIRParser(
auto_dp=auto_dp, world_rank_size=fleet.world_size())
input_varlist = self._get_input_varlist(main_block.program)
......@@ -238,9 +253,9 @@ class AscendOptimizer(Optimizer):
startup_program, main_block.program, input_varlist, self.fetch_list)
for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" %
(cfg.name, cfg.nranks, cfg.rank_ids))
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph)
......
......@@ -170,6 +170,7 @@ class AscendParserBase(object):
self.parser_name, len(index_list), output_num)
for output_id in range(output_num):
arguments = self.op.output(self.op.output_names[output_id])
#print("%d argument: %s" % (output_id, str(arguments)))
if len(arguments) > 0:
assert len(arguments) == len(
index_list[output_id]
......@@ -177,6 +178,8 @@ class AscendParserBase(object):
self.parser_name, output_id, len(index_list[output_id]),
len(arguments))
for i in range(len(arguments)):
#print("assgin index_list[%d][%d] to %s" %
# (output_id, i, arguments[i]))
self.var2geop[arguments[i]] = geop_list[index_list[
output_id][i]]
......@@ -789,6 +792,8 @@ class FillConstantParser(AscendParserBase):
"Const").set_attr_tensor("value", tensor)
self._mark_as_input(const)
if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in fill_constant" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable")
var.update_output_desc("y",
......@@ -800,6 +805,10 @@ class FillConstantParser(AscendParserBase):
"assign" + self._accumulated_op_id(), "Assign").set_input(
"value", const).set_input("ref", var)
return [const], [[0]]
#else:
# print(
# "self.op.output('Out')[0]: %s is not persistable in fill_constant"
# % (self.op.output('Out')[0]))
return [const], [[0]]
......@@ -853,6 +862,8 @@ class TruncatedNormalParser(AscendParserBase):
## wirte the output of truncatedNormal from startup_program to main_program
if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in truncated_normal" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable")
var.update_output_desc("y",
......@@ -867,6 +878,10 @@ class TruncatedNormalParser(AscendParserBase):
shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor,
truncated_normal
], [[-1]]
#else:
# print(
# "self.op.output('Out')[0] is not persistable in truncated_noraml"
# )
return [truncated_normal], [[0]]
......@@ -1436,7 +1451,8 @@ class SqueezeParser(AscendParserBase):
.set_input("x", tensor)\
.set_attr_vec_int32("axes", axes)
shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", data_squeezed)
"shape" + self._accumulated_op_id(),
"Shape").set_input("x", data_squeezed)
return [shape, data_squeezed], [[1], [0]]
......@@ -2172,4 +2188,3 @@ class AdamParser(AscendParserBase):
"epsilon", epsilon).set_input("grad", grad)
return [adam], [[0]]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册