未验证 提交 636fefd9 编写于 作者: G gongweibao 提交者: GitHub

code style (#30781)

code style
上级 88dfd067
...@@ -24,6 +24,7 @@ from collections import namedtuple ...@@ -24,6 +24,7 @@ from collections import namedtuple
HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids']) HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids'])
class AscendIRParser(object): class AscendIRParser(object):
def __init__(self): def __init__(self):
self.graph_idx = 0 self.graph_idx = 0
...@@ -34,19 +35,26 @@ class AscendIRParser(object): ...@@ -34,19 +35,26 @@ class AscendIRParser(object):
ret_map = {} ret_map = {}
ge_in_operator = [] ge_in_operator = []
for id, var in enumerate(input_varlist): for id, var in enumerate(input_varlist):
if var.is_data: # input data if var.is_data: # input data
ge_input = core.GEOperatorFactory.create_operator(var.name, "Data").set_attr_int32("index", id) ge_input = core.GEOperatorFactory.create_operator(
var.name, "Data").set_attr_int32("index", id)
ret_map[var.name] = ge_input ret_map[var.name] = ge_input
ge_in_operator.append(ge_input) ge_in_operator.append(ge_input)
else: # param, learning ... else: # param, learning ...
ge_input = core.GEOperatorFactory.create_operator(var.name, "Variable") ge_input = core.GEOperatorFactory.create_operator(var.name,
ge_input.update_output_desc("y", core.GETensorDesc(core.GEShape(var.shape), core.GEFormat.FORMAT_ND, core.GEDataType.DT_FLOAT)) "Variable")
ge_input.update_output_desc("y",
core.GETensorDesc(
core.GEShape(var.shape),
core.GEFormat.FORMAT_ND,
core.GEDataType.DT_FLOAT))
ret_map[var.name] = ge_input ret_map[var.name] = ge_input
return ge_in_operator, ret_map return ge_in_operator, ret_map
def _endpoint_to_world_rank_id(self, endpoint): def _endpoint_to_world_rank_id(self, endpoint):
world_endpoints = fleet.worker_endpoints() world_endpoints = fleet.worker_endpoints()
assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (endpoint, fleet.world_device_ids()) assert endpoint in world_endpoints, "endpoint (%s) not in worker_endpoints (%s) " % (
endpoint, fleet.world_device_ids())
return world_endpoints.index(endpoint) return world_endpoints.index(endpoint)
def parse_op(self, op): def parse_op(self, op):
...@@ -62,26 +70,40 @@ class AscendIRParser(object): ...@@ -62,26 +70,40 @@ class AscendIRParser(object):
self.hcom_endpoints[nccl_id] = other_endpoints[:] self.hcom_endpoints[nccl_id] = other_endpoints[:]
self.hcom_endpoints[nccl_id].insert(rank, endpoint) self.hcom_endpoints[nccl_id].insert(rank, endpoint)
print("nccl_id (%s) registered endpoints %s" % (nccl_id, self.hcom_endpoints[nccl_id])) print("nccl_id (%s) registered endpoints %s" %
(nccl_id, self.hcom_endpoints[nccl_id]))
elif op.type == 'c_comm_init': elif op.type == 'c_comm_init':
nccl_id = op.input_arg_names[0] nccl_id = op.input_arg_names[0]
nranks = op.attr("nranks") nranks = op.attr("nranks")
assert nranks == len(self.hcom_endpoints[nccl_id]), "nranks doesn't match endpoint count" assert nranks == len(self.hcom_endpoints[
nccl_id]), "nranks doesn't match endpoint count"
rank = op.attr("rank") rank = op.attr("rank")
ring_id = op.attr("ring_id") ring_id = op.attr("ring_id")
group_name = "hcom_group_" + str(ring_id) group_name = "hcom_group_" + str(ring_id)
global_rank_ids = [self._endpoint_to_world_rank_id(endpoint) for endpoint in self.hcom_endpoints[nccl_id]] global_rank_ids = [
self.groups_to_create.append(HcomGroupConfig(name=group_name, nranks=nranks, rank_ids=global_rank_ids)) self._endpoint_to_world_rank_id(endpoint)
print("append to create group: %s, with rank_ids: %s" % (group_name, global_rank_ids)) for endpoint in self.hcom_endpoints[nccl_id]
]
self.groups_to_create.append(
HcomGroupConfig(
name=group_name, nranks=nranks, rank_ids=global_rank_ids))
print("append to create group: %s, with rank_ids: %s" %
(group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op: elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type)) print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse(ascend_parser.registerd_op[op.type]) op_parser = self.parser_factory.create_parse(
ascend_parser.registerd_op[op.type])
op_parser.apply(op) op_parser.apply(op)
else: else:
print("Op[%s] has not been registered, so we have to skip it" % (op.type)) print("Op[%s] has not been registered, so we have to skip it" %
(op.type))
def _parse_program(self, graph_name, program, input_varlist=[], fetch_list=[]):
def _parse_program(self,
graph_name,
program,
input_varlist=[],
fetch_list=[]):
begin_graph_idx = self.graph_idx begin_graph_idx = self.graph_idx
ge_in_operator = [] ge_in_operator = []
ge_out_operator = [] ge_out_operator = []
...@@ -96,7 +118,8 @@ class AscendIRParser(object): ...@@ -96,7 +118,8 @@ class AscendIRParser(object):
ge_in_operator, self.var2geop = self._construct_input_map(input_varlist) ge_in_operator, self.var2geop = self._construct_input_map(input_varlist)
self.parser_factory = ascend_parser.AscendParserFactory(graph, self.var2geop) self.parser_factory = ascend_parser.AscendParserFactory(graph,
self.var2geop)
for i, curop in list(enumerate(block.ops)): for i, curop in list(enumerate(block.ops)):
self.parse_op(curop) self.parse_op(curop)
...@@ -133,9 +156,11 @@ class AscendIRParser(object): ...@@ -133,9 +156,11 @@ class AscendIRParser(object):
self.graph_idx += 1 self.graph_idx += 1
return graph return graph
def parse_program(self, startup_program, main_program, input_varlist, fetch_list): def parse_program(self, startup_program, main_program, input_varlist,
fetch_list):
startup_graph = self._parse_program("startup", startup_program) startup_graph = self._parse_program("startup", startup_program)
main_graph = self._parse_program("main", main_program, input_varlist, fetch_list) main_graph = self._parse_program("main", main_program, input_varlist,
fetch_list)
return startup_graph, main_graph return startup_graph, main_graph
...@@ -174,14 +199,16 @@ class AscendOptimizer(Optimizer): ...@@ -174,14 +199,16 @@ class AscendOptimizer(Optimizer):
auto_dp=False): auto_dp=False):
minimized = None minimized = None
if self.inner_opt: if self.inner_opt:
minimized = self.inner_opt.minimize(loss, startup_program=startup_program) minimized = self.inner_opt.minimize(
loss, startup_program=startup_program)
self.ascend_instance = core.AscendInstance() self.ascend_instance = core.AscendInstance()
from paddle.distributed import fleet from paddle.distributed import fleet
if auto_dp and fleet.worker_num() > 1: if auto_dp and fleet.worker_num() > 1:
from paddle.fluid.transpiler import ascend_transpiler from paddle.fluid.transpiler import ascend_transpiler
t = ascend_transpiler.AscendTranspiler(startup_program, loss.block.program) t = ascend_transpiler.AscendTranspiler(startup_program,
loss.block.program)
t.transpile() t.transpile()
print(loss.block.program) print(loss.block.program)
...@@ -211,7 +238,8 @@ class AscendOptimizer(Optimizer): ...@@ -211,7 +238,8 @@ class AscendOptimizer(Optimizer):
for cfg in self.parser.groups_to_create: for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids) hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" % (cfg.name, cfg.nranks, cfg.rank_ids)) print("create group (%s), nranks: %d, rank_ids: %s" %
(cfg.name, cfg.nranks, cfg.rank_ids))
self.ascend_instance.add_ascend_subgraph(0, startup_graph) self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph) self.ascend_instance.add_ascend_subgraph(1, main_graph)
......
...@@ -69,11 +69,13 @@ class AscendHelper(object): ...@@ -69,11 +69,13 @@ class AscendHelper(object):
} }
def dtype2ge(self, dtype): def dtype2ge(self, dtype):
assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (dtype) assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (
dtype)
return self.dtype2ge_map[dtype] return self.dtype2ge_map[dtype]
def dtype2np(self, index): def dtype2np(self, index):
assert index in self.dtype2np_map, "index[%d] is not supported %d" % (dtype) assert index in self.dtype2np_map, "index[%d] is not supported %d" % (
dtype)
return self.dtype2np_map[index] return self.dtype2np_map[index]
...@@ -98,7 +100,8 @@ class AscendParserBase(object): ...@@ -98,7 +100,8 @@ class AscendParserBase(object):
self.ascend_helper = AscendHelper() self.ascend_helper = AscendHelper()
def _get_ge_input(self, input_var_name): def _get_ge_input(self, input_var_name):
assert input_var_name in self.var2geop, "var %s not created before" % (input_var_name) assert input_var_name in self.var2geop, "var %s not created before" % (
input_var_name)
return self.var2geop[input_var_name] return self.var2geop[input_var_name]
def update_output(self, geop_list, index_list): def update_output(self, geop_list, index_list):
...@@ -119,7 +122,8 @@ class AscendParserBase(object): ...@@ -119,7 +122,8 @@ class AscendParserBase(object):
for i in range(len(arguments)): for i in range(len(arguments)):
print("assgin index_list[%d][%d] to %s" % print("assgin index_list[%d][%d] to %s" %
(output_id, i, arguments[i])) (output_id, i, arguments[i]))
self.var2geop[arguments[i]] = geop_list[index_list[output_id][i]] self.var2geop[arguments[i]] = geop_list[index_list[
output_id][i]]
for geop in geop_list: for geop in geop_list:
self.graph.add_op(geop) self.graph.add_op(geop)
...@@ -483,11 +487,15 @@ class TruncatedNormalParser(AscendParserBase): ...@@ -483,11 +487,15 @@ class TruncatedNormalParser(AscendParserBase):
"const" + self._accumulated_op_id(), "Const").set_attr_tensor( "const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor3) "value", tensor3)
tensor4 = self._create_ge_tensor([1], dtype, mean-2*std) tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std)
min_tensor = core.GEOperatorFactory.create_operator("const" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor4) min_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor4)
tensor5 = self._create_ge_tensor([1], dtype, mean+2*std) tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std)
max_tensor = core.GEOperatorFactory.create_operator("const" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor5) max_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor5)
self._mark_as_input(shape_tensor) self._mark_as_input(shape_tensor)
self._mark_as_input(mean_tensor) self._mark_as_input(mean_tensor)
...@@ -546,10 +554,11 @@ class AllGatherParser(AscendParserBase): ...@@ -546,10 +554,11 @@ class AllGatherParser(AscendParserBase):
"rank_size", rank_size).set_attr_string("group", group) "rank_size", rank_size).set_attr_string("group", group)
return [allgather], [[0]] return [allgather], [[0]]
class AllReduceParser(AscendParserBase): class AllReduceParser(AscendParserBase):
def __init__(self, graph, var2geop, reduction): def __init__(self, graph, var2geop, reduction):
super(AllReduceParser, self).__init__(graph, var2geop) super(AllReduceParser, self).__init__(graph, var2geop)
self.parser_name = "c_allreduce_" + reduction self.parser_name = "c_allreduce_" + reduction
self.reduction = reduction self.reduction = reduction
def _apply(self): def _apply(self):
...@@ -557,8 +566,8 @@ class AllReduceParser(AscendParserBase): ...@@ -557,8 +566,8 @@ class AllReduceParser(AscendParserBase):
reduction = self.reduction reduction = self.reduction
ring_id = self.op.attr("ring_id") ring_id = self.op.attr("ring_id")
group = "hcom_group_" + str(ring_id) group = "hcom_group_" + str(ring_id)
fusion = None #self.op.attr("fusion") fusion = None #self.op.attr("fusion")
fusion_id = None #self.op.attr("fusion_id") fusion_id = None #self.op.attr("fusion_id")
allreduce = core.GEOperatorFactory.create_operator( allreduce = core.GEOperatorFactory.create_operator(
"allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input( "allreduce" + self._accumulated_op_id(), "HcomAllReduce").set_input(
...@@ -611,10 +620,10 @@ class ReduceScatterParser(AscendParserBase): ...@@ -611,10 +620,10 @@ class ReduceScatterParser(AscendParserBase):
rank_size = self.op.attr("rank_size") rank_size = self.op.attr("rank_size")
reduce_scatter = core.GEOperatorFactory.create_operator( reduce_scatter = core.GEOperatorFactory.create_operator(
"reducescatter" + self._accumulated_op_id(), "HcomReduceScatter").set_input( "reducescatter" + self._accumulated_op_id(),
"x", x).set_attr_string( "HcomReduceScatter").set_input("x", x).set_attr_string(
"reduction", reduction).set_attr_string( "reduction", reduction).set_attr_string(
"group", group).set_attr_int32("rank_size", rank_size) "group", group).set_attr_int32("rank_size", rank_size)
return [reduce_scatter], [[0]] return [reduce_scatter], [[0]]
...@@ -631,9 +640,8 @@ class SendParser(AscendParserBase): ...@@ -631,9 +640,8 @@ class SendParser(AscendParserBase):
send = core.GEOperatorFactory.create_operator( send = core.GEOperatorFactory.create_operator(
"send" + self._accumulated_op_id(), "HcomSend").set_input( "send" + self._accumulated_op_id(), "HcomSend").set_input(
"x", x).set_attr_int32( "x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32( "dest_rank", dest_rank).set_attr_string("group", group)
"dest_rank", dest_rank).set_attr_string("group", group)
return [send], [[0]] return [send], [[0]]
...@@ -652,11 +660,10 @@ class ReceiveParser(AscendParserBase): ...@@ -652,11 +660,10 @@ class ReceiveParser(AscendParserBase):
receive = core.GEOperatorFactory.create_operator( receive = core.GEOperatorFactory.create_operator(
"receive" + self._accumulated_op_id(), "HcomReceive").set_input( "receive" + self._accumulated_op_id(), "HcomReceive").set_input(
"x", x).set_attr_int32( "x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
"sr_tag", sr_tag).set_attr_int32( "src_rank", src_rank).set_attr_string(
"src_rank", src_rank).set_attr_string( "group", group).set_attr_vec_int32(
"group", group).set_attr_vec_int32( "shape", shape).set_attr_int32("dtype", dtype)
"shape", shape).set_attr_int32("dtype", dtype)
return [receive], [[0]] return [receive], [[0]]
...@@ -667,18 +674,30 @@ class ScaleParser(AscendParserBase): ...@@ -667,18 +674,30 @@ class ScaleParser(AscendParserBase):
def _apply(self): def _apply(self):
x = self._get_ge_input(self.op.input_arg_names[0]) x = self._get_ge_input(self.op.input_arg_names[0])
scale = self.op.attr("scale") #self.get_ge_input(self.op.input_arg_names[1]) scale = self.op.attr(
"scale") #self.get_ge_input(self.op.input_arg_names[1])
bias = self.op.attr("bias") bias = self.op.attr("bias")
bias_after_scale = self.op.attr("bias_after_scale") bias_after_scale = self.op.attr("bias_after_scale")
if bias_after_scale: if bias_after_scale:
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", bias) scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", x).set_attr_float("power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", bias)
else: else:
x_add_bias = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", bias) #set_input("x2", bias) x_add_bias = core.GEOperatorFactory.create_operator(
scale_value = core.GEOperatorFactory.create_operator("scale" + self._accumulated_op_id(), "Power").set_input("x", x_add_bias).set_attr_float("power", 1.0).set_attr_float("scale", scale).set_attr_float("shift", 0.0) "adds" + self._accumulated_op_id(), "Adds").set_input(
"x", x).set_attr_float("value",
bias) #set_input("x2", bias)
scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", x_add_bias).set_attr_float(
"power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", 0.0)
#tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x) #tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x)
#bias_ = self.create_ge_tensor([1], 5, bias) #bias_ = self.create_ge_tensor([1], 5, bias)
#const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias) #const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias)
return [scale_value],[[0]] return [scale_value], [[0]]
class ReshapeParser(AscendParserBase): class ReshapeParser(AscendParserBase):
def __init__(self, graph, var2geop): def __init__(self, graph, var2geop):
...@@ -695,9 +714,12 @@ class ReshapeParser(AscendParserBase): ...@@ -695,9 +714,12 @@ class ReshapeParser(AscendParserBase):
print("shape: ", shape) print("shape: ", shape)
data_x1_shape = self._get_ge_input(self.op.input_arg_names[0]) data_x1_shape = self._get_ge_input(self.op.input_arg_names[0])
tensor = self._create_ge_tensor([len(shape)], 2, shape) tensor = self._create_ge_tensor([len(shape)], 2, shape)
const_shape = core.GEOperatorFactory.create_operator("shape" + self._accumulated_op_id(), "Const").set_attr_tensor("value", tensor) const_shape = core.GEOperatorFactory.create_operator(
reshape = core.GEOperatorFactory.create_operator("reshape" + self._accumulated_op_id(), "Reshape").set_input("x", data_x1_shape).set_input("shape", const_shape).set_attr_int32("axis", axis) "shape" + self._accumulated_op_id(), "Const").set_attr_tensor(
"value", tensor)
return [reshape, reshape], [[0],[1]] reshape = core.GEOperatorFactory.create_operator(
"reshape" + self._accumulated_op_id(), "Reshape").set_input(
"x", data_x1_shape).set_input(
"shape", const_shape).set_attr_int32("axis", axis)
return [reshape, reshape], [[0], [1]]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册