未验证 提交 636fefd9 编写于 作者: G gongweibao 提交者: GitHub

code style (#30781)

code style
上级 88dfd067
......@@ -24,6 +24,7 @@ from collections import namedtuple
HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
class AscendIRParser(object):
def __init__(self):
self.graph_idx = 0
......@@ -35,18 +36,25 @@ class AscendIRParser(object):
ge_in_operator = []
for id, var in enumerate(input_varlist):
if var.is_data: # input data
ge_input = core.GEOperatorFactory.create_operator(var.name, "Data").set_attr_int32("index", id)
ge_input = core.GEOperatorFactory.create_operator(
var.name, "Data").set_attr_int32("index", id)
ret_map[var.name] = ge_input
ge_in_operator.append(ge_input)
else: # param, learning ...
ge_input = core.GEOperatorFactory.create_operator(var.name, "Variable")
ge_input.update_output_desc("y", core.GETensorDesc(core.GEShape(var.shape), core.GEFormat.FORMAT_ND, core.GEDataType.DT_FLOAT))
ge_input = core.GEOperatorFactory.create_operator(var.name,
"Variable")
ge_input.update_output_desc("y",
core.GETensorDesc(
core.GEShape(var.shape),
core.GEFormat.FORMAT_ND,
core.GEDataType.DT_FLOAT))
ret_map[var.name] = ge_input
return ge_in_operator, ret_map
def _endpoint_to_world_rank_id(self, endpoint):
world_endpoints = fleet.worker_endpoints()
assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids())
assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (
endpoint, fleet.world_device_ids())
return world_endpoints.index(endpoint)
def parse_op(self, op):
......@@ -62,26 +70,40 @@ class AscendIRParser(object):
self.hcom_endpoints[nccl_id] = other_endpoints[:]
self.hcom_endpoints[nccl_id].insert(rank, endpoint)
print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id]))
print("nccl_id (%s) registered endpoints %s" %
(nccl_id, self.hcom_endpoints[nccl_id]))
elif op.type == 'c_comm_init':
nccl_id = op.input_arg_names[0]
nranks = op.attr("nranks")
assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count"
assert nranks == len(self.hcom_endpoints[
nccl_id]), "nranks doesn't match endpoint count"
rank = op.attr("rank")
ring_id = op.attr("ring_id")
group_name = "hcom_group_" + str(ring_id)
global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]]
self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids))
print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids))
global_rank_ids = [
self._endpoint_to_world_rank_id(endpoint)
for endpoint in self.hcom_endpoints[nccl_id]
]
self.groups_to_create.append(
HcomGroupConfig(
name=group_name, nranks=nranks, rank_ids=global_rank_ids))
print("append to create group: %s, with rank_ids: %s" %
(group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type])
op_parser = self.parser_factory.create_parse(
ascend_parser.registerd_op[op.type])
op_parser.apply(op)
else:
print("Op[%s] has not been registered, so we have to skip it" % (op.type))
def _parse_program(self, graph_name, program, input_varlist=[], fetch_list=[]):
print("Op[%s] has not been registered, so we have to skip it" %
(op.type))
def _parse_program(self,
graph_name,
program,
input_varlist=[],
fetch_list=[]):
begin_graph_idx = self.graph_idx
ge_in_operator = []
ge_out_operator = []
......@@ -96,7 +118,8 @@ class AscendIRParser(object):
ge_in_operator, self.var2geop = self._construct_input_map(input_varlist)
self.parser_factory = ascend_parser.AscendParserFactory(graph, self.var2geop)
self.parser_factory = ascend_parser.AscendParserFactory(graph,
self.var2geop)
for i, curop in list(enumerate(block.ops)):
self.parse_op(curop)
......@@ -133,9 +156,11 @@ class AscendIRParser(object):
self.graph_idx += 1
return graph
def parse_program(self, startup_program, main_program, input_varlist, fetch_list):
def parse_program(self, startup_program, main_program, input_varlist,
fetch_list):
startup_graph = self._parse_program("startup", startup_program)
main_graph = self._parse_program("main", main_program, input_varlist, fetch_list)
main_graph = self._parse_program("main", main_program, input_varlist,
fetch_list)
return startup_graph, main_graph
......@@ -174,14 +199,16 @@ class AscendOptimizer(Optimizer):
auto_dp=False):
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(loss, startup_program=startup_program)
minimized = self.inner_opt.minimize(
loss, startup_program=startup_program)
self.ascend_instance = core.AscendInstance()
from paddle.distributed import fleet
if auto_dp and fleet.worker_num() > 1:
from paddle.fluid.transpiler import ascend_transpiler
t = ascend_transpiler.AscendTranspiler(startup_program, loss.block.program)
t = ascend_transpiler.AscendTranspiler(startup_program,
loss.block.program)
t.transpile()
print(loss.block.program)
......@@ -211,7 +238,8 @@ class AscendOptimizer(Optimizer):
for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids))
print("create group (%s), nranks: %d, rank_ids: %s" %
(cfg.name, cfg.nranks, cfg.rank_ids))
self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph)
......
......@@ -69,11 +69,13 @@ class AscendHelper(object):
}
def dtype2ge(self, dtype):
assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (dtype)
assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (
dtype)
return self.dtype2ge_map[dtype]
def dtype2np(self, index):
assert index in self.dtype2np_map, "index[%d] is not supported %d" % (dtype)
assert index in self.dtype2np_map, "index[%d] is not supported %d" % (
dtype)
return self.dtype2np_map[index]
......@@ -98,7 +100,8 @@ class AscendParserBase(object):
self.ascend_helper = AscendHelper()
def _get_ge_input(self, input_var_name):
assert input_var_name in self.var2geop, "var %s not created before" % (input_var_name)
assert input_var_name in self.var2geop, "var %s not created before" % (
input_var_name)
return self.var2geop[input_var_name]
def update_output(self, geop_list, index_list):
......@@ -119,7 +122,8 @@ class AscendParserBase(object):
for i in range(len(arguments)):
print("assgin index_list[%d][%d] to %s" %
(output_id, i, arguments[i]))
self.var2geop[arguments[i]] = geop_list[index_list[output_id][i]]
self.var2geop[arguments[i]] = geop_list[index_list[
output_id][i]]
for geop in geop_list:
self.graph.add_op(geop)
......@@ -483,11 +487,15 @@ class TruncatedNormalParser(AscendParserBase):
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor3)
tensor4 = self._create_ge_tensor([1], dtype, mean-2*std)
min_tensor = core.GEOperatorFactory.create_operator("const" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor4)
tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std)
min_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor4)
tensor5 = self._create_ge_tensor([1], dtype, mean+2*std)
max_tensor = core.GEOperatorFactory.create_operator("const" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor5)
tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std)
max_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor5)
self._mark_as_input(shape_tensor)
self._mark_as_input(mean_tensor)
......@@ -546,6 +554,7 @@ class AllGatherParser(AscendParserBase):
"rank_size", rank_size).set_attr_string("group", group)
return [allgather], [[0]]
class AllReduceParser(AscendParserBase):
def __init__(self, graph, var2geop, reduction):
super(AllReduceParser, self).__init__(graph, var2geop)
......@@ -611,8 +620,8 @@ class ReduceScatterParser(AscendParserBase):
rank_size = self.op.attr("rank_size")
reduce_scatter = core.GEOperatorFactory.create_operator(
"reducescatter" + self._accumulated_op_id(), "HcomReduceScatter").set_input(
"x", x).set_attr_string(
"reducescatter" + self._accumulated_op_id(),
"HcomReduceScatter").set_input("x", x).set_attr_string(
"reduction", reduction).set_attr_string(
"group", group).set_attr_int32("rank_size", rank_size)
return [reduce_scatter], [[0]]
......@@ -631,8 +640,7 @@ class SendParser(AscendParserBase):
send = core.GEOperatorFactory.create_operator(
"send" + self._accumulated_op_id(), "HcomSend").set_input(
"x", x).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32(
"x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
"dest_rank", dest_rank).set_attr_string("group", group)
return [send], [[0]]
......@@ -652,8 +660,7 @@ class ReceiveParser(AscendParserBase):
receive = core.GEOperatorFactory.create_operator(
"receive" + self._accumulated_op_id(), "HcomReceive").set_input(
"x", x).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32(
"x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
"src_rank", src_rank).set_attr_string(
"group", group).set_attr_vec_int32(
"shape", shape).set_attr_int32("dtype", dtype)
......@@ -667,18 +674,30 @@ class ScaleParser(AscendParserBase):
def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0])
scale = self.op.attr("scale") #self.get_ge_input(self.op.input_arg_names[1])
scale = self.op.attr(
"scale") #self.get_ge_input(self.op.input_arg_names[1])
bias = self.op.attr("bias")
bias_after_scale = self.op.attr("bias_after_scale")
if bias_after_scale:
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", bias)
scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", x).set_attr_float("power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", bias)
else:
x_add_bias = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", bias) #set_input("x2", bias)
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0)
x_add_bias = core.GEOperatorFactory.create_operator(
"adds" + self._accumulated_op_id(), "Adds").set_input(
"x", x).set_attr_float("value",
bias) #set_input("x2", bias)
scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", x_add_bias).set_attr_float(
"power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", 0.0)
#tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x)
#bias_ = self.create_ge_tensor([1], 5, bias)
#const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias)
return [scale_value],[[0]]
return [scale_value], [[0]]
class ReshapeParser(AscendParserBase):
def __init__(self, graph, var2geop):
......@@ -695,9 +714,12 @@ class ReshapeParser(AscendParserBase):
print("shape: ", shape)
data_x1_shape = self._get_ge_input(self.op.input_arg_names[0])
tensor = self._create_ge_tensor([len(shape)], 2, shape)
const_shape = core.GEOperatorFactory.create_operator("shape" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor)
reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis)
return [reshape, reshape], [[0],[1]]
const_shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor)
reshape = core.GEOperatorFactory.create_operator(
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
"x", data_x1_shape).set_input(
"shape", const_shape).set_attr_int32("axis", axis)
return [reshape, reshape], [[0], [1]]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册