diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index e322b34575d5fdbcc82f27d9a50b0f8dd7bb634c..210dff45de23ea7696f2d56b9afb8cea566af8cd 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -63,7 +63,7 @@ _group_map_backend = {} # Name of the default group for init_parallel_env _default_group_name = "_default_pg" -_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter', 'xccl', 'bkcl'] +_valid_backend_list = ['nccl', 'gloo', 'heter', 'xccl', 'bkcl'] _default_store = None # the default tcp store _default_backend = None _default_timeout = datetime.timedelta(seconds=1800) diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index 847dea2f42a8eefb674fa6c79789ac8511e188c3..a55a47943b3d7bc44d8b5d62ee0cf70124767703 100755 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -115,7 +115,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra "--backend", type=str, default=os.environ.get('PADDLE_DISTRI_BACKEND', 'auto'), - help="Specifize the backend, can be gloo|nccl|bkcl|auto|hccl|heter. " + help="Specifize the backend, can be gloo|nccl|bkcl|auto|heter. " "Default value is auto which perfers nccl or bkcl.", ) base_group.add_argument( diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index 458ce6540aa877159e2d41c46a601f6208987ac8..2638a42a6c5cc7ca3c2a83124ad25631cec72565 100755 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -1988,14 +1988,13 @@ def check_backend(backend): 'bkcl', 'cncl', 'auto', - 'hccl', 'heter', 'xccl', ]: raise ValueError( "paddle.distributed initialize error, " "backend argument can only be one of " - "'nccl', 'gloo', 'bkcl', 'auto', 'hccl', 'heter', 'xccl' " + "'nccl', 'gloo', 'bkcl', 'auto', 'heter', 'xccl' " "but got %s" % backend ) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/__init__.py deleted file mode 100644 index 185a92b8d94d3426d616c0624f0f2ee04339349e..0000000000000000000000000000000000000000 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py deleted file mode 100644 index e5a1e1b26014acf9c07832b13e5a15ea2b534cae..0000000000000000000000000000000000000000 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import namedtuple - -import hccl.manage.api as hccl - -from paddle.distributed import fleet -from paddle.framework import core -from paddle.optimizer import Optimizer - -from . import ascend_parser - -HcomGroupConfig = namedtuple('HcomGroupConfig', ['name', 'nranks', 'rank_ids']) - -__all__ = [] - - -class AscendIRParser: - def __init__(self, auto_dp=False, world_rank_size=1): - self.graph_idx = 0 - self.hcom_endpoints = {} - self.groups_to_create = [] - self._auto_dp = auto_dp - self._world_rank_size = world_rank_size - - def _construct_input_map(self, input_varlist): - ret_map = {} - ge_in_operator = [] - for id, var in enumerate(input_varlist): - if var.is_data: # input data - ge_input = core.GEOperatorFactory.create_operator( - var.name, "Data" - ).set_attr_int32("index", id) - ret_map[var.name] = ge_input - ge_in_operator.append(ge_input) - else: # param, learning ... - ge_input = core.GEOperatorFactory.create_operator( - var.name, "Variable" - ) - ge_input.update_output_desc( - "y", - core.GETensorDesc( - core.GEShape(var.shape), - core.GEFormat.FORMAT_ND, - core.GEDataType.DT_FLOAT, - ), - ) - ret_map[var.name] = ge_input - return ge_in_operator, ret_map - - def _endpoint_to_world_rank_id(self, endpoint): - world_endpoints = fleet.worker_endpoints() - assert ( - endpoint in world_endpoints - ), "endpoint ({}) not in worker_endpoints ({}) ".format( - endpoint, - fleet.world_device_ids(), - ) - return world_endpoints.index(endpoint) - - def parse_op(self, op): - if op.type == 'c_gen_nccl_id': - endpoint = op.attr("endpoint") - other_endpoints = op.attr("other_endpoints") - rank = op.attr("rank") - - nccl_id = op.output_arg_names[0] - - # c_gen_nccl_id operator splits endpoints into local endpoint and other_endpoints - # we should combine these together to produce world_rank_ids - self.hcom_endpoints[nccl_id] = other_endpoints[:] - self.hcom_endpoints[nccl_id].insert(rank, endpoint) - - print( - "nccl_id (%s) registered endpoints %s" - % (nccl_id, self.hcom_endpoints[nccl_id]) - ) - elif op.type == 'c_comm_init': - nccl_id = op.input_arg_names[0] - nranks = op.attr("nranks") - assert nranks == len( - self.hcom_endpoints[nccl_id] - ), "nranks doesn't match endpoint count" - rank = op.attr("rank") - ring_id = op.attr("ring_id") - - group_name = "hcom_group_" + str(ring_id) - global_rank_ids = [ - self._endpoint_to_world_rank_id(endpoint) - for endpoint in self.hcom_endpoints[nccl_id] - ] - self.groups_to_create.append( - HcomGroupConfig( - name=group_name, nranks=nranks, rank_ids=global_rank_ids - ) - ) - print( - "append to create group: %s, with rank_ids: %s" - % (group_name, global_rank_ids) - ) - elif op.type in ascend_parser.registerd_op: - op_parser = self.parser_factory.create_parse( - ascend_parser.registerd_op[op.type] - ) - op_parser.apply(op) - else: - raise AssertionError( - 'Op[%s] has not been registered, so we have to skip it' - % op.type - ) - - def _parse_program( - self, graph_name, program, input_varlist=[], fetch_list=[] - ): - begin_graph_idx = self.graph_idx - ge_in_operator = [] - ge_out_operator = [] - self.var2geop = {} - - block = program.global_block() - if len(block.ops) == 0: - print("There is no ops in program %s" % (graph_name)) - return [] - - graph = core.GEGraph(graph_name) - - ge_in_operator, self.var2geop = self._construct_input_map(input_varlist) - - self.parser_factory = ascend_parser.AscendParserFactory( - graph, self.var2geop - ) - for i, curop in list(enumerate(block.ops)): - self.parse_op(curop) - - # Set fetch_var for GE - for e in fetch_list: - name = e - if not isinstance(e, str): - name = e.name - ge_out_operator.append(self.var2geop[name]) - - # (Debug) If you want to print back prop vars, append/assign the varname in ge_out_operator here, such as: - # if graph_name == "main": - # ge_out_operator.append(self.var2geop["reduce_sum_0.tmp_0@GRAD"]) - - # Add ops that may be input of a graph, such as const. - for varname, geop in self.var2geop.items(): - if varname.startswith("geinput"): - ge_in_operator.append(geop) - - graph.set_inputs(ge_in_operator).set_outputs(ge_out_operator) - - # Remove ops of origin program - op_num = len(block.ops) - for i in range(op_num - 1, -1, -1): - block._remove_op(i) - - input_varlist = [var for var in input_varlist if var.is_data] - - block.append_op( - type="ascend_trigger", - inputs={"FeedList": input_varlist}, - outputs={"FetchList": fetch_list}, - attrs={'graph_idx': self.graph_idx}, - ) - self.graph_idx += 1 - return graph - - def parse_program( - self, startup_program, main_program, input_varlist, fetch_list - ): - startup_graph = self._parse_program("startup", startup_program) - main_graph = self._parse_program( - "main", main_program, input_varlist, fetch_list - ) - if self._auto_dp and self._world_rank_size > 1: - assert ( - len(self.groups_to_create) == 0 - ), "can't parse program under auto_dp mode" - - from paddle.distributed import fleet - - self.groups_to_create.append( - HcomGroupConfig( - name="hcom_group_0", - nranks=fleet.world_size(), - rank_ids=list(range(fleet.world_size())), - ) - ) - - return startup_graph, main_graph - - -# AscendOptimizer is a wrapper for basic optimizer now -# We will make it part of fleet meta_optimizer in the future -class AscendOptimizer(Optimizer): - def __init__(self, optimizer, fetch_list=[]): - self.inner_opt = optimizer - self.fetch_list = fetch_list - self.ascend_instance = None - - def __del__(self): - print("begin AscendOptimizer del") - if self.ascend_instance is not None: - self.ascend_instance.destroy_global_resources() - core.ge_finalize() - print("end AscendOptimizer del") - - def _can_apply(self): - if not self.user_defined_strategy.ascend: - return False - # TODO(hutuxian): other check here - return True - - def _disable_strategy(self, dist_strategy): - dist_strategy.ascend = False - dist_strategy.ascend_configs = {} - - def _get_input_varlist(self, program): - ret_list = [] - for var in program.list_vars(): - if var.is_data or var.persistable: - ret_list.append(var) - return ret_list - - def _set_auxiliary_var(self, key, val): - super()._set_auxiliary_var(key, val) - self.inner_opt._set_auxiliary_var(key, val) - - def minimize( - self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None, - auto_dp=False, - rank_table_file=None, - precision_mode="must_keep_origin_dtype", - ): - minimized = None - if self.inner_opt: - minimized = self.inner_opt.minimize( - loss, startup_program=startup_program - ) - - self.ascend_instance = core.AscendInstance() - - from paddle.distributed import fleet - - if auto_dp and fleet.world_size() > 1: - from paddle.distributed.transpiler import ascend_transpiler - - t = ascend_transpiler.AscendTranspiler( - startup_program, loss.block.program - ) - t.transpile() - # print(loss.block.program) - - # Config about Graph Engine can be found in https://support.huaweicloud.com/ - config = { - "ge.exec.deviceId": str(fleet.local_device_ids()), - "ge.graphRunMode": "1", - "ge.exec.precision_mode": precision_mode, - } - # if multi trainers - if rank_table_file and fleet.world_size() > 1: - config["ge.exec.rankTableFile"] = rank_table_file - config["ge.exec.rankId"] = str(fleet.worker_index()) - config["ge.exec.isUseHcom"] = "1" - config["ge.exec.deployMode"] = "0" - print("ge_initialize config:", config) - core.ge_initialize(config) - - # Init Session - self.ascend_instance.init_global_resources() - - main_block = loss.block - self.parser = AscendIRParser( - auto_dp=auto_dp, world_rank_size=fleet.world_size() - ) - - input_varlist = self._get_input_varlist(main_block.program) - - startup_graph, main_graph = self.parser.parse_program( - startup_program, main_block.program, input_varlist, self.fetch_list - ) - - for cfg in self.parser.groups_to_create: - print( - "create group (%s), nranks: %d, rank_ids: %s" - % (cfg.name, cfg.nranks, cfg.rank_ids) - ) - hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids) - - self.ascend_instance.add_ascend_subgraph(0, startup_graph) - self.ascend_instance.add_ascend_subgraph(1, main_graph) - - return minimized diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py deleted file mode 100644 index d973541c8660fa88597095b81a96fad90817956d..0000000000000000000000000000000000000000 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ /dev/null @@ -1,2813 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import reduce - -import numpy as np - -from paddle.framework import core - -__all__ = [] - -registerd_op = { # forwards - "elementwise_add": "AddParser", - "matmul": "MatMulParser", - "mul": "MulParser", - "relu": "ReluParser", - "softmax_with_cross_entropy": "SoftmaxWithCrossEntropyParser", - "shape": "ShapeParser", - "fill_constant": "FillConstantParser", - "reduce_sum": "ReduceSumParser", - "elementwise_mul": "DotMulParser", - "elementwise_div": "DotDivParser", - "elementwise_pow": "DotPowParser", - "elementwise_max": "MaxParser", - "elementwise_min": "MinParser", - "elementwise_sub": "DotSubParser", - "pow": "PowParser", - "gelu": "GeluParser", - "sqrt": "SqrtParser", - "log": "LogParser", - "sum": "SumParser", - "logical_not": "LogicalNotParser", - "gather": "GatherParser", - "scatter": "ScatterParser", - "cast": "CastParser", - "tanh": "TanhParser", - "stack": "StackParser", - "square": "SquareParser", - "unsqueeze2": "UnSqueezeParser", - "assign": "AssignParser", - "softmax": "SoftMaxParser", - "reshape2": "ReshapeParser", - "transpose2": "TransposeParser", - "layer_norm": "LayerNormParser", - "less_than": "LessParser", - "mean": "MeanParser", - "scale": "ScaleParser", - "slice": "SliceParser", - "top_k": "TopkParser", - "accuracy": "AccuracyParser", - # "increment": "IncrementParser", - "lookup_table": "LookupTableParser", - "truncated_gaussian_random": "TruncatedNormalParser", - "c_allgather": "AllGatherParser", - "c_allreduce_sum": "AllReduceSumParser", - "c_allreduce_max": "AllReduceMaxParser", - "c_broadcast": "BroadcastParser", - "c_reduce_scatter": "ReduceScatterParser", - "c_send": "SendParser", - "c_receive": "ReceiveParser", - "uniform_random": "UniformRandomParser", - "range": "RangeParser", - "equal": "EqualParser", - "expand": "ExpandParser", - "squeeze2": "SqueezeParser", - # backwords - "matmul_grad": "MatMulGradParser", - "mul_grad": "MulGradParser", - "relu_grad": "ReluGradParser", - "reduce_sum_grad": "ReduceSumGradParser", - "softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser", - "tanh_grad": "TanhGradParser", - "log_grad": "LogGradParser", - "pow_grad": "PowGradParser", - "sqrt_grad": "SqrtGradParser", - "gelu_grad": "GeluGradParser", - "mean_grad": "MeanGradParser", - 'lookup_table_grad': "LookUpTableGradParser", - "elementwise_mul_grad": "DotMulGradParser", - "elementwise_add_grad": "DotAddGradParser", - "elementwise_div_grad": "DotDivGradParser", - "softmax_grad": "SoftmaxGradParser", - "slice_grad": "SliceGradParser", - "reshape2_grad": "ReshapeGradParser", - "gather_grad": "GatherGradParser", - "transpose2_grad": "TransposeGradParser", - "layer_norm_grad": "LayerNormGradParser", - # opt - "sgd": "SGDParser", - # "adam": "AdamParser", -} -global_cnt = -1 -global_input_cnt = -1 - - -class AscendHelper: - def __init__(self): - self.dtype2ge_map = { - 0: core.GEDataType.DT_BOOL, - 1: core.GEDataType.DT_INT16, - 2: core.GEDataType.DT_INT32, - 3: core.GEDataType.DT_INT64, - 4: core.GEDataType.DT_FLOAT16, - 5: core.GEDataType.DT_FLOAT, - 6: core.GEDataType.DT_DOUBLE, - } - self.dtype2np_map = { - 0: "bool", - 1: "int16", - 2: "int32", - 3: "int64", - 4: "float16", - 5: "float32", - 6: "float64", - } - self.dtype2paddle_inv_map = {"VarType.FP32": 0, "VarType.FP16": 1} - - def dtype2ge(self, dtype): - assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % ( - dtype - ) - return self.dtype2ge_map[dtype] - - def dtype2np(self, index): - assert index in self.dtype2np_map, "index[%d] is not supported %d" % ( - index - ) - return self.dtype2np_map[index] - - -class AscendParserFactory: - def __init__(self, graph, var2geop): - self.graph = graph - self.var2geop = var2geop - - def create_parse(self, parser_class): - try: - parser = globals()[parser_class](self.graph, self.var2geop) - return parser - except: - raise ValueError("parser class %s does not exist" % parser_class) - - -class AscendParserBase: - def __init__(self, graph, var2geop): - self.graph = graph - self.var2geop = var2geop - self.op = None - self.ascend_helper = AscendHelper() - - def _get_ge_input(self, input_var_name): - assert input_var_name in self.var2geop, "var %s not created before" % ( - input_var_name - ) - return self.var2geop[input_var_name] - - def update_output(self, geop_list, index_list): - output_num = len(self.op.output_names) - assert output_num == len(index_list), ( - "Parser[%s]'s output number[%d] is not equal to parameters number[%d]" - % (self.parser_name, len(index_list), output_num) - ) - for output_id in range(output_num): - arguments = self.op.output(self.op.output_names[output_id]) - if len(arguments) > 0: - assert len(arguments) == len(index_list[output_id]), ( - "Parser[%s]'s %dth argument number[%d] is not equal to paddle's number[%d]" - % ( - self.parser_name, - output_id, - len(index_list[output_id]), - len(arguments), - ) - ) - for i in range(len(arguments)): - self.var2geop[arguments[i]] = geop_list[ - index_list[output_id][i] - ] - - for geop in geop_list: - self.graph.add_op(geop) - - def apply(self, op): - self.op = op - assert ( - self.op.type == self.parser_name - ), f"op [{self.op.type}] != parser_name[{self.parser_name}]" - # print("begin to parse op %s" % (self.parser_name)) - geop_list, index_list = self._apply() - self.update_output(geop_list, index_list) - - def _mark_as_input(self, ge_tensor): - global global_input_cnt - global_input_cnt += 1 - self.var2geop["geinput." + str(global_input_cnt)] = ge_tensor - - def _accumulated_op_id(self): - global global_cnt - global_cnt += 1 - name = "." + str(global_cnt) - return name - - def _create_ge_tensor(self, shape, dtype, value): - tensor_desc = core.GETensorDesc( - core.GEShape(shape), - core.GEFormat.FORMAT_ND, - self.ascend_helper.dtype2ge(dtype), - ) - tensor = core.GETensor(tensor_desc) - - data = ( - (value * np.ones(shape)) - .reshape(shape) - .astype(self.ascend_helper.dtype2np(dtype)) - ) - buf = data.tobytes() - data_8 = np.frombuffer(buf, dtype=np.uint8) - tensor.set_data(data_8) - return tensor - - def _get_ge_tensor(self, shape, dtype, value_list): - tensor_desc = core.GETensorDesc( - core.GEShape(shape), - core.GEFormat.FORMAT_ND, - self.ascend_helper.dtype2ge(dtype), - ) - tensor = core.GETensor(tensor_desc) - - data = ( - np.array(value_list) - .reshape(shape) - .astype(self.ascend_helper.dtype2np(dtype)) - ) - buf = data.tobytes() - data_8 = np.frombuffer(buf, dtype=np.uint8) - tensor.set_data(data_8) - - tensor_const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - - return tensor_const - - def _get_variable(self, shape, dtype, tensor): - if dtype == "int32": - type = core.GEDataType.DT_INT32 - elif dtype == "float32": - type = core.GEDataType.DT_FLOAT - - var = core.GEOperatorFactory.create_operator( - "variable" + self._accumulated_op_id(), "Variable" - ) - var.update_output_desc( - "y", - core.GETensorDesc( - core.GEShape(shape), core.GEFormat.FORMAT_ND, type - ), - ) - assign = ( - core.GEOperatorFactory.create_operator( - "assign" + self._accumulated_op_id(), "Assign" - ) - .set_input("value", tensor) - .set_input("ref", var) - ) - - return assign - - def _create_shape_tensor(self): - tensor_desc = core.GETensorDesc( - core.GEShape([2]), core.GEFormat.FORMAT_ND, core.GEDataType.DT_INT32 - ) - tensor = core.GETensor(tensor_desc) - - data = np.ones(2).astype("int32").reshape([2]) - data[0] = 64 - buf = data.tobytes() - data_8 = np.frombuffer(buf, dtype=np.uint8) - tensor.set_data(data_8) - return tensor - - def _get_GEtensor_shape(self, tensor): - tensor_shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", tensor) - tensor_shape = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", tensor_shape) - .set_attr_int32("dst_type", 0) - ) - return tensor_shape - - -class AddParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_add" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - add = ( - core.GEOperatorFactory.create_operator( - "add" + self._accumulated_op_id(), "Add" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [add], [[0]] - - -class DotSubParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_sub" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - sub = ( - core.GEOperatorFactory.create_operator( - "sub" + self._accumulated_op_id(), "Sub" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [sub], [[0]] - - -class DotMulParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_mul" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - mul = ( - core.GEOperatorFactory.create_operator( - "dotmul" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [mul], [[0]] - - -class DotDivParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_div" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - div = ( - core.GEOperatorFactory.create_operator( - "dotdiv" + self._accumulated_op_id(), "Div" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [div], [[0]] - - -class DotPowParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_pow" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - pow = ( - core.GEOperatorFactory.create_operator( - "dotpow" + self._accumulated_op_id(), "Pow" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [pow], [[0]] - - -class LessParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "less_than" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - less_than = ( - core.GEOperatorFactory.create_operator( - "less_than" + self._accumulated_op_id(), "Less" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [less_than], [[0]] - - -class MaxParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_max" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - max_out = ( - core.GEOperatorFactory.create_operator( - "max" + self._accumulated_op_id(), "Maximum" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [max_out], [[0]] - - -class MinParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_min" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - min_out = ( - core.GEOperatorFactory.create_operator( - "min" + self._accumulated_op_id(), "Minimum" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - return [min_out], [[0]] - - -# cal -class LogParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "log" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - log = core.GEOperatorFactory.create_operator( - "log" + self._accumulated_op_id(), "Log" - ).set_input("x", x) - return [log], [[0]] - - -class SqrtParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "sqrt" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - sqrt = core.GEOperatorFactory.create_operator( - "sqrt" + self._accumulated_op_id(), "Sqrt" - ).set_input("x", x) - return [sqrt], [[0]] - - -class PowParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "pow" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - factor = self.op.attr("factor") - pow_value = ( - core.GEOperatorFactory.create_operator( - "pow" + self._accumulated_op_id(), "Power" - ) - .set_input("x", x) - .set_attr_float("power", factor) - .set_attr_float("scale", 1.0) - .set_attr_float("shift", 0.0) - ) - return [pow_value], [[0]] - - -class SquareParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "square" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - square = core.GEOperatorFactory.create_operator( - "square" + self._accumulated_op_id(), "Square" - ).set_input("x", x) - return [square], [[0]] - - -class SumParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "sum" - - def _apply(self): - len_list = len(self.op.input_arg_names) - if len_list < 2: - raise AssertionError("the size of input list must large or equal 2") - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - sum = ( - core.GEOperatorFactory.create_operator( - "sum" + self._accumulated_op_id(), "Add" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - for i in range(2, len_list): - y = self._get_ge_input(self.op.input_arg_names[i]) - sum = ( - core.GEOperatorFactory.create_operator( - "sum" + self._accumulated_op_id(), "Add" - ) - .set_input("x1", sum) - .set_input("x2", y) - ) - return [sum], [[0]] - - -class LogicalNotParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "logical_not" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - logical_not = core.GEOperatorFactory.create_operator( - "logical_not" + self._accumulated_op_id(), "LogicalNot" - ).set_input("x", x) - return [logical_not], [[0]] - - -class MeanParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "mean" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - mean = ( - core.GEOperatorFactory.create_operator( - "mean" + self._accumulated_op_id(), "ReduceMeanD" - ) - .set_input("x", x) - .set_attr_bool("keep_dims", False) - .set_attr_vec_int32("axes", []) - ) - return [mean], [[0]] - - -class ReduceSumParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "reduce_sum" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - axes = self.op.attr("dim") - keep_dims = self.op.attr("keep_dim") - reduce_all = self.op.attr("reduce_all") - x_shape = self.op.block.var(self.op.input_arg_names[0]).shape - if reduce_all: - axes = list(range(len(x_shape))) - reduce_sum = ( - core.GEOperatorFactory.create_operator( - "reduce_sum" + self._accumulated_op_id(), "ReduceSumD" - ) - .set_input("x", x, 0) - .set_attr_vec_int32("axes", axes) - .set_attr_bool("keep_dims", keep_dims) - ) - return [reduce_sum], [[0]] - - -# class IncrementParser(AscendParserBase): -# def __init__(self, graph, var2geop): -# super().__init__(graph, var2geop) -# self.parser_name = "increment" -# -# def _apply(self): -# x = self._get_ge_input(self.op.input_arg_names[0]) -# step = self.op.attr("step") #self._get_ge_input(self.op.input_arg_names[1]) -# print("step: ", step) -# -# increment = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", step) #set_input("x2", bias) -# -# return [increment] - - -# matrix cal -class MatMulParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "matmul" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - transpose_x = self.op.attr("transpose_X") - transpose_y = self.op.attr("transpose_Y") - - x1_shape = self.op.block.var(self.op.input_arg_names[0]).shape - x2_shape = self.op.block.var(self.op.input_arg_names[1]).shape - - if len(x1_shape) > 2: - matmul = ( - core.GEOperatorFactory.create_operator( - "matmul" + self._accumulated_op_id(), "BatchMatMul" - ) - .set_input("x1", x) - .set_input("x2", y) - .set_attr_bool("adj_x1", transpose_x) - .set_attr_bool("adj_x2", transpose_y) - ) - elif len(x1_shape) == 2: - matmul = ( - core.GEOperatorFactory.create_operator( - "matmul" + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", x) - .set_input("x2", y) - .set_attr_bool("transpose_x1", transpose_x) - .set_attr_bool("transpose_x2", transpose_y) - ) - else: - raise AssertionError("not support") - return [matmul], [[0]] - - -class MulParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "mul" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - y = self._get_ge_input(self.op.input_arg_names[1]) - x_num_col_dims = self.op.attr("x_num_col_dims") - y_num_col_dims = self.op.attr("y_num_col_dims") - shape_x1 = self.op.block.var(self.op.input_arg_names[0]).shape - shape_x2 = self.op.block.var(self.op.input_arg_names[1]).shape - - if x_num_col_dims == 1 and y_num_col_dims == 1: - if len(shape_x1) == 2 and len(shape_x2) == 2: - matmul = ( - core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", x) - .set_input("x2", y) - ) - elif len(shape_x1) == 3 and len(shape_x2) == 2: - flatten_x1 = core.GEOperatorFactory.create_operator( - "flatten" + self._accumulated_op_id(), "Flatten" - ).set_input("x", x) - matmul = ( - core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", flatten_x1, 0) - .set_input("x2", y, 0) - ) - else: - raise AssertionError("not support") - else: - if len(shape_x1) == 3 and len(shape_x2) == 2: - assert x_num_col_dims == 2, "only support 2" - flatten_x1 = ( - core.GEOperatorFactory.create_operator( - "flatten" + self._accumulated_op_id(), "FlattenV2" - ) - .set_input("x", x) - .set_attr_int32("axis", 0) - .set_attr_int32("end_axis", 1) - ) - matmul_m = ( - core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", flatten_x1, 0) - .set_input("x2", y, 0) - ) - matmul_transpose = ( - core.GEOperatorFactory.create_operator( - "transpose" + self._accumulated_op_id(), "TransposeD" - ) - .set_input("x", matmul_m) - .set_attr_vec_int32("perm", [1, 0]) - ) - tensor = self._create_ge_tensor( - [3], 2, [shape_x2[1], shape_x1[0], shape_x1[1]] - ) - const_shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - reshape_matmul = ( - core.GEOperatorFactory.create_operator( - "reshape" + self._accumulated_op_id(), "Reshape" - ) - .set_input("x", matmul_transpose) - .set_input("shape", const_shape) - .set_attr_int32("axis", 0) - ) - matmul = ( - core.GEOperatorFactory.create_operator( - "transpose" + self._accumulated_op_id(), "TransposeD" - ) - .set_input("x", reshape_matmul) - .set_attr_vec_int32("perm", [1, 2, 0]) - ) - else: - raise AssertionError("not support") - - return [matmul], [[0]] - - -class LayerNormParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "layer_norm" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[2]) - scale = self._get_ge_input(self.op.input_arg_names[1]) - bias = self._get_ge_input(self.op.input_arg_names[0]) - epsilon = self.op.attr("epsilon") - begin_norm_axis = self.op.attr("begin_norm_axis") - x_dtype = self.op.block.var(self.op.input_arg_names[2]).dtype - - shape_tensor = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", x) - scale_expand = ( - core.GEOperatorFactory.create_operator( - "broadcast_to_d" + self._accumulated_op_id(), "BroadcastTo" - ) - .set_input("x", scale) - .set_input("shape", shape_tensor) - ) - bias_expand = ( - core.GEOperatorFactory.create_operator( - "broadcast_to_d" + self._accumulated_op_id(), "BroadcastTo" - ) - .set_input("x", bias) - .set_input("shape", shape_tensor) - ) - layer_norm = ( - core.GEOperatorFactory.create_operator( - "layer_norm" + self._accumulated_op_id(), "LayerNorm" - ) - .set_input("x", x) - .set_input("gamma", scale_expand) - .set_input("beta", bias_expand) - .set_attr_int32("begin_norm_axis", begin_norm_axis) - .set_attr_int32("begin_params_axis", begin_norm_axis) - .set_attr_float("epsilon", epsilon) - ) - - cast_dtype = ( - 0 - if self.ascend_helper.dtype2paddle_inv_map[str(x_dtype)] == 0 - else 1 - ) - y = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", layer_norm, 0) - .set_attr_int32("dst_type", cast_dtype) - ) - mean = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", layer_norm, 1) - .set_attr_int32("dst_type", cast_dtype) - ) - variance = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", layer_norm, 2) - .set_attr_int32("dst_type", cast_dtype) - ) - return [y, mean, variance], [[1], [2], [0]] - - -# activate function -class ReluParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "relu" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - relu = core.GEOperatorFactory.create_operator( - "relu" + self._accumulated_op_id(), "Relu" - ).set_input("x", x) - return [relu], [[0]] - - -class GeluParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "gelu" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - gelu = core.GEOperatorFactory.create_operator( - "gelu" + self._accumulated_op_id(), "Gelu" - ).set_input("x", x) - return [gelu], [[0]] - - -class TanhParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "tanh" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - tanh = core.GEOperatorFactory.create_operator( - "tanh" + self._accumulated_op_id(), "Tanh" - ).set_input("x", x) - return [tanh], [[0]] - - -# loss function -class SoftmaxWithCrossEntropyParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "softmax_with_cross_entropy" - - def _apply(self): - label = self._get_ge_input(self.op.input_arg_names[0]) - logits = self._get_ge_input(self.op.input_arg_names[1]) - cls_num = self.op.block.var(self.op.input_arg_names[1]).shape[1] - - softmax = core.GEOperatorFactory.create_operator( - "softmax" + self._accumulated_op_id(), "SoftmaxV2" - ).set_input("x", logits) - label = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", label) - .set_attr_int32("dst_type", 3) - ) - - tensoron = self._create_ge_tensor([1], 5, 1) - on = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensoron) - tensoroff = self._create_ge_tensor([1], 5, 0) - off = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensoroff) - self._mark_as_input(on) - self._mark_as_input(off) - onehot = ( - core.GEOperatorFactory.create_operator( - "onehot" + self._accumulated_op_id(), "OneHotD" - ) - .set_input("x", label) - .set_input("on_value", on) - .set_input("off_value", off) - .set_attr_int32("depth", cls_num) - ) - squeeze = core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "Squeeze" - ).set_input("x", onehot) - - loss_all = ( - core.GEOperatorFactory.create_operator( - "loss" + self._accumulated_op_id(), - "SoftmaxCrossEntropyWithLogits", - ) - .set_input("features", logits) - .set_input("labels", squeeze) - ) - loss = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", loss_all, 0) - .set_attr_int32("dst_type", 0) - ) - loss_expand = ( - core.GEOperatorFactory.create_operator( - "unsqueeze" + self._accumulated_op_id(), "Unsqueeze" - ) - .set_input("x", loss) - .set_attr_vec_int32("axes", [1]) - ) - return [label, softmax, loss_expand], [[2], [1]] - - -class SoftMaxParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "softmax" - - def _apply(self): - logits = self._get_ge_input(self.op.input_arg_names[0]) - axes = self.op.attr("axis") - - softmax = ( - core.GEOperatorFactory.create_operator( - "softmax" + self._accumulated_op_id(), "SoftmaxV2" - ) - .set_input("x", logits) - .set_attr_vec_int32("axes", [axes]) - ) - return [softmax], [[0]] - - -# general -class ShapeParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "shape" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", x) - return [shape], [[0]] - - -class FillConstantParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "fill_constant" - - def _apply(self): - shape = self.op.attr("shape") - dtype = self.op.attr("dtype") - value = self.op.attr("value") - - tensor = self._create_ge_tensor(shape, dtype, value) - const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - self._mark_as_input(const) - if self.op.block.var(self.op.output('Out')[0]).persistable: - # print("%s is Persistable in fill_constant" % - # (self.op.output('Out')[0])) - var = core.GEOperatorFactory.create_operator( - self.op.output('Out')[0], "Variable" - ) - var.update_output_desc( - "y", - core.GETensorDesc( - core.GEShape(shape), - core.GEFormat.FORMAT_ND, - core.GEDataType.DT_FLOAT, - ), - ) - assign = ( - core.GEOperatorFactory.create_operator( - "assign" + self._accumulated_op_id(), "Assign" - ) - .set_input("value", const) - .set_input("ref", var) - ) - return [const], [[0]] - return [const], [[0]] - - -class TruncatedNormalParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "truncated_gaussian_random" - - def _apply(self): - shape = self.op.attr("shape") - dtype = self.op.attr("dtype") - mean = self.op.attr("mean") - std = self.op.attr("std") - seed = self.op.attr("seed") - - tensor1 = self._create_ge_tensor([len(shape)], 2, shape) - shape_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor1) - tensor2 = self._create_ge_tensor([1], dtype, mean) - mean_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor2) - tensor3 = self._create_ge_tensor([1], dtype, std) - std_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor3) - tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std) - min_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor4) - tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std) - max_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor5) - - self._mark_as_input(shape_tensor) - self._mark_as_input(mean_tensor) - self._mark_as_input(std_tensor) - self._mark_as_input(min_tensor) - self._mark_as_input(max_tensor) - - truncated_normal = ( - core.GEOperatorFactory.create_operator( - "truncated_normal" + self._accumulated_op_id(), - "ParameterizedTruncatedNormal", - ) - .set_input("shape", shape_tensor) - .set_input("means", mean_tensor) - .set_input("stdevs", std_tensor) - .set_input("min", min_tensor) - .set_input("max", max_tensor) - .set_attr_int32("seed", 0) - ) - - # wirte the output of truncatedNormal from startup_program to main_program - if self.op.block.var(self.op.output('Out')[0]).persistable: - # print("%s is Persistable in truncated_normal" % - # (self.op.output('Out')[0])) - var = core.GEOperatorFactory.create_operator( - self.op.output('Out')[0], "Variable" - ) - var.update_output_desc( - "y", - core.GETensorDesc( - core.GEShape(shape), - core.GEFormat.FORMAT_ND, - core.GEDataType.DT_FLOAT, - ), - ) - assign = ( - core.GEOperatorFactory.create_operator( - "assign" + self._accumulated_op_id(), "Assign" - ) - .set_input("value", truncated_normal) - .set_input("ref", var) - ) - return [ - shape_tensor, - mean_tensor, - std_tensor, - min_tensor, - max_tensor, - truncated_normal, - ], [[-1]] - # else: - # print( - # "self.op.output('Out')[0] is not persistable in truncated_noraml" - # ) - return [truncated_normal], [[0]] - - -class GatherParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "gather" - - def _apply(self): - index = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - clo = self.op.block.var(self.op.input_arg_names[1]).shape[-1] - - gather = ( - core.GEOperatorFactory.create_operator( - "gather" + self._accumulated_op_id(), "Gather" - ) - .set_input("x", x) - .set_input("indices", index) - .set_attr_bool("validate_indices", True) - ) - return [gather], [[0]] - - -class ScatterParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "scatter" - - def _apply(self): - index = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - updates = self._get_ge_input(self.op.input_arg_names[2]) - overwrite = self.op.attr("overwrite") - index_shape = self.op.block.var(self.op.input_arg_names[0]).shape - - if len(index_shape) == 1: - index = ( - core.GEOperatorFactory.create_operator( - "unsqueeze" + self.getid(), "Unsqueeze" - ) - .set_input("x", index) - .set_attr_vec_int32("axes", [1]) - ) - if not overwrite: - scatter_value = ( - core.GEOperatorFactory.create_operator( - "scatter" + self._accumulated_op_id(), "TensorScatterAdd" - ) - .set_input("x", x) - .set_input("indices", index) - .set_input("updates", updates) - ) - else: - scatter_value = ( - core.GEOperatorFactory.create_operator( - "scatter" + self._accumulated_op_id(), "TensorScatterUpdate" - ) - .set_input("x", x) - .set_input("indices", index) - .set_input("updates", updates) - ) - return [x, index, updates, scatter_value], [[-1]] - - -class CastParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "cast" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - dtype = self.op.attr("out_dtype") - cast = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", x) - .set_attr_int32("dst_type", dtype) - ) - return [cast], [[0]] - - -class AssignParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "assign" - - def _apply(self): - const = self._get_ge_input(self.op.input_arg_names[0]) - var = self._get_ge_input(self.op.input_arg_names[1]) - assign = ( - core.GEOperatorFactory.create_operator( - "assign" + self._accumulated_op_id(), "Assign" - ) - .set_input("value", const) - .set_input("ref", var) - ) - return [assign], [[0]] - - -class ScaleParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "scale" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - scale = self.op.attr("scale") - bias = self.op.attr("bias") - bias_after_scale = self.op.attr("bias_after_scale") - - if bias_after_scale: - scale_value = ( - core.GEOperatorFactory.create_operator( - "scale" + self._accumulated_op_id(), "Power" - ) - .set_input("x", x) - .set_attr_float("power", 1.0) - .set_attr_float("scale", scale) - .set_attr_float("shift", bias) - ) - else: - x_add_bias = ( - core.GEOperatorFactory.create_operator( - "adds" + self._accumulated_op_id(), "Adds" - ) - .set_input("x", x) - .set_attr_float("value", bias) - ) - scale_value = ( - core.GEOperatorFactory.create_operator( - "scale" + self._accumulated_op_id(), "Power" - ) - .set_input("x", x_add_bias) - .set_attr_float("power", 1.0) - .set_attr_float("scale", scale) - .set_attr_float("shift", 0.0) - ) - return [scale_value], [[0]] - - -class SliceParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "slice" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - axes = self.op.attr("axes") - starts = self.op.attr("starts") - ends = self.op.attr("ends") - - x_shape = self.op.block.var(self.op.input_arg_names[0]).shape - len_shape = len(x_shape) - axes_cor = list(range(len_shape)) - starts_cor, ends_cor = [], [] - cnt = 0 - for i in range(len_shape): - starts_cor.append(starts[cnt] if i in axes else 0) - if i in axes and ends[cnt] <= x_shape[i]: - ends_cor.append(ends[cnt]) - else: - ends_cor.append(x_shape[i]) - if i in axes: - cnt += 1 - size = [ends_cor[i] - starts_cor[i] for i in range(len(axes_cor))] - - assert ( - len(axes_cor) == len(starts_cor) == len(ends_cor) - ), "the three fields must have same size" - slice_value = ( - core.GEOperatorFactory.create_operator( - "slice" + self._accumulated_op_id(), "SliceD" - ) - .set_input("x", x) - .set_attr_vec_int32("offsets", starts_cor) - .set_attr_vec_int32("size", size) - ) - - return [slice_value], [[0]] - - -class ReshapeParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "reshape2" - - def _apply(self): - org_shape = self.op.block.var(self.op.input_arg_names[0]).shape - assert org_shape.count(-1) == 0, "do not allow the dim is -1" - shape = self.op.attr("shape") - for cnt in range(len(shape)): - if shape[cnt] == 0: - shape[cnt] = org_shape[cnt] - - if -1 in shape: - assert shape.count(-1) == 1, "only allow one dim is -1" - mul_res_org = reduce(lambda x, y: x * y, org_shape) - mul_res_refine = reduce(lambda x, y: x * y, shape) * -1 - idx = shape.index(-1) - shape[idx] = mul_res_org // mul_res_refine - - x = self._get_ge_input(self.op.input_arg_names[0]) - tensor = self._create_ge_tensor([len(shape)], 2, shape) - const_shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - reshape = ( - core.GEOperatorFactory.create_operator( - "reshape" + self._accumulated_op_id(), "Reshape" - ) - .set_input("x", x) - .set_input("shape", const_shape) - .set_attr_int32("axis", 0) - ) - x_shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", x) - - return [x_shape, reshape], [[1], [0]] - - -class TransposeParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "transpose2" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - perm = self.op.attr("axis") - transpose = ( - core.GEOperatorFactory.create_operator( - "transpose" + self._accumulated_op_id(), "TransposeD" - ) - .set_input("x", x) - .set_attr_vec_int32("perm", perm) - ) - x_shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", x) - - return [x_shape, transpose], [[1], [0]] - - -class AccuracyParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "accuracy" - - def _apply(self): - pred = self._get_ge_input(self.op.input_arg_names[0]) - label = self._get_ge_input(self.op.input_arg_names[1]) - logits = self._get_ge_input(self.op.input_arg_names[2]) - - pred = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", pred) - .set_attr_int32("dst_type", 3) - ) - label = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", label) - .set_attr_int32("dst_type", 3) - ) - equal = ( - core.GEOperatorFactory.create_operator( - "equal" + self._accumulated_op_id(), "Equal" - ) - .set_input("x1", pred) - .set_input("x2", label) - ) - cast = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", equal) - .set_attr_int32("dst_type", 0) - ) - acc = ( - core.GEOperatorFactory.create_operator( - "mean" + self._accumulated_op_id(), "ReduceMeanD" - ) - .set_input("x", cast) - .set_attr_bool("keep_dims", False) - .set_attr_vec_int32("axes", []) - ) - correct = ( - core.GEOperatorFactory.create_operator( - "sum" + self._accumulated_op_id(), "ReduceSumD" - ) - .set_input("x", cast) - .set_attr_bool("keep_dims", False) - .set_attr_vec_int32("axes", []) - ) - ones_tensor = core.GEOperatorFactory.create_operator( - "oneslike" + self._accumulated_op_id(), "OnesLike" - ).set_input("x", label) - ones_tensor = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", ones_tensor) - .set_attr_int32("dst_type", 0) - ) - total = ( - core.GEOperatorFactory.create_operator( - "sum" + self._accumulated_op_id(), "ReduceSumD" - ) - .set_input("x", ones_tensor) - .set_attr_bool("keep_dims", False) - .set_attr_vec_int32("axes", []) - ) - - return [acc, correct, total], [[0], [1], [2]] - - -class TopkParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "top_k" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - k = self.op.attr("k") - - tensor = self._create_ge_tensor([1], 2, k) - const_k = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - cast_x = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", x) - .set_attr_int32("dst_type", 1) - ) - topk = ( - core.GEOperatorFactory.create_operator( - "topk" + self._accumulated_op_id(), "TopK" - ) - .set_input("x", cast_x) - .set_input("k", const_k) - ) - value = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", topk, 0) - .set_attr_int32("dst_type", 0) - ) - index = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", topk, 1) - .set_attr_int32("dst_type", 0) - ) - return [value, index], [[1], [0]] - - -class LookupTableParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "lookup_table" - - def _apply(self): - ids = self._get_ge_input(self.op.input_arg_names[0]) - w = self._get_ge_input(self.op.input_arg_names[1]) - - ids_squeeze = ( - core.GEOperatorFactory.create_operator( - "squeeze" + self._accumulated_op_id(), "Squeeze" - ) - .set_input("x", ids) - .set_attr_vec_int32("axes", [-1]) - ) - out = ( - core.GEOperatorFactory.create_operator( - "lookup" + self._accumulated_op_id(), "Gather" - ) - .set_input("x", w) - .set_input("indices", ids_squeeze) - ) - return [out], [[0]] - - -class StackParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "stack" - - def _apply(self): - tiles = len(self.op.input_arg_names) - data_x_lst = [] - for index in range(tiles): - data_x_lst.append( - self._get_ge_input(self.op.input_arg_names[index]) - ) - axis = self.op.attr("axis") - - data_x = data_x_lst[0] - tensor = self._create_ge_tensor([1], 2, axis) - tensor_axis = core.GEOperatorFactory.create_operator( - "axis" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - expand = ( - core.GEOperatorFactory.create_operator( - "expand" + self._accumulated_op_id(), "ExpandDims" - ) - .set_input("x", data_x) - .set_input("axis", tensor_axis) - ) - - stack = ( - core.GEOperatorFactory.create_operator( - "stack" + self._accumulated_op_id(), "TileWithAxis" - ) - .set_input("x", expand) - .set_attr_int32("axis", axis) - .set_attr_int32("tiles", tiles) - ) - - return [stack], [[0]] - - -class UnSqueezeParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "unsqueeze2" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - axes = self.op.attr('axes') - - output = ( - core.GEOperatorFactory.create_operator( - "unsqueeze" + self._accumulated_op_id(), "Unsqueeze" - ) - .set_input("x", x) - .set_attr_vec_int32("axes", axes) - ) - shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", output) - return [shape, output], [[1], [0]] - - -# parallel -class AllGatherParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "c_allgather" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - rank_size = self.op.attr("rank_size") - group = self.op.attr("group") - - allgather = ( - core.GEOperatorFactory.create_operator( - "allgather" + self._accumulated_op_id(), "HcomAllGather" - ) - .set_input("x", x) - .set_attr_int32("rank_size", rank_size) - .set_attr_string("group", group) - ) - return [allgather], [[0]] - - -class AllReduceParser(AscendParserBase): - def __init__(self, graph, var2geop, reduction): - super().__init__(graph, var2geop) - self.parser_name = "c_allreduce_" + reduction - self.reduction = reduction - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - reduction = self.reduction - ring_id = self.op.attr("ring_id") - group = "hcom_group_" + str(ring_id) - fusion = None # self.op.attr("fusion") - fusion_id = None # self.op.attr("fusion_id") - - allreduce = ( - core.GEOperatorFactory.create_operator( - "allreduce" + self._accumulated_op_id(), "HcomAllReduce" - ) - .set_input("x", x) - .set_attr_string("reduction", reduction) - .set_attr_string("group", group) - ) - if fusion is not None: - allreduce.set_attr_int32("fusion", fusion) - - if fusion_id is not None: - allreduce.set_attr_int32("fusion_id", fusion_id) - return [allreduce], [[0]] - - -class AllReduceSumParser(AllReduceParser): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop, 'sum') - - -class AllReduceMaxParser(AllReduceParser): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop, 'max') - - -class BroadcastParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "c_broadcast" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - root_rank = self.op.attr("root_rank") - group = self.op.attr("group") - - broadcast = ( - core.GEOperatorFactory.create_operator( - "broadcast" + self._accumulated_op_id(), "HcomBroadcast" - ) - .set_input("x", x) - .set_attr_int32("root_rank", root_rank) - .set_attr_string("group", group) - ) - return [broadcast], [[0]] - - -class ReduceScatterParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "c_reduce_scatter" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - reduction = self.op.attr("reduction") - group = self.op.attr("group") - rank_size = self.op.attr("rank_size") - - reduce_scatter = ( - core.GEOperatorFactory.create_operator( - "reducescatter" + self._accumulated_op_id(), "HcomReduceScatter" - ) - .set_input("x", x) - .set_attr_string("reduction", reduction) - .set_attr_string("group", group) - .set_attr_int32("rank_size", rank_size) - ) - return [reduce_scatter], [[0]] - - -class SendParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "c_send" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - sr_tag = self.op.attr("sr_tag") - dest_rank = self.op.attr("dest_rank") - group = self.op.attr("group") - - send = ( - core.GEOperatorFactory.create_operator( - "send" + self._accumulated_op_id(), "HcomSend" - ) - .set_input("x", x) - .set_attr_int32("sr_tag", sr_tag) - .set_attr_int32("dest_rank", dest_rank) - .set_attr_string("group", group) - ) - return [send], [[0]] - - -class ReceiveParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "c_receive" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - sr_tag = self.op.attr("sr_tag") - src_rank = self.op.attr("src_rank") - group = self.op.attr("group") - shape = self.op.attr("shape") - dtype = self.op.attr("dtype") - - receive = ( - core.GEOperatorFactory.create_operator( - "receive" + self._accumulated_op_id(), "HcomReceive" - ) - .set_input("x", x) - .set_attr_int32("sr_tag", sr_tag) - .set_attr_int32("src_rank", src_rank) - .set_attr_string("group", group) - .set_attr_vec_int32("shape", shape) - .set_attr_int32("dtype", dtype) - ) - return [receive], [[0]] - - -class RangeParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "range" - - def _apply(self): - # TODO not support range type yet - start = self._get_ge_input(self.op.input_arg_names[0]) - end = self._get_ge_input(self.op.input_arg_names[1]) - delta = self._get_ge_input(self.op.input_arg_names[2]) - - ge_range = ( - core.GEOperatorFactory.create_operator( - "range" + self._accumulated_op_id(), "Range" - ) - .set_input("start", end) - .set_input("limit", start) - .set_input("delta", delta) - ) - - return [ge_range], [[0]] - - -class UniformRandomParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "uniform_random" - - def _apply(self): - shape = self.op.attr("shape") - - min_v = self.op.attr("min") - max_v = self.op.attr("max") - seed = self.op.attr("seed") - dtype = self.op.attr("dtype") - assert max_v > min_v, ( - "assert max_v > min_v, but received " - + f"as max_v={max_v}, min_v={min_v} " - ) - - tensor1 = self._create_ge_tensor([len(shape)], 2, shape) - shape_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor1) - - ge_ur = ( - core.GEOperatorFactory.create_operator( - "uniform_random" + self._accumulated_op_id(), "RandomUniform" - ) - .set_input("shape", shape_tensor) - .set_attr_dtype("dtype", self.ascend_helper.dtype2ge(dtype)) - .set_attr_int32("seed", seed) - .set_attr_int32("seed2", seed) - ) - - scale = max_v - min_v - - scale_value = ( - core.GEOperatorFactory.create_operator( - "scale" + self._accumulated_op_id(), "Power" - ) - .set_input("x", ge_ur) - .set_attr_float("power", 1.0) - .set_attr_float("scale", scale) - .set_attr_float("shift", min_v) - ) - - return [scale_value], [[0]] - - -class EqualParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "equal" - - def _apply(self): - data_x1 = self._get_ge_input(self.op.input_arg_names[0]) - data_x2 = self._get_ge_input(self.op.input_arg_names[1]) - equal = ( - core.GEOperatorFactory.create_operator( - "equal" + self._accumulated_op_id(), "Equal" - ) - .set_input("x1", data_x1) - .set_input("x2", data_x2) - ) - return [equal], [[0]] - - -class ExpandParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "expand" - - def _apply(self): - data_x1_shape = self._get_ge_input(self.op.input_arg_names[0]) - expand_times = self.op.attr('expand_times') - - tensor = self._create_ge_tensor([len(expand_times)], 2, expand_times) - expand_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - - assign = ( - core.GEOperatorFactory.create_operator( - "tile" + self._accumulated_op_id(), "Tile" - ) - .set_input("x", data_x1_shape) - .set_input("multiples", expand_tensor) - ) - return [assign], [[0]] - - -class SqueezeParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "squeeze2" - - def _apply(self): - tensor = self._get_ge_input(self.op.input_arg_names[0]) - axes = self.op.attr("axes") - - data_squeezed = ( - core.GEOperatorFactory.create_operator( - "squeeze" + self._accumulated_op_id(), "Squeeze" - ) - .set_input("x", tensor) - .set_attr_vec_int32("axes", axes) - ) - shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", data_squeezed) - return [shape, data_squeezed], [[1], [0]] - - -# ****************************************************************# -# *************************** *************************# -# *************************** *************************# -# *************************** GradParser *************************# -# *************************** *************************# -# *************************** *************************# -# ****************************************************************# -# grad -class ReduceSumGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "reduce_sum_grad" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - input = self._get_ge_input(self.op.input_arg_names[1]) - - shape_tensor = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", input, 0) - tensoron = self._create_ge_tensor([1], 2, -1) - const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensoron) - self._mark_as_input(const) - - reduce_sum = ( - core.GEOperatorFactory.create_operator( - "broadcast_to_d" + self._accumulated_op_id(), "BroadcastTo" - ) - .set_input("x", x) - .set_input("shape", shape_tensor) - ) - # reduce_sum = core.GEOperatorFactory.create_operator("expand" + self._accumulated_op_id(), "ExpandDims").set_input("x", reduce_sum).set_input("axis", const) - - return [reduce_sum], [[0]] - - -class MatMulGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "matmul_grad" - - def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - y = self._get_ge_input(self.op.input_arg_names[2]) - transpose_x = self.op.attr("transpose_X") - transpose_y = self.op.attr("transpose_Y") - - out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape - x_shape = self.op.block.var(self.op.input_arg_names[1]).shape - y_shape = self.op.block.var(self.op.input_arg_names[2]).shape - - if len(x_shape) > 2: - if transpose_y: - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), - "BatchMatMul", - ) - .set_input("x1", out_grad) - .set_input("x2", y) - .set_attr_bool("adj_x1", False) - .set_attr_bool("adj_x2", False) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), - "BatchMatMul", - ) - .set_input("x1", out_grad) - .set_input("x2", x) - .set_attr_bool("adj_x1", True) - .set_attr_bool("adj_x2", False) - ) - else: - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), - "BatchMatMul", - ) - .set_input("x1", out_grad) - .set_input("x2", y) - .set_attr_bool("adj_x1", False) - .set_attr_bool("adj_x2", True) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), - "BatchMatMul", - ) - .set_input("x1", x) - .set_input("x2", out_grad) - .set_attr_bool("adj_x1", True) - .set_attr_bool("adj_x2", False) - ) - else: - if transpose_y: - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", out_grad) - .set_input("x2", y) - .set_attr_bool("transpose_x1", False) - .set_attr_bool("transpose_x2", False) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", out_grad) - .set_input("x2", x) - .set_attr_bool("transpose_x1", True) - .set_attr_bool("transpose_x2", False) - ) - else: - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", out_grad) - .set_input("x2", y) - .set_attr_bool("transpose_x1", False) - .set_attr_bool("transpose_x2", True) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", x) - .set_input("x2", out_grad) - .set_attr_bool("transpose_x1", True) - .set_attr_bool("transpose_x2", False) - ) - - return [x_grad, y_grad], [[0], [1]] - - -class MulGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "mul_grad" - - def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - y = self._get_ge_input(self.op.input_arg_names[2]) - x_num_col_dims = self.op.attr("x_num_col_dims") - y_num_col_dims = self.op.attr("y_num_col_dims") - - shape_out_grad = self.op.block.var(self.op.input_arg_names[0]).shape - shape_x = self.op.block.var(self.op.input_arg_names[1]).shape - shape_y = self.op.block.var(self.op.input_arg_names[2]).shape - - if x_num_col_dims == 1 and y_num_col_dims == 1: - if len(shape_x) == 2 and len(shape_y) == 2: - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", out_grad) - .set_input("x2", y) - .set_attr_bool("transpose_x1", False) - .set_attr_bool("transpose_x2", True) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", x) - .set_input("x2", out_grad) - .set_attr_bool("transpose_x1", True) - .set_attr_bool("transpose_x2", False) - ) - elif len(shape_x) == 3 and len(shape_y) == 2: - flatten_x = core.GEOperatorFactory.create_operator( - "flatten" + self._accumulated_op_id(), "Flatten" - ).set_input("x", x) - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", out_grad) - .set_input("x2", y) - .set_attr_bool("transpose_x1", False) - .set_attr_bool("transpose_x2", True) - ) - if len(shape_out_grad) == 2: - x_grad = ( - core.GEOperatorFactory.create_operator( - "unsqueeze" + self._accumulated_op_id(), "Unsqueeze" - ) - .set_input("x", x_grad) - .set_attr_vec_int32("axes", [1]) - ) - - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", flatten_x) - .set_input("x2", out_grad) - .set_attr_bool("transpose_x1", True) - .set_attr_bool("transpose_x2", False) - ) - else: - if len(shape_x) == 3 and len(shape_y) == 2: - assert x_num_col_dims == 2, "only support 2" - flatten_x = ( - core.GEOperatorFactory.create_operator( - "flatten" + self._accumulated_op_id(), "FlattenV2" - ) - .set_input("x", x) - .set_attr_int32("axis", 0) - .set_attr_int32("end_axis", 1) - ) - flatten_out_grad = ( - core.GEOperatorFactory.create_operator( - "flatten" + self._accumulated_op_id(), "FlattenV2" - ) - .set_input("x", out_grad) - .set_attr_int32("axis", 0) - .set_attr_int32("end_axis", 1) - ) - - y_unsqueeze = ( - core.GEOperatorFactory.create_operator( - "unsqueeze" + self._accumulated_op_id(), "Unsqueeze" - ) - .set_input("x", y) - .set_attr_vec_int32("axes", [0]) - ) - y_stack = ( - core.GEOperatorFactory.create_operator( - "stack" + self._accumulated_op_id(), "TileWithAxis" - ) - .set_input("x", y_unsqueeze) - .set_attr_int32("axis", 0) - .set_attr_int32("tiles", shape_out_grad[0]) - ) - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), - "BatchMatMul", - ) - .set_input("x1", out_grad) - .set_input("x2", y_stack) - .set_attr_bool("adj_x1", False) - .set_attr_bool("adj_x2", True) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul" - ) - .set_input("x1", flatten_x) - .set_input("x2", flatten_out_grad) - .set_attr_bool("transpose_x1", True) - .set_attr_bool("transpose_x2", False) - ) - - return [x_grad, y_grad], [[0], [1]] - - -class ReluGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "relu_grad" - - def _apply(self): - out = self._get_ge_input(self.op.input_arg_names[0]) - out_grad = self._get_ge_input(self.op.input_arg_names[1]) - relu_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "ReluGrad" - ) - .set_input("gradients", out_grad) - .set_input("features", out) - ) - return [relu_grad], [[0]] - - -class SoftmaxWithCrossEntropyGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "softmax_with_cross_entropy_grad" - - def _apply(self): - label = self._get_ge_input(self.op.input_arg_names[0]) - loss_grad = self._get_ge_input(self.op.input_arg_names[1]) - softmax = self._get_ge_input(self.op.input_arg_names[2]) - cls_num = self.op.block.var(self.op.input_arg_names[2]).shape[1] - - label_shape = self.op.block.var(self.op.input_arg_names[0]).shape - loss_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape - softmax_shape = self.op.block.var(self.op.input_arg_names[2]).shape - - tensoron = self._create_ge_tensor([1], 5, 1) - on = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensoron) - tensoroff = self._create_ge_tensor([1], 5, 0) - off = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensoroff) - self._mark_as_input(on) - self._mark_as_input(off) - - label = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", label) - .set_attr_int32("dst_type", 3) - ) - onehot = ( - core.GEOperatorFactory.create_operator( - "onehot" + self._accumulated_op_id(), "OneHotD" - ) - .set_input("x", label) - .set_input("on_value", on) - .set_input("off_value", off) - .set_attr_int32("depth", cls_num) - ) - squeeze = core.GEOperatorFactory.create_operator( - "suqeeze" + self._accumulated_op_id(), "Squeeze" - ).set_input("x", onehot) - sub = ( - core.GEOperatorFactory.create_operator( - "sub" + self._accumulated_op_id(), "Sub" - ) - .set_input("x1", softmax) - .set_input("x2", squeeze) - ) - grad = ( - core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", loss_grad) - .set_input("x2", sub) - ) - - return [on, off, label, onehot, grad], [[-1]] - - -class DotMulGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_mul_grad" - - def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - out_1 = self._get_ge_input(self.op.input_arg_names[1]) - out_2 = self._get_ge_input(self.op.input_arg_names[2]) - - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", out_grad) - .set_input("x2", out_2) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", out_1) - .set_input("x2", out_grad) - ) - - return [x_grad, y_grad], [[0], [1]] - - -class DotAddGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_add_grad" - - def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - out_1 = self._get_ge_input(self.op.input_arg_names[1]) - out_2 = self._get_ge_input(self.op.input_arg_names[2]) - out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape - out_1_shape = self.op.block.var(self.op.input_arg_names[1]).shape - out_2_shape = self.op.block.var(self.op.input_arg_names[2]).shape - - x_grad = out_grad - cur_time_x = len(out_grad_shape) - len(out_1_shape) - for i in range(cur_time_x): - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "ReduceSumD" - ) - .set_input("x", x_grad) - .set_attr_vec_int32("axes", [0]) - .set_attr_bool("keep_dims", False) - ) - for axis, size in enumerate(out_1_shape): - if size == 1: - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), - "ReduceSumD", - ) - .set_input("x", x_grad) - .set_attr_vec_int32("axes", [axis]) - .set_attr_bool("keep_dims", True) - ) - - y_grad = out_grad - cur_time_y = len(out_grad_shape) - len(out_2_shape) - for i in range(cur_time_y): - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "ReduceSumD" - ) - .set_input("x", y_grad) - .set_attr_vec_int32("axes", [0]) - .set_attr_bool("keep_dims", False) - ) - for axis, size in enumerate(out_2_shape): - if size == 1: - y_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), - "ReduceSumD", - ) - .set_input("x", y_grad) - .set_attr_vec_int32("axes", [axis]) - .set_attr_bool("keep_dims", True) - ) - - return [x_grad, y_grad], [[0], [1]] - - -class DotDivGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "elementwise_div_grad" - - def _apply(self): - out = self._get_ge_input(self.op.input_arg_names[0]) - out_grad = self._get_ge_input(self.op.input_arg_names[1]) - x = self._get_ge_input(self.op.input_arg_names[2]) - y = self._get_ge_input(self.op.input_arg_names[3]) - - y_power = ( - core.GEOperatorFactory.create_operator( - "power" + self._accumulated_op_id(), "Power" - ) - .set_input("x", y) - .set_attr_float("power", -1) - ) - - tensor_zeros = core.GEOperatorFactory.create_operator( - "zeroslike" + self._accumulated_op_id(), "ZerosLike" - ).set_input("x", x) - x_zero = ( - core.GEOperatorFactory.create_operator( - "equal" + self._accumulated_op_id(), "Equal" - ) - .set_input("x1", x) - .set_input("x2", tensor_zeros) - ) - x_nozero = core.GEOperatorFactory.create_operator( - "logical_not" + self._accumulated_op_id(), "LogicalNot" - ).set_input("x", x_zero) - x_nozero_f = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", x_nozero) - .set_attr_int32("dst_type", 0) - ) - x_grad_w = ( - core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", x_nozero_f) - .set_input("x2", y_power) - ) - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", x_grad_w) - .set_input("x2", out_grad) - ) - - y_grad_w = ( - core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", out) - .set_input("x2", y_power) - ) - y_grad = ( - core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", y_grad_w) - .set_input("x2", out_grad) - ) - - return [x_grad, y_grad], [[0], [1]] - - -class SoftmaxGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "softmax_grad" - - def _apply(self): - out = self._get_ge_input(self.op.input_arg_names[0]) - out_grad = self._get_ge_input(self.op.input_arg_names[1]) - - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "SoftmaxGrad" - ) - .set_input("softmax", out) - .set_input("grad_softmax", out_grad) - ) - return [x_grad], [[0]] - - -class ReshapeGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "reshape2_grad" - - def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - x_shape = self._get_ge_input(self.op.input_arg_names[1]) - x_shape_list = self.op.block.var(self.op.input_arg_names[1]).shape - - if x_shape_list[0] == 0: - x_shape_delzero = x_shape_list[1:] - tensor = self._create_ge_tensor( - [len(x_shape_delzero)], 2, x_shape_delzero - ) - const_shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", tensor) - x_grad = ( - core.GEOperatorFactory.create_operator( - "reshape" + self._accumulated_op_id(), "Reshape" - ) - .set_input("x", out_grad) - .set_input("shape", const_shape) - ) - - return [x_grad], [[0]] - - -class GatherGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "gather_grad" - - def _apply(self): - index = self._get_ge_input(self.op.input_arg_names[0]) - out_grad = self._get_ge_input(self.op.input_arg_names[1]) - x = self._get_ge_input(self.op.input_arg_names[2]) - - index_shape = self.op.block.var(self.op.input_arg_names[0]).shape - out_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape - x_shape = self.op.block.var(self.op.input_arg_names[2]).shape - - if len(index_shape) == 1: - index = ( - core.GEOperatorFactory.create_operator( - "unsqueeze" + self._accumulated_op_id(), "Unsqueeze" - ) - .set_input("x", index) - .set_attr_vec_int32("axes", [1]) - ) - - tensor_zeros = core.GEOperatorFactory.create_operator( - "zeroslike" + self._accumulated_op_id(), "ZerosLike" - ).set_input("x", x) - x_grad = ( - core.GEOperatorFactory.create_operator( - "scatter" + self._accumulated_op_id(), "TensorScatterUpdate" - ) - .set_input("x", tensor_zeros) - .set_input("indices", index) - .set_input("updates", out_grad) - ) - - return [tensor_zeros, x_grad], [[-1]] - - -class TransposeGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "transpose2_grad" - - def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - perm = self.op.attr("axis") - - x_shape = self.op.block.var(self.op.input_arg_names[1]).shape[1:] - out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape - assert [out_grad_shape[x] for x in perm] == list(x_shape) - - x_grad = ( - core.GEOperatorFactory.create_operator( - "transpose" + self._accumulated_op_id(), "TransposeD" - ) - .set_input("x", out_grad) - .set_attr_vec_int32("perm", perm) - ) - - return [x_grad], [[0]] - - -class LayerNormGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "layer_norm_grad" - - def _apply(self): - bias = self._get_ge_input(self.op.input_arg_names[0]) - mean = self._get_ge_input(self.op.input_arg_names[1]) - scale = self._get_ge_input(self.op.input_arg_names[2]) - variance = self._get_ge_input(self.op.input_arg_names[3]) - x = self._get_ge_input(self.op.input_arg_names[4]) - out_grad = self._get_ge_input(self.op.input_arg_names[5]) - x_dtype = self.op.block.var(self.op.input_arg_names[4]).dtype - - x_grad = ( - core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "LayerNormGrad" - ) - .set_input("dy", out_grad) - .set_input("x", x) - .set_input("variance", variance) - .set_input("mean", mean) - .set_input("gamma", scale) - ) - - cast_dtype = ( - 0 - if self.ascend_helper.dtype2paddle_inv_map[str(x_dtype)] == 0 - else 1 - ) - out_x_grad = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", x_grad, 0) - .set_attr_int32("dst_type", cast_dtype) - ) - out_scale_grad = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", x_grad, 1) - .set_attr_int32("dst_type", cast_dtype) - ) - out_bias_grad = ( - core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast" - ) - .set_input("x", x_grad, 2) - .set_attr_int32("dst_type", cast_dtype) - ) - - return [out_x_grad, out_scale_grad, out_bias_grad], [[2], [1], [0]] - - -class TanhGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = 'tanh_grad' - - def _apply(self): - y = self._get_ge_input(self.op.input_arg_names[0]) - out_grad = self._get_ge_input(self.op.input_arg_names[1]) - tanh_grad = ( - core.GEOperatorFactory.create_operator( - "tanh_grad" + self._accumulated_op_id(), "TanhGrad" - ) - .set_input("y", y) - .set_input("dy", out_grad) - ) - - return [tanh_grad], [[0]] - - -class LogGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = 'log_grad' - - def _apply(self): - grad = self._get_ge_input(self.op.input_arg_names[0]) - input = self._get_ge_input(self.op.input_arg_names[1]) - log_grad = ( - core.GEOperatorFactory.create_operator( - "log_grad" + self._accumulated_op_id(), "DivNoNan" - ) - .set_input("x1", grad) - .set_input("x2", input) - ) - return [log_grad], [[0]] - - -class SqrtGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "sqrt_grad" - - def _apply(self): - y = self._get_ge_input(self.op.input_arg_names[0]) - out_grad = self._get_ge_input(self.op.input_arg_names[1]) - sqrt_grad = ( - core.GEOperatorFactory.create_operator( - "sqrt_grad" + self._accumulated_op_id(), "SqrtGrad" - ) - .set_input("y", y) - .set_input("dy", out_grad) - ) - return [sqrt_grad] - - -class PowGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "pow_grad" - - def _apply(self): - grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - factor = self.op.attr("factor") - - shape_tensor = self._create_shape_tensor() - shape_tensor = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape" - ).set_input("x", x) - factor_scale = self._create_ge_tensor([1], 5, factor) - factor_scale = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", factor_scale) - factor_tensor = ( - core.GEOperatorFactory.create_operator( - "broadcast_to_d" + self._accumulated_op_id(), "BroadcastTo" - ) - .set_input("x", factor_scale) - .set_input("shape", shape_tensor) - ) - - x_power = ( - core.GEOperatorFactory.create_operator( - "x_power" + self._accumulated_op_id(), "Power" - ) - .set_input("x", x) - .set_attr_float("power", factor - 1) - ) - x_power_mul_factor = ( - core.GEOperatorFactory.create_operator( - "x_power_mul_factor" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", x) - .set_input("x2", factor_tensor) - ) - x_power_mul_factor_grad = ( - core.GEOperatorFactory.create_operator( - "x_power_mul_factor_grad" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", x_power_mul_factor) - .set_input("x2", grad) - ) - - return [x_power_mul_factor_grad], [[0]] - - -class GeluGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "gelu_grad" - - def _apply(self): - grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - - y = core.GEOperatorFactory.create_operator( - "gelu" + self._accumulated_op_id(), "Gelu" - ).set_input("x", x) - gelu_grad = ( - core.GEOperatorFactory.create_operator( - "gelu_grad" + self._accumulated_op_id(), "GeluGrad" - ) - .set_input("x", x) - .set_input("dy", grad) - .set_input("y", y) - ) - - return [gelu_grad], [[0]] - - -class MeanGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "mean_grad" - - def _apply(self): - grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - - ones_tensor = core.GEOperatorFactory.create_operator( - "one_tensor" + self._accumulated_op_id(), "OnesLike" - ).set_input("x", x) - sum = ( - core.GEOperatorFactory.create_operator( - "mean" + self._accumulated_op_id(), "ReduceSumD" - ) - .set_input("x", ones_tensor) - .set_attr_bool("keep_dims", False) - .set_attr_vec_int32("axes", []) - ) - mean = ( - core.GEOperatorFactory.create_operator( - "x_power" + self._accumulated_op_id(), "Power" - ) - .set_input("x", sum) - .set_attr_float("power", -1) - ) - - mean_grad = ( - core.GEOperatorFactory.create_operator( - "mean_grad" + self._accumulated_op_id(), "Mul" - ) - .set_input("x1", mean) - .set_input("x2", grad) - ) - - return [mean_grad], [[0]] - - -class SliceGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "slice_grad" - - def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - grad = self._get_ge_input(self.op.input_arg_names[1]) - axes = self.op.attr("axes") - starts = self.op.attr("starts") - ends = self.op.attr("ends") - - x_shape = self.op.block.var(self.op.input_arg_names[0]).shape - grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape - - len_shape = len(x_shape) - axes_cor = list(range(len_shape)) - starts_cor, ends_cor = [], [] - cnt = 0 - for i in range(len_shape): - starts_cor.append(starts[cnt] if i in axes else 0) - if i in axes and ends[cnt] <= x_shape[i]: - ends_cor.append(x_shape[i] - ends[cnt]) - else: - ends_cor.append(0) - if i in axes: - cnt += 1 - - starts_cor[0] = 0 - ends_cor[0] = 0 - paddings = [[s, e] for (s, e) in zip(starts_cor, ends_cor)] - slice_value = ( - core.GEOperatorFactory.create_operator( - "slice_grad" + self._accumulated_op_id(), "PadD" - ) - .set_input("x", grad) - .set_attr_vec_vec_int64("paddings", paddings) - ) - - return [slice_value], [[0]] - - -class LookUpTableGradParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "lookup_table_grad" - - def _apply(self): - ids = self._get_ge_input(self.op.input_arg_names[0]) - grad = self._get_ge_input(self.op.input_arg_names[1]) - embedding = self._get_ge_input(self.op.input_arg_names[2]) - - shape_ids = self.op.block.var(self.op.input_arg_names[0]).shape - shape_grad = self.op.block.var(self.op.input_arg_names[1]).shape - shape_embedding = self.op.block.var(self.op.input_arg_names[2]).shape - - ids_flatten = ( - core.GEOperatorFactory.create_operator( - "flatten" + self._accumulated_op_id(), "FlattenV2" - ) - .set_input("x", ids) - .set_attr_int32("axis", 0) - .set_attr_int32("end_axis", 1) - ) - grad_flatten = ( - core.GEOperatorFactory.create_operator( - "flatten" + self._accumulated_op_id(), "FlattenV2" - ) - .set_input("x", grad) - .set_attr_int32("axis", 0) - .set_attr_int32("end_axis", 1) - ) - - tensor_zeros = core.GEOperatorFactory.create_operator( - "zeroslike" + self._accumulated_op_id(), "ZerosLike" - ).set_input("x", embedding) - embedding_grad = ( - core.GEOperatorFactory.create_operator( - "scatteradd" + self._accumulated_op_id(), "TensorScatterAdd" - ) - .set_input("x", tensor_zeros) - .set_input("indices", ids_flatten) - .set_input("updates", grad_flatten) - ) - - return [embedding_grad], [[0]] - - -class SGDParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "sgd" - - def _apply(self): - grad = self._get_ge_input(self.op.input_arg_names[0]) - lr = self._get_ge_input(self.op.input_arg_names[1]) - param = self._get_ge_input(self.op.input_arg_names[2]) - sgd = ( - core.GEOperatorFactory.create_operator( - "momentum" + self._accumulated_op_id(), "ApplyGradientDescent" - ) - .set_input("var", param) - .set_input("alpha", lr) - .set_input("delta", grad) - ) - return [sgd], [[0]] - - -class AdamParser(AscendParserBase): - def __init__(self, graph, var2geop): - super().__init__(graph, var2geop) - self.parser_name = "adam" - - def _apply(self): - beta1_power = self._get_ge_input(self.op.input_arg_names[0]) - beta2_power = self._get_ge_input(self.op.input_arg_names[1]) - grad = self._get_ge_input(self.op.input_arg_names[2]) - lr = self._get_ge_input(self.op.input_arg_names[3]) - moment1 = self._get_ge_input(self.op.input_arg_names[4]) - moment2 = self._get_ge_input(self.op.input_arg_names[5]) - param = self._get_ge_input(self.op.input_arg_names[6]) - beta1 = self.op.attr('beta1') - beta2 = self.op.attr('beta2') - epsilon = self.op.attr('epsilon') - - beta1 = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", self._create_ge_tensor([1], 5, beta1)) - beta2 = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", self._create_ge_tensor([1], 5, beta2)) - epsilon = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const" - ).set_attr_tensor("value", self._create_ge_tensor([1], 5, epsilon)) - - adam = ( - core.GEOperatorFactory.create_operator( - "adam" + self._accumulated_op_id(), "ApplyAdam" - ) - .set_input("var", param) - .set_input("m", moment1) - .set_input("v", moment2) - .set_input("beta1_power", beta1_power) - .set_input("beta2_power", beta2_power) - .set_input("lr", lr) - .set_input("beta1", beta1) - .set_input("beta2", beta2) - .set_input("epsilon", epsilon) - .set_input("grad", grad) - ) - - return [adam], [[0]] diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 858a48eafdc625445e1bb643bca0fc0f7099adb4..6d68234c0f440479c0cd76c2b4abcfe87bce1d5e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -743,7 +743,6 @@ class ShardingOptimizer(MetaOptimizerBase): ) def _init_npu_pipeline_comm(self, startup_block): - # NOTE(wangxi): some bug with hccl, must set pp_degree be even number assert (self.pp_degree % 2) == 0 max_ring_id = -1 diff --git a/python/paddle/fluid/tests/unittests/ascend_group.py b/python/paddle/fluid/tests/unittests/ascend_group.py deleted file mode 100644 index 29c3a72bd85b9b1ce4fab7180cfe6c7e0efaefd0..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/ascend_group.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from collections import namedtuple - -import paddle -from paddle import fluid -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_optimizers.ascend import ascend_optimizer -from paddle.fluid import core, unique_name -from paddle.fluid.layer_helper import LayerHelper - -Block = namedtuple('Block', ['program']) -Loss = namedtuple('Loss', ['block']) - -paddle.enable_static() - -OpRole = core.op_proto_and_checker_maker.OpRole -OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() -OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName() - -role = fleet.PaddleCloudRoleMaker(is_collective=True) -fleet.init(role) - - -def init_communicator( - startup_program, main_program, current_endpoint, endpoints, ring_id -): - nranks = len(endpoints) - other_endpoints = endpoints[:] - other_endpoints.remove(current_endpoint) - group_rank = endpoints.index(current_endpoint) - assert group_rank >= 0 - - block = startup_program.global_block() - nccl_id_var = block.create_var( - name=unique_name.generate('nccl_id'), - persistable=True, - type=core.VarDesc.VarType.RAW, - ) - block.append_op( - type='c_gen_nccl_id', - inputs={}, - outputs={'Out': nccl_id_var}, - attrs={ - 'rank': group_rank, - 'endpoint': current_endpoint, - 'other_endpoints': other_endpoints, - OP_ROLE_KEY: OpRole.Forward, - }, - ) - block.append_op( - type='c_comm_init', - inputs={'X': nccl_id_var}, - outputs={}, - attrs={ - 'nranks': nranks, - 'rank': group_rank, - 'ring_id': ring_id, - OP_ROLE_KEY: OpRole.Forward, - }, - ) - - # add input op for test - fill_var_name = "tensor@Filled" - fill_var = block.create_var( - name=fill_var_name, - shape=[10, 10], - dtype='float32', - persistable=False, - stop_gradient=True, - ) - block.append_op( - type="fill_constant", - outputs={"Out": fill_var_name}, - attrs={ - "shape": [10, 10], - "dtype": fill_var.dtype, - "value": 1.0, - "place_type": 1, - }, - ) - - with fluid.program_guard(main_program): - op_type = "c_allreduce_sum" - data = paddle.tensor.fill_constant( - shape=[1], dtype='float32', value=2.5 - ) - helper = LayerHelper(op_type, **locals()) - helper.append_op( - type=op_type, - inputs={'X': [data]}, - outputs={'Out': [data]}, - attrs={'ring_id': ring_id, 'use_calc_stream': True}, - ) - - print("startup program:", startup_program) - print("main program:", main_program) - - -def train(world_endpoints, world_device_ids, local_device_ids, local_rank): - startup_programs = [] - main_programs = [] - - # trainer_endpoints=["127.0.0.1:6071","127.0.0.1:6072","127.0.0.1:6073","127.0.0.1:6074"] - trainer_endpoints = world_endpoints - groups = [[], [], []] - groups[0] = [trainer_endpoints[0], trainer_endpoints[1]] - groups[1] = [trainer_endpoints[2], trainer_endpoints[3]] - groups[2] = [trainer_endpoints[0], trainer_endpoints[2]] - print("groups:", groups) - - for i in range(len(trainer_endpoints)): - startup_programs.append(fluid.Program()) - main_programs.append(fluid.Program()) - - for idx, group in enumerate(groups): - for te in group: - te_idx = trainer_endpoints.index(te) - startup_program = startup_programs[te_idx] - main_program = main_programs[te_idx] - init_communicator(startup_program, main_program, te, group, idx) - - print(len(startup_programs)) - print(startup_programs[local_rank]) - print(main_programs[local_rank]) - - print("local rank: ", local_rank) - print("local startup program: ", startup_programs[local_rank]) - - startup_program = startup_programs[local_rank] - main_program = main_programs[local_rank] - loss = Loss(Block(main_program)) - optimizer = ascend_optimizer.AscendOptimizer(None, fetch_list=[]) - optimizer.minimize( - loss, - startup_program, - auto_dp=True, - rank_table_file=os.getenv("RANK_TABLE_FILE", None), - ) - - exe = paddle.static.Executor(paddle.CPUPlace()) - exe.run(startup_program) - exe.run(main_program) - - -worker_endpoints = fleet.worker_endpoints() -world_device_ids = fleet.world_device_ids() -local_device_ids = fleet.local_device_ids() -local_rank = int(fleet.local_rank()) - -print("worker_endpoints:", worker_endpoints) -print("world_device_ids:", world_device_ids) -print("local_device_ids:", local_device_ids) -print("local_rank:", local_rank) - -train(worker_endpoints, world_device_ids, local_device_ids, local_rank) diff --git a/python/paddle/fluid/tests/unittests/hccl_tools.py b/python/paddle/fluid/tests/unittests/hccl_tools.py deleted file mode 100644 index f839da2a786ebb17b185d067e273fedcb8259bf5..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/hccl_tools.py +++ /dev/null @@ -1,191 +0,0 @@ -# -*- coding:UTF-8 -*- - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""generate hccl config file script""" -import json -import os -import socket -import sys -from argparse import ArgumentParser - - -def parse_args(): - """ - parse args . - - Args: - - Returns: - args. - - Examples: - >>> parse_args() - """ - parser = ArgumentParser( - description="mindspore distributed training launch " - "helper utilty that will generate hccl" - " config file" - ) - parser.add_argument( - "--device_num", - type=str, - default="[0,8)", - help="The number of the Ascend accelerators used. please note that the Ascend accelerators" - "used must be continuous, such [0,4) means to use four chips " - "0,1,2,3; [0,1) means to use chip 0; The first four chips are" - "a group, and the last four chips are a group. In addition to" - "the [0,8) chips are allowed, other cross-group such as [3,6)" - "are prohibited.", - ) - parser.add_argument( - "--visible_devices", - type=str, - default="0,1,2,3,4,5,6,7", - help="will use the visible devices sequentially", - ) - parser.add_argument("--server_ip", type=str, default="", help="server ip") - args = parser.parse_args() - return args - - -def get_host_ip(): - """ - get host ip - """ - ip = None - - try: - hostname = socket.gethostname() - ip = socket.gethostbyname(hostname) - except EOFError: - pass - - return ip - - -def main(): - print("start", __file__) - args = parse_args() - - # visible_devices - visible_devices = args.visible_devices.split(',') - print(f'visible_devices:{visible_devices}') - - # server_id - ip = get_host_ip() - if args.server_ip: - server_id = args.server_ip - elif ip: - server_id = ip - else: - raise ValueError("please input server ip!") - print(f'server_id:{server_id}') - - # device_num - first_num = int(args.device_num[1]) - last_num = int(args.device_num[3]) - if first_num < 0 or last_num > 8: - raise ValueError( - f"device num {args.device_num} must be in range [0,8] !" - ) - if first_num > last_num: - raise ValueError( - "First num {} of device num {} must less than last num {} !".format( - first_num, args.device_num, last_num - ) - ) - if first_num < 4: - if last_num > 4: - if first_num == 0 and last_num == 8: - pass - else: - raise ValueError( - "device num {} must be in the same group of [0,4] or [4,8] !".format( - args.device_num - ) - ) - - device_num_list = list(range(first_num, last_num)) - print("device_num_list:", device_num_list) - - assert len(visible_devices) >= len(device_num_list) - - # construct hccn_table - device_ips = {} - with open('/etc/hccn.conf', 'r') as fin: - for hccn_item in fin.readlines(): - if hccn_item.strip().startswith('address_'): - device_id, device_ip = hccn_item.split('=') - device_id = device_id.split('_')[1] - device_ips[device_id] = device_ip.strip() - - hccn_table = {'version': '1.0', 'server_count': '1', 'server_list': []} - device_list = [] - rank_id = 0 - for instance_id in device_num_list: - device_id = visible_devices[instance_id] - device_ip = device_ips[device_id] - device = { - 'device_id': device_id, - 'device_ip': device_ip, - 'rank_id': str(rank_id), - } - print( - 'rank_id:{}, device_id:{}, device_ip:{}'.format( - rank_id, device_id, device_ip - ) - ) - rank_id += 1 - device_list.append(device) - hccn_table['server_list'].append( - { - 'server_id': server_id, - 'device': device_list, - 'host_nic_ip': 'reserve', - } - ) - hccn_table['status'] = 'completed' - - # save hccn_table to file - table_path = os.getcwd() - table_fn = os.path.join( - table_path, - 'hccl_{}p_{}_{}.json'.format( - len(device_num_list), "".join(map(str, device_num_list)), server_id - ), - ) - with open(table_fn, 'w') as table_fp: - json.dump(hccn_table, table_fp, indent=4) - sys.stdout.flush() - print("Completed: hccl file was save in :", table_fn) - - -if __name__ == "__main__": - main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 94b11c7d1de42fd5bafce28b578ca3219c4d0afc..d4dcd597c6ff4002692396ea5ba188b5ccb56f9a 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -700,11 +700,7 @@ class TestParallelDyGraphRunnerBase: nranks = len(args.endpoints.split(",")) if args.endpoints else 1 # if args.update_method == "nccl2": - if ( - args.update_method == "nccl2" - or args.update_method == "bkcl" - or args.update_method == "hccl" - ): + if args.update_method == "nccl2" or args.update_method == "bkcl": strategy = paddle.distributed.parallel.ParallelStrategy() strategy.nranks = nranks strategy.local_rank = args.trainer_id @@ -818,12 +814,12 @@ class TestParallelDyGraphRunnerBase: strategy.find_unused_parameters = True # 3. init parallel env - if args.update_method == "nccl2" or "bkcl" or "hccl": + if args.update_method == "nccl2" or "bkcl": fleet.init(is_collective=True, strategy=strategy) # 4. train model model, train_reader, opt = self.get_model() - if args.update_method == "nccl2" or "bkcl" or "hccl": + if args.update_method == "nccl2" or "bkcl": opt = fleet.distributed_optimizer(opt) model = fleet.distributed_model(model) @@ -860,7 +856,6 @@ def runtime_main(test_class): "local", "nccl2_reduce_layer", "gloo", - "hccl", ], ) parser.add_argument('--trainer_id', type=int, required=False, default=0) @@ -968,7 +963,6 @@ class TestDistBase(unittest.TestCase): self._nccl2_mode = False self._bkcl_mode = False self._gloo_mode = False # now, support gloo backend - self._hccl_mode = False self._pipeline_mode = False self._mp_mode = False self._diff_batch = False @@ -1767,14 +1761,6 @@ class TestDistBase(unittest.TestCase): check_error_log=check_error_log, log_name=log_name, ) - elif self._hccl_mode: - tr0_losses, tr1_losses = self._run_cluster_nccl2( - model_file, - required_envs, - update_method='hccl', - check_error_log=check_error_log, - log_name=log_name, - ) elif self._pipeline_mode: tr0_losses, tr1_losses = self._run_pipeline( model_file, required_envs, check_error_log, log_name=log_name diff --git a/python/setup.py.in b/python/setup.py.in index fa32dcf13c3d51bbbe9f4573837ee28d37cd90c4..723c23ed7c6f7c5a5aaf44867880b3f507bb2fcb 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -411,7 +411,6 @@ packages=['paddle', 'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers.sharding', - 'paddle.distributed.fleet.meta_optimizers.ascend', 'paddle.distributed.fleet.meta_optimizers.dygraph_optimizer', 'paddle.distributed.fleet.runtime', 'paddle.distributed.rpc', diff --git a/setup.py b/setup.py index 288500feba8548ca58b2c9005f0fba6c06a7d5f5..491f09b976d56ca7fd7af1f1fa1d34b9f60d3d00 100644 --- a/setup.py +++ b/setup.py @@ -1387,7 +1387,6 @@ def get_setup_parameters(): 'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers.sharding', - 'paddle.distributed.fleet.meta_optimizers.ascend', 'paddle.distributed.fleet.meta_optimizers.dygraph_optimizer', 'paddle.distributed.fleet.runtime', 'paddle.distributed.rpc',