diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index c3aec546156e1873817634a943508e6964afb6f0..89a19b64795eefe9408de75131b1bae3c9d45a8f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -130,7 +130,7 @@ class AscendIRParser(object): name = e.name ge_out_operator.append(self.var2geop[name]) - # (Debug) If you want to print back prop vars, append/assign the varname in ge_out_operator here, such as: + # (Debug) If you want to print back prop vars, append/assign the varname in ge_out_operator here, such as: # if graph_name == "main": # ge_out_operator.append(self.var2geop["reduce_sum_0.tmp_0@GRAD"]) @@ -233,6 +233,7 @@ class AscendOptimizer(Optimizer): self.parser = AscendIRParser() input_varlist = self._get_input_varlist(main_block.program) + startup_graph, main_graph = self.parser.parse_program( startup_program, main_block.program, input_varlist, self.fetch_list) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py index b7f21b4051ceca195f55d894e37f53263acaf2ca..1dc67f59dd80c2f81eebdc473b20d5361cc9deff 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_parser.py @@ -17,32 +17,83 @@ import paddle.fluid.core as core import numpy as np from paddle.distributed import fleet -registerd_op = { - "elementwise_add": "AddParser", - "matmul": "MatMulParser", - "mul": "MulParser", - "relu": "ReluParser", - "softmax_with_cross_entropy": "SoftmaxWithCrossEntropyParser", - "shape": "ShapeParser", - "fill_constant": "FillConstantParser", - "reduce_sum": "ReduceSumParser", - "reduce_sum_grad": "ReduceSumGradParser", - "matmul_grad": "MatMulGradParser", - "mul_grad": "MulGradParser", - "reshape2": "ReshapeParser", - "scale": "ScaleParser", - "relu_grad": "ReluGradParser", - "softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser", - "truncated_gaussian_random": "TruncatedNormalParser", - "sgd": "SGDParser", - "c_allgather": "AllGatherParser", - "c_allreduce_sum": "AllReduceSumParser", - "c_allreduce_max": "AllReduceMaxParser", - "c_broadcast": "BroadcastParser", - "c_reduce_scatter": "ReduceScatterParser", - "c_send": "SendParser", - "c_receive": "ReceiveParser" -} +registerd_op = {## forwards + "elementwise_add": "AddParser", + "matmul": "MatMulParser", + "mul": "MulParser", + "relu": "ReluParser", + "softmax_with_cross_entropy": "SoftmaxWithCrossEntropyParser", + "shape": "ShapeParser", + "fill_constant": "FillConstantParser", + "reduce_sum": "ReduceSumParser", + "elementwise_mul": "DotMulParser", + "elementwise_div": "DotDivParser", + "elementwise_pow": "DotPowParser", + "elementwise_max": "MaxParser", + "elementwise_min": "MinParser", + "elementwise_sub": "DotSubParser", + "pow": "PowParser", + "gelu": "GeluParser", + "sqrt": "SqrtParser", + "log": "LogParser", + "sum": "SumParser", + "logical_not": "LogicalNotParser", + "gather": "GatherParser", + "scatter": "ScatterParser", + "cast": "CastParser", + "tanh": "TanhParser", + "stack": "StackParser", + "square": "SquareParser", + "unsqueeze2": "UnSqueezeParser", + "assign": "AssignParser", + "softmax": "SoftMaxParser", + "reshape2": "ReshapeParser", + "transpose2": "TransposeParser", + "layer_norm": "LayerNormParser", + "less_than": "LessParser", + "mean": "MeanParser", + "scale": "ScaleParser", + "slice": "SliceParser", + "top_k": "TopkParser", + "accuracy": "AccuracyParser", + #"increment": "IncrementParser", + "lookup_table": "LookupTableParser", + "truncated_gaussian_random": "TruncatedNormalParser", + "c_allgather": "AllGatherParser", + "c_allreduce_sum": "AllReduceSumParser", + "c_allreduce_max": "AllReduceMaxParser", + "c_broadcast": "BroadcastParser", + "c_reduce_scatter": "ReduceScatterParser", + "c_send": "SendParser", + "c_receive": "ReceiveParser", + + ## backwords + "matmul_grad": "MatMulGradParser", + "mul_grad": "MulGradParser", + "relu_grad": "ReluGradParser", + "reduce_sum_grad": "ReduceSumGradParser", + "softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser", + "tanh_grad":"TanhGradParser", + "log_grad":"LogGradParser", + "pow_grad": "PowGradParser", + "sqrt_grad": "SqrtGradParser", + "gelu_grad": "GeluGradParser", + "mean_grad": "MeanGradParser", + 'lookup_table_grad': "LookUpTableGradParser", + "elementwise_mul_grad": "DotMulGradParser", + "elementwise_add_grad": "DotAddGradParser", + "elementwise_div_grad": "DotDivGradParser", + "softmax_grad": "SoftmaxGradParser", + "slice_grad": "SliceGradParser", + "reshape2_grad": "ReshapeGradParser", + "gather_grad": "GatherGradParser", + "transpose2_grad": "TransposeGradParser", + "layer_norm_grad": "LayerNormGradParser", + + ## opt + "sgd": "SGDParser", + #"adam": "AdamParser", + } global_cnt = -1 global_input_cnt = -1 @@ -67,6 +118,7 @@ class AscendHelper(object): 5: "float32", 6: "float64" } + self.dtype2paddle_inv_map = {"VarType.FP32": 0, "VarType.FP16": 1} def dtype2ge(self, dtype): assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % ( @@ -159,7 +211,65 @@ class AscendParserBase(object): tensor.set_data(data_8) return tensor + def _get_ge_tensor(self, shape, dtype, value_list): + tensor_desc = core.GETensorDesc( + core.GEShape(shape), core.GEFormat.FORMAT_ND, + self.ascend_helper.dtype2ge(dtype)) + tensor = core.GETensor(tensor_desc) + + data = np.array(value_list).reshape(shape).astype( + self.ascend_helper.dtype2np(dtype)) + buf = data.tobytes() + data_8 = np.frombuffer(buf, dtype=np.uint8) + tensor.set_data(data_8) + + tensor_const = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor) + + return tensor_const + + def _get_variable(self, shape, dtype, tensor): + if dtype == "int32": + type = core.GEDataType.DT_INT32 + elif dtype == "float32": + type = core.GEDataType.DT_FLOAT + + var = core.GEOperatorFactory.create_operator( + "variable" + self._accumulated_op_id(), "Variable") + var.update_output_desc("y", + core.GETensorDesc( + core.GEShape(shape), core.GEFormat.FORMAT_ND, + type)) + assign = core.GEOperatorFactory.create_operator( + "assign" + self._accumulated_op_id(), "Assign").set_input( + "value", tensor).set_input("ref", var) + + return assign + + def _create_shape_tensor(self): + tensor_desc = core.GETensorDesc( + core.GEShape([2]), core.GEFormat.FORMAT_ND, + core.GEDataType.DT_INT32) + tensor = core.GETensor(tensor_desc) + data = np.ones((2)).astype("int32").reshape([2]) + data[0] = 64 + buf = data.tobytes() + data_8 = np.frombuffer(buf, dtype=np.uint8) + tensor.set_data(data_8) + return tensor + + def _get_GEtensor_shape(self, tensor): + tensor_shape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), "Shape").set_input("x", tensor) + tensor_shape = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", tensor_shape).set_attr_int32("dst_type", 0) + return tensor_shape + + +### elementwise_op class AddParser(AscendParserBase): def __init__(self, graph, var2geop): super(AddParser, self).__init__(graph, var2geop) @@ -169,109 +279,276 @@ class AddParser(AscendParserBase): x = self._get_ge_input(self.op.input_arg_names[0]) y = self._get_ge_input(self.op.input_arg_names[1]) add = core.GEOperatorFactory.create_operator( - "add" + self._accumulated_op_id(), "Add").set_input( - "x1", x).set_input("x2", y) + "add" + self._accumulated_op_id(), + "Add").set_input("x1", x).set_input("x2", y) return [add], [[0]] -class ReduceSumParser(AscendParserBase): +class DotSubParser(AscendParserBase): def __init__(self, graph, var2geop): - super(ReduceSumParser, self).__init__(graph, var2geop) - self.parser_name = "reduce_sum" + super(DotSubParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_sub" def _apply(self): x = self._get_ge_input(self.op.input_arg_names[0]) - axes = self.op.attr("dim") - keep_dims = self.op.attr("keep_dim") - reduce_sum = core.GEOperatorFactory.create_operator( - "reduce_sum" + self._accumulated_op_id(), "ReduceSumD").set_input( - "x", x, 0).set_attr_vec_int32("axes", axes).set_attr_bool( - "keep_dims", keep_dims) - return [reduce_sum], [[0]] + y = self._get_ge_input(self.op.input_arg_names[1]) + sub = core.GEOperatorFactory.create_operator( + "sub" + self._accumulated_op_id(), + "Sub").set_input("x1", x).set_input("x2", y) + return [sub], [[0]] -class ReduceSumGradParser(AscendParserBase): +class DotMulParser(AscendParserBase): def __init__(self, graph, var2geop): - super(ReduceSumGradParser, self).__init__(graph, var2geop) - self.parser_name = "reduce_sum_grad" + super(DotMulParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_mul" def _apply(self): x = self._get_ge_input(self.op.input_arg_names[0]) - input = self._get_ge_input(self.op.input_arg_names[1]) + y = self._get_ge_input(self.op.input_arg_names[1]) + mul = core.GEOperatorFactory.create_operator( + "dotmul" + self._accumulated_op_id(), + "Mul").set_input("x1", x).set_input("x2", y) + return [mul], [[0]] - shape_tensor = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Shape").set_input("x", input, - 0) - axis_const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", self._create_ge_tensor([1], 2, -1)) - self._mark_as_input(axis_const) - broadcast = core.GEOperatorFactory.create_operator( - "broadcast_to_d" + self._accumulated_op_id(), - "BroadcastTo").set_input("x", x).set_input("shape", shape_tensor) - # unsqueeze cannot get right result, but ExpandDims seems have the same functionality. - reduce_sum_grad = core.GEOperatorFactory.create_operator( - "expand" + self._accumulated_op_id(), "ExpandDims").set_input( - "x", broadcast).set_input("axis", axis_const) - return [shape_tensor, axis_const, broadcast, reduce_sum_grad], [[3]] +class DotDivParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(DotDivParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_div" + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + y = self._get_ge_input(self.op.input_arg_names[1]) + div = core.GEOperatorFactory.create_operator( + "dotdiv" + self._accumulated_op_id(), + "Div").set_input("x1", x).set_input("x2", y) + return [div], [[0]] -class MatMulParser(AscendParserBase): + +class DotPowParser(AscendParserBase): def __init__(self, graph, var2geop): - super(MatMulParser, self).__init__(graph, var2geop) - self.parser_name = "matmul" + super(DotPowParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_pow" def _apply(self): - x1 = self._get_ge_input(self.op.input_arg_names[0]) - x2 = self._get_ge_input(self.op.input_arg_names[1]) - matmul = core.GEOperatorFactory.create_operator( - "matmul" + self._accumulated_op_id(), "MatMul").set_input( - "x1", x1).set_input("x2", x2) - return [matmul], [[0]] + x = self._get_ge_input(self.op.input_arg_names[0]) + y = self._get_ge_input(self.op.input_arg_names[1]) + pow = core.GEOperatorFactory.create_operator( + "dotpow" + self._accumulated_op_id(), + "Pow").set_input("x1", x1).set_input("x2", y) + return [pow], [[0]] -class MatMulGradParser(AscendParserBase): +class LessParser(AscendParserBase): def __init__(self, graph, var2geop): - super(MatMulGradParser, self).__init__(graph, var2geop) - self.parser_name = "matmul_grad" + super(LessParser, self).__init__(graph, var2geop) + self.parser_name = "less_than" def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - y = self._get_ge_input(self.op.input_arg_names[2]) + x = self._get_ge_input(self.op.input_arg_names[0]) + y = self._get_ge_input(self.op.input_arg_names[1]) + less_than = core.GEOperatorFactory.create_operator( + "less_than" + self._accumulated_op_id(), + "Less").set_input("x1", x).set_input("x2", y) + return [less_than], [[0]] - x_grad = core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul").set_input( - "x1", out_grad).set_input("x2", y).set_attr_bool( - "transpose_x1", False).set_attr_bool("transpose_x2", True) - y_grad = core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul").set_input( - "x1", x).set_input("x2", out_grad).set_attr_bool( - "transpose_x1", True).set_attr_bool("transpose_x2", False) - return [x_grad, y_grad], [[0], [1]] +class MaxParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(MaxParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_max" -class MulGradParser(AscendParserBase): + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + y = self._get_ge_input(self.op.input_arg_names[1]) + max_out = core.GEOperatorFactory.create_operator( + "max" + self._accumulated_op_id(), + "Maximum").set_input("x1", x).set_input("x2", y) + return [max_out], [[0]] + + +class MinParser(AscendParserBase): def __init__(self, graph, var2geop): - super(MulGradParser, self).__init__(graph, var2geop) - self.parser_name = "mul_grad" + super(MinParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_min" def _apply(self): - out_grad = self._get_ge_input(self.op.input_arg_names[0]) - x = self._get_ge_input(self.op.input_arg_names[1]) - y = self._get_ge_input(self.op.input_arg_names[2]) + x = self._get_ge_input(self.op.input_arg_names[0]) + y = self._get_ge_input(self.op.input_arg_names[1]) + min_out = core.GEOperatorFactory.create_operator( + "min" + self._accumulated_op_id(), + "Minimum").set_input("x1", x).set_input("x2", y) + return [min_out], [[0]] - x_grad = core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul").set_input( - "x1", out_grad).set_input("x2", y).set_attr_bool( - "transpose_x1", False).set_attr_bool("transpose_x2", True) - y_grad = core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "MatMul").set_input( - "x1", x).set_input("x2", out_grad).set_attr_bool( - "transpose_x1", True).set_attr_bool("transpose_x2", False) - return [x_grad, y_grad], [[0], [1]] +## cal +class LogParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(LogParser, self).__init__(graph, var2geop) + self.parser_name = "log" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + log = core.GEOperatorFactory.create_operator( + "log" + self._accumulated_op_id(), "Log").set_input("x", x) + return [log], [[0]] + + +class SqrtParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SqrtParser, self).__init__(graph, var2geop) + self.parser_name = "sqrt" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + sqrt = core.GEOperatorFactory.create_operator( + "sqrt" + self._accumulated_op_id(), "Sqrt").set_input("x", x) + return [sqrt], [[0]] + + +class PowParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(PowParser, self).__init__(graph, var2geop) + self.parser_name = "pow" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + factor = self.op.attr("factor") + pow_value = core.GEOperatorFactory.create_operator( + "pow" + self._accumulated_op_id(), + "Power").set_input("x", x).set_attr_float( + "power", factor).set_attr_float("scale", 1.0).set_attr_float( + "shift", 0.0) + return [pow_value], [[0]] + + +class SquareParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SquareParser, self).__init__(graph, var2geop) + self.parser_name = "square" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + square = core.GEOperatorFactory.create_operator( + "square" + self._accumulated_op_id(), "Square").set_input("x", x) + return [square], [[0]] + + +class SumParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SumParser, self).__init__(graph, var2geop) + self.parser_name = "sum" + + def _apply(self): + len_list = len(self.op.input_arg_names) + if len_list < 2: + assert False, "the size of input list must large or equal 2" + x = self._get_ge_input(self.op.input_arg_names[0]) + y = self._get_ge_input(self.op.input_arg_names[1]) + sum = core.GEOperatorFactory.create_operator( + "sum" + self._accumulated_op_id(), + "Add").set_input("x1", x).set_input("x2", y) + for i in range(2, len_list): + y = self._get_ge_input(self.op.input_arg_names[i]) + sum = core.GEOperatorFactory.create_operator( + "sum" + self._accumulated_op_id(), + "Add").set_input("x1", sum).set_input("x2", y) + return [sum], [[0]] + + +class LogicalNotParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(LogicalNotParser, self).__init__(graph, var2geop) + self.parser_name = "logical_not" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + logical_not = core.GEOperatorFactory.create_operator( + "logical_not" + self._accumulated_op_id(), + "LogicalNot").set_input("x", x) + return [logical_not], [[0]] + + +class MeanParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(MeanParser, self).__init__(graph, var2geop) + self.parser_name = "mean" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + mean = core.GEOperatorFactory.create_operator( + "mean" + self._accumulated_op_id(), + "ReduceMeanD").set_input("x", x).set_attr_bool( + "keep_dims", False).set_attr_vec_int32("axes", []) + return [mean], [[0]] + + +class ReduceSumParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ReduceSumParser, self).__init__(graph, var2geop) + self.parser_name = "reduce_sum" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + axes = self.op.attr("dim") + keep_dims = self.op.attr("keep_dim") + reduce_all = self.op.attr("reduce_all") + x_shape = self.op.block.var(self.op.input_arg_names[0]).shape + if reduce_all: + axes = list(range(len(x_shape))) + reduce_sum = core.GEOperatorFactory.create_operator( + "reduce_sum" + self._accumulated_op_id(), + "ReduceSumD").set_input("x", x, 0).set_attr_vec_int32( + "axes", axes).set_attr_bool("keep_dims", keep_dims) + return [reduce_sum], [[0]] + + +#class IncrementParser(AscendParserBase): +# def __init__(self, graph, var2geop): +# super(IncrementParser, self).__init__(graph, var2geop) +# self.parser_name = "increment" +# +# def _apply(self): +# x = self._get_ge_input(self.op.input_arg_names[0]) +# step = self.op.attr("step") #self._get_ge_input(self.op.input_arg_names[1]) +# print("step: ", step) +# +# increment = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", step) #set_input("x2", bias) +# +# return [increment] + + +## matrix cal +class MatMulParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(MatMulParser, self).__init__(graph, var2geop) + self.parser_name = "matmul" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + y = self._get_ge_input(self.op.input_arg_names[1]) + transpose_x = self.op.attr("transpose_X") + transpose_y = self.op.attr("transpose_Y") + + x1_shape = self.op.block.var(self.op.input_arg_names[0]).shape + x2_shape = self.op.block.var(self.op.input_arg_names[1]).shape + + if len(x1_shape) > 2: + matmul = core.GEOperatorFactory.create_operator( + "matmul" + self._accumulated_op_id(), "BatchMatMul").set_input( + "x1", x).set_input("x2", y).set_attr_bool( + "adj_x1", + transpose_x).set_attr_bool("adj_x2", transpose_y) + elif len(x1_shape) == 2: + matmul = core.GEOperatorFactory.create_operator( + "matmul" + self._accumulated_op_id(), + "MatMul").set_input("x1", x).set_input("x2", y).set_attr_bool( + "transpose_x1", transpose_x).set_attr_bool("transpose_x2", + transpose_y) + else: + assert False, "not support" + return [matmul], [[0]] class MulParser(AscendParserBase): @@ -282,13 +559,105 @@ class MulParser(AscendParserBase): def _apply(self): x = self._get_ge_input(self.op.input_arg_names[0]) y = self._get_ge_input(self.op.input_arg_names[1]) + x_num_col_dims = self.op.attr("x_num_col_dims") + y_num_col_dims = self.op.attr("y_num_col_dims") + shape_x1 = self.op.block.var(self.op.input_arg_names[0]).shape + shape_x2 = self.op.block.var(self.op.input_arg_names[1]).shape + + if x_num_col_dims == 1 and y_num_col_dims == 1: + if len(shape_x1) == 2 and len(shape_x2) == 2: + matmul = core.GEOperatorFactory.create_operator( + "mul" + self._accumulated_op_id(), + "MatMul").set_input("x1", x).set_input("x2", y) + elif len(shape_x1) == 3 and len(shape_x2) == 2: + flatten_x1 = core.GEOperatorFactory.create_operator( + "flatten" + self._accumulated_op_id(), + "Flatten").set_input("x", x) + matmul = core.GEOperatorFactory.create_operator( + "mul" + self._accumulated_op_id(), "MatMul").set_input( + "x1", flatten_x1, 0).set_input("x2", y, 0) + else: + assert False, "not support" + else: + if len(shape_x1) == 3 and len(shape_x2) == 2: + assert x_num_col_dims == 2, "only support 2" + flatten_x1 = core.GEOperatorFactory.create_operator( + "flatten" + self._accumulated_op_id(), + "FlattenV2").set_input("x", x).set_attr_int32( + "axis", 0).set_attr_int32("end_axis", 1) + matmul_m = core.GEOperatorFactory.create_operator( + "mul" + self._accumulated_op_id(), "MatMul").set_input( + "x1", flatten_x1, 0).set_input("x2", y, 0) + matmul_transpose = core.GEOperatorFactory.create_operator( + "transpose" + self._accumulated_op_id(), + "TransposeD").set_input( + "x", matmul_m).set_attr_vec_int32("perm", [1, 0]) + tensor = self._create_ge_tensor( + [3], 2, [shape_x2[1], shape_x1[0], shape_x1[1]]) + const_shape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor) + reshape_matmul = core.GEOperatorFactory.create_operator( + "reshape" + self._accumulated_op_id(), "Reshape").set_input( + "x", matmul_transpose).set_input( + "shape", const_shape).set_attr_int32("axis", 0) + matmul = core.GEOperatorFactory.create_operator( + "transpose" + self._accumulated_op_id(), + "TransposeD").set_input( + "x", + reshape_matmul).set_attr_vec_int32("perm", [1, 2, 0]) + else: + assert False, "not support" - matmul = core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "MatMul").set_input( - "x1", x).set_input("x2", y) return [matmul], [[0]] +class LayerNormParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(LayerNormParser, self).__init__(graph, var2geop) + self.parser_name = "layer_norm" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[2]) + scale = self._get_ge_input(self.op.input_arg_names[1]) + bias = self._get_ge_input(self.op.input_arg_names[0]) + epsilon = self.op.attr("epsilon") + begin_norm_axis = self.op.attr("begin_norm_axis") + x_dtype = self.op.block.var(self.op.input_arg_names[2]).dtype + + shape_tensor = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), "Shape").set_input("x", x) + scale_expand = core.GEOperatorFactory.create_operator( + "broadcast_to_d" + self._accumulated_op_id(), + "BroadcastTo").set_input("x", + scale).set_input("shape", shape_tensor) + bias_expand = core.GEOperatorFactory.create_operator( + "broadcast_to_d" + self._accumulated_op_id(), + "BroadcastTo").set_input("x", bias).set_input("shape", shape_tensor) + layer_norm = core.GEOperatorFactory.create_operator( + "layer_norm" + self._accumulated_op_id(), + "LayerNorm").set_input("x", x).set_input( + "gamma", + scale_expand).set_input("beta", bias_expand).set_attr_int32( + "begin_norm_axis", begin_norm_axis).set_attr_int32( + "begin_params_axis", + begin_norm_axis).set_attr_float("epsilon", epsilon) + + cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str( + x_dtype)] == 0 else 1 + y = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", layer_norm, 0).set_attr_int32("dst_type", cast_dtype) + mean = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", layer_norm, 1).set_attr_int32("dst_type", cast_dtype) + variance = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", layer_norm, 2).set_attr_int32("dst_type", cast_dtype) + return [y, mean, variance], [[1], [2], [0]] + + +## activate function class ReluParser(AscendParserBase): def __init__(self, graph, var2geop): super(ReluParser, self).__init__(graph, var2geop) @@ -301,20 +670,31 @@ class ReluParser(AscendParserBase): return [relu], [[0]] -class ReluGradParser(AscendParserBase): +class GeluParser(AscendParserBase): def __init__(self, graph, var2geop): - super(ReluGradParser, self).__init__(graph, var2geop) - self.parser_name = "relu_grad" + super(GeluParser, self).__init__(graph, var2geop) + self.parser_name = "gelu" def _apply(self): - out = self._get_ge_input(self.op.input_arg_names[0]) - out_grad = self._get_ge_input(self.op.input_arg_names[1]) - relu_grad = core.GEOperatorFactory.create_operator( - self.parser_name + self._accumulated_op_id(), "ReluGrad").set_input( - "gradients", out_grad).set_input("features", out) - return [relu_grad], [[0]] + x = self._get_ge_input(self.op.input_arg_names[0]) + gelu = core.GEOperatorFactory.create_operator( + "gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x) + return [gelu], [[0]] + + +class TanhParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(TanhParser, self).__init__(graph, var2geop) + self.parser_name = "tanh" + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + tanh = core.GEOperatorFactory.create_operator( + "tanh" + self._accumulated_op_id(), "Tanh").set_input("x", x) + return [tanh], [[0]] + +## loss function class SoftmaxWithCrossEntropyParser(AscendParserBase): def __init__(self, graph, var2geop): super(SoftmaxWithCrossEntropyParser, self).__init__(graph, var2geop) @@ -323,80 +703,61 @@ class SoftmaxWithCrossEntropyParser(AscendParserBase): def _apply(self): label = self._get_ge_input(self.op.input_arg_names[0]) logits = self._get_ge_input(self.op.input_arg_names[1]) - cls_num = self.op.block.var(self.op.input_arg_names[1]).shape[1] + softmax = core.GEOperatorFactory.create_operator( - "softmax" + self._accumulated_op_id(), "SoftmaxV2").set_input( - "x", logits) + "softmax" + self._accumulated_op_id(), + "SoftmaxV2").set_input("x", logits) label = core.GEOperatorFactory.create_operator( "cast" + self._accumulated_op_id(), "Cast").set_input( "x", label).set_attr_int32("dst_type", 3) tensoron = self._create_ge_tensor([1], 5, 1) - on_const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensoron) - self._mark_as_input(on_const) + on = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensoron) tensoroff = self._create_ge_tensor([1], 5, 0) - off_const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensoroff) - self._mark_as_input(off_const) + off = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensoroff) + self._mark_as_input(on) + self._mark_as_input(off) onehot = core.GEOperatorFactory.create_operator( "onehot" + self._accumulated_op_id(), "OneHotD").set_input( - "x", label).set_input("on_value", on_const).set_input( - "off_value", off_const).set_attr_int32("depth", cls_num) + "x", label).set_input("on_value", on).set_input( + "off_value", off).set_attr_int32("depth", cls_num) squeeze = core.GEOperatorFactory.create_operator( "mul" + self._accumulated_op_id(), "Squeeze").set_input("x", onehot) - loss = core.GEOperatorFactory.create_operator( + + loss_all = core.GEOperatorFactory.create_operator( "loss" + self._accumulated_op_id(), "SoftmaxCrossEntropyWithLogits").set_input( "features", logits).set_input("labels", squeeze) - - return [label, softmax, on_const, off_const, onehot, squeeze, - loss], [[6], [1]] + loss = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", loss_all, 0).set_attr_int32("dst_type", 0) + loss_expand = core.GEOperatorFactory.create_operator( + "unsqueeze" + self._accumulated_op_id(), + "Unsqueeze").set_input("x", loss).set_attr_vec_int32("axes", [1]) + return [label, softmax, loss_expand], [[2], [1]] -class SoftmaxWithCrossEntropyGradParser(AscendParserBase): +class SoftMaxParser(AscendParserBase): def __init__(self, graph, var2geop): - super(SoftmaxWithCrossEntropyGradParser, self).__init__(graph, var2geop) - self.parser_name = "softmax_with_cross_entropy_grad" + super(SoftMaxParser, self).__init__(graph, var2geop) + self.parser_name = "softmax" def _apply(self): - label = self._get_ge_input(self.op.input_arg_names[0]) - loss_grad = self._get_ge_input(self.op.input_arg_names[1]) - softmax = self._get_ge_input(self.op.input_arg_names[2]) - cls_num = self.op.block.var(self.op.input_arg_names[2]).shape[1] + logits = self._get_ge_input(self.op.input_arg_names[0]) + axes = self.op.attr("axis") - tensoron = self._create_ge_tensor([1], 5, 1) - on_const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensoron) - self._mark_as_input(on_const) - tensoroff = self._create_ge_tensor([1], 5, 0) - off_const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensoroff) - self._mark_as_input(off_const) - label = core.GEOperatorFactory.create_operator( - "cast" + self._accumulated_op_id(), "Cast").set_input( - "x", label).set_attr_int32("dst_type", 3) - onehot = core.GEOperatorFactory.create_operator( - "onehot" + self._accumulated_op_id(), "OneHotD").set_input( - "x", label).set_input("on_value", on_const).set_input( - "off_value", off_const).set_attr_int32("depth", cls_num) - # the fuck onehot will add a demension, so must call squeeze afterward - squeeze = core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "Squeeze").set_input("x", onehot) - sub = core.GEOperatorFactory.create_operator( - "sub" + self._accumulated_op_id(), "Sub").set_input( - "x1", softmax).set_input("x2", squeeze) - grad = core.GEOperatorFactory.create_operator( - "mul" + self._accumulated_op_id(), "Mul").set_input( - "x1", loss_grad).set_input("x2", sub) - return [on_const, off_const, label, onehot, squeeze, sub, grad], [[-1]] + softmax = core.GEOperatorFactory.create_operator( + "softmax" + self._accumulated_op_id(), "SoftmaxV2").set_input( + "x", logits).set_attr_vec_int32("axes", [axes]) + return [softmax], [[0]] +## general class ShapeParser(AscendParserBase): def __init__(self, graph, var2geop): super(ShapeParser, self).__init__(graph, var2geop) @@ -418,16 +779,15 @@ class FillConstantParser(AscendParserBase): shape = self.op.attr("shape") dtype = self.op.attr("dtype") value = self.op.attr("value") - print("shape: ", shape) - print("dtype: ", dtype) - print("value: ", value) + tensor = self._create_ge_tensor(shape, dtype, value) const = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensor) + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor) self._mark_as_input(const) if self.op.block.var(self.op.output('Out')[0]).persistable: - print("%s fill_constant" % (self.op.output('Out')[0])) + print("%s is Persistable in fill_constant" % + (self.op.output('Out')[0])) var = core.GEOperatorFactory.create_operator( self.op.output('Out')[0], "Variable") var.update_output_desc("y", @@ -441,24 +801,9 @@ class FillConstantParser(AscendParserBase): return [const], [[0]] else: print( - "self.op.output('Out')[0] is not persistable in fill_constant") - return [const], [[0]] - - -class SGDParser(AscendParserBase): - def __init__(self, graph, var2geop): - super(SGDParser, self).__init__(graph, var2geop) - self.parser_name = "sgd" - - def _apply(self): - grad = self._get_ge_input(self.op.input_arg_names[0]) - lr = self._get_ge_input(self.op.input_arg_names[1]) - param = self._get_ge_input(self.op.input_arg_names[2]) - sgd = core.GEOperatorFactory.create_operator( - "momentum" + self._accumulated_op_id(), - "ApplyGradientDescent").set_input("var", param).set_input( - "alpha", lr).set_input("delta", grad) - return [sgd], [[0]] + "self.op.output('Out')[0]: %s is not persistable in fill_constant" + % (self.op.output('Out')[0])) + return [const], [[0]] class TruncatedNormalParser(AscendParserBase): @@ -472,30 +817,27 @@ class TruncatedNormalParser(AscendParserBase): mean = self.op.attr("mean") std = self.op.attr("std") seed = self.op.attr("seed") + tensor1 = self._create_ge_tensor([len(shape)], 2, shape) shape_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensor1) - + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor1) tensor2 = self._create_ge_tensor([1], dtype, mean) mean_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensor2) - + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor2) tensor3 = self._create_ge_tensor([1], dtype, std) std_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensor3) - + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor3) tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std) min_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensor4) - + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor4) tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std) max_tensor = core.GEOperatorFactory.create_operator( - "const" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensor5) + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor5) self._mark_as_input(shape_tensor) self._mark_as_input(mean_tensor) @@ -516,7 +858,6 @@ class TruncatedNormalParser(AscendParserBase): if self.op.block.var(self.op.output('Out')[0]).persistable: print("%s is Persistable in truncated_normal" % (self.op.output('Out')[0])) - #var = core.GEOperatorFactory.create_operator(self.op.output('Out')[0], "Variable").set_input("x", truncated_normal) var = core.GEOperatorFactory.create_operator( self.op.output('Out')[0], "Variable") var.update_output_desc("y", @@ -535,24 +876,354 @@ class TruncatedNormalParser(AscendParserBase): print( "self.op.output('Out')[0] is not persistable in truncated_noraml" ) - return [truncated_normal], [[0]] #[assign] + return [truncated_normal], [[0]] -class AllGatherParser(AscendParserBase): +class GatherParser(AscendParserBase): def __init__(self, graph, var2geop): - super(AllGatherParser, self).__init__(graph, var2geop) - self.parser_name = "c_allgather" + super(GatherParser, self).__init__(graph, var2geop) + self.parser_name = "gather" def _apply(self): - x = self._get_ge_input(self.op.input_arg_names[0]) - rank_size = self.op.attr("rank_size") - group = self.op.attr("group") + index = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + clo = self.op.block.var(self.op.input_arg_names[1]).shape[-1] - allgather = core.GEOperatorFactory.create_operator( - "allgather" + self._accumulated_op_id(), "HcomAllGather").set_input( - "x", x).set_attr_int32( - "rank_size", rank_size).set_attr_string("group", group) - return [allgather], [[0]] + gather = core.GEOperatorFactory.create_operator( + "gather" + self._accumulated_op_id(), "Gather").set_input( + "x", x).set_input("indices", index).set_attr_bool( + "validate_indices", True) + return [gather], [[0]] + + +class ScatterParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ScatterParser, self).__init__(graph, var2geop) + self.parser_name = "scatter" + + def _apply(self): + index = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + updates = self._get_ge_input(self.op.input_arg_names[2]) + overwrite = self.op.attr("overwrite") + index_shape = self.op.block.var(self.op.input_arg_names[0]).shape + + if len(index_shape) == 1: + index = core.GEOperatorFactory.create_operator( + "unsqueeze" + self.getid(), "Unsqueeze").set_input( + "x", index).set_attr_vec_int32("axes", [1]) + if not overwrite: + scatter_value = core.GEOperatorFactory.create_operator( + "scatter" + self._accumulated_op_id(), + "TensorScatterAdd").set_input( + "x", x_var).set_input("indices", index_var).set_input( + "updates", updatesi_var) + else: + scatter_value = core.GEOperatorFactory.create_operator( + "scatter" + self._accumulated_op_id(), + "TensorScatterUpdate").set_input( + "x", x_var).set_input("indices", index_var).set_input( + "updates", updates_var) + return [x_var, index_var, updates_var, scatter_value], [[-1]] + + +class CastParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(CastParser, self).__init__(graph, var2geop) + self.parser_name = "cast" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + dtype = self.op.attr("out_dtype") + cast = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", x).set_attr_int32("dst_type", dtype) + return [cast] + + +class AssignParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(AssignParser, self).__init__(graph, var2geop) + self.parser_name = "assign" + + def _apply(self): + const = self._get_ge_input(self.op.input_arg_names[0]) + var = self._get_ge_input(self.op.input_arg_names[1]) + assign = core.GEOperatorFactory.create_operator( + "assign" + self._accumulated_op_id(), "Assign").set_input( + "value", const).set_input("ref", var) + return [assign], [[0]] + + +class ScaleParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ScaleParser, self).__init__(graph, var2geop) + self.parser_name = "scale" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + scale = self.op.attr("scale") + bias = self.op.attr("bias") + bias_after_scale = self.op.attr("bias_after_scale") + + if bias_after_scale: + scale_value = core.GEOperatorFactory.create_operator( + "scale" + self._accumulated_op_id(), "Power").set_input( + "x", x).set_attr_float("power", 1.0).set_attr_float( + "scale", scale).set_attr_float("shift", bias) + else: + x_add_bias = core.GEOperatorFactory.create_operator( + "adds" + self._accumulated_op_id(), "Adds").set_input( + "x", x).set_attr_float("value", bias) + scale_value = core.GEOperatorFactory.create_operator( + "scale" + self._accumulated_op_id(), "Power").set_input( + "x", + x_add_bias).set_attr_float("power", 1.0).set_attr_float( + "scale", scale).set_attr_float("shift", 0.0) + return [scale_value], [[0]] + + +class SliceParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SliceParser, self).__init__(graph, var2geop) + self.parser_name = "slice" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + axes = self.op.attr("axes") + starts = self.op.attr("starts") + ends = self.op.attr("ends") + + x_shape = self.op.block.var(self.op.input_arg_names[0]).shape + len_shape = len(x_shape) + axes_cor = list(range(len_shape)) + starts_cor, ends_cor = [], [] + cnt = 0 + for i in range(len_shape): + starts_cor.append(starts[cnt] if i in axes else 0) + if i in axes and ends[cnt] <= x_shape[i]: + ends_cor.append(ends[cnt]) + else: + ends_cor.append(x_shape[i]) + if i in axes: + cnt += 1 + size = [ends_cor[i] - starts_cor[i] for i in range(len(axes_cor))] + + assert len(axes_cor) == len(starts_cor) == len( + ends_cor), "the three fields must have same size" + slice_value = core.GEOperatorFactory.create_operator( + "slice" + self._accumulated_op_id(), "SliceD").set_input( + "x", x).set_attr_vec_int32( + "offsets", starts_cor).set_attr_vec_int32("size", size) + + return [slice_value], [[0]] + + +class ReshapeParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ReshapeParser, self).__init__(graph, var2geop) + self.parser_name = "reshape2" + + def _apply(self): + org_shape = self.op.block.var(self.op.input_arg_names[0]).shape + assert org_shape.count(-1) == 0, "do not allow the dim is -1" + shape = self.op.attr("shape") + for cnt in range(len(shape)): + if shape[cnt] == 0: + shape[cnt] = org_shape[cnt] + + if -1 in shape: + assert shape.count(-1) == 1, "only allow one dim is -1" + mul_res_org = reduce(lambda x, y: x * y, org_shape) + mul_res_refine = reduce(lambda x, y: x * y, shape) * -1 + idx = shape.index(-1) + shape[idx] = mul_res_org // mul_res_refine + + x = self._get_ge_input(self.op.input_arg_names[0]) + tensor = self._create_ge_tensor([len(shape)], 2, shape) + const_shape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor) + reshape = core.GEOperatorFactory.create_operator( + "reshape" + self._accumulated_op_id(), "Reshape").set_input( + "x", + x).set_input("shape", const_shape).set_attr_int32("axis", 0) + x_shape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), "Shape").set_input("x", x) + + return [x_shape, reshape], [[1], [0]] + + +class TransposeParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(TransposeParser, self).__init__(graph, var2geop) + self.parser_name = "transpose2" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + perm = self.op.attr("axis") + transpose = core.GEOperatorFactory.create_operator( + "transpose" + self._accumulated_op_id(), "TransposeD").set_input( + "x", x).set_attr_vec_int32("perm", perm) + x_shape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), "Shape").set_input("x", x) + + return [x_shape, transpose], [[1], [0]] + + +class AccuracyParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(AccuracyParser, self).__init__(graph, var2geop) + self.parser_name = "accuracy" + + def _apply(self): + pred = self._get_ge_input(self.op.input_arg_names[0]) + label = self._get_ge_input(self.op.input_arg_names[1]) + logits = self._get_ge_input(self.op.input_arg_names[2]) + + pred = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", pred).set_attr_int32("dst_type", 3) + label = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", label).set_attr_int32("dst_type", 3) + equal = core.GEOperatorFactory.create_operator( + "equal" + self._accumulated_op_id(), "Equal").set_input( + "x1", pred).set_input("x2", label) + cast = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", equal).set_attr_int32("dst_type", 0) + acc = core.GEOperatorFactory.create_operator( + "mean" + self._accumulated_op_id(), "ReduceMeanD").set_input( + "x", cast).set_attr_bool("keep_dims", False).set_attr_vec_int32( + "axes", []) + correct = core.GEOperatorFactory.create_operator( + "sum" + self._accumulated_op_id(), "ReduceSumD").set_input( + "x", cast).set_attr_bool("keep_dims", False).set_attr_vec_int32( + "axes", []) + ones_tensor = core.GEOperatorFactory.create_operator( + "oneslike" + self._accumulated_op_id(), + "OnesLike").set_input("x", label) + ones_tensor = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", ones_tensor).set_attr_int32("dst_type", 0) + total = core.GEOperatorFactory.create_operator( + "sum" + self._accumulated_op_id(), "ReduceSumD").set_input( + "x", ones_tensor).set_attr_bool( + "keep_dims", False).set_attr_vec_int32("axes", []) + + return [acc, correct, total], [[0], [1], [2]] + + +class TopkParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(TopkParser, self).__init__(graph, var2geop) + self.parser_name = "top_k" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + k = self.op.attr("k") + + tensor = self._create_ge_tensor([1], 2, k) + const_k = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor) + cast_x = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), + "Cast").set_input("x", x).set_attr_int32("dst_type", 1) + topk = core.GEOperatorFactory.create_operator( + "topk" + self._accumulated_op_id(), + "TopK").set_input("x", cast_x).set_input("k", const_k) + value = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", topk, 0).set_attr_int32("dst_type", 0) + index = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", topk, 1).set_attr_int32("dst_type", 0) + return [value, index], [[1], [0]] + + +class LookupTableParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(LookupTableParser, self).__init__(graph, var2geop) + self.parser_name = "lookup_table" + + def _apply(self): + ids = self._get_ge_input(self.op.input_arg_names[0]) + w = self._get_ge_input(self.op.input_arg_names[1]) + + ids_squeeze = core.GEOperatorFactory.create_operator( + "squeeze" + self._accumulated_op_id(), "Squeeze").set_input( + "x", ids).set_attr_vec_int32("axes", [-1]) + out = core.GEOperatorFactory.create_operator( + "lookup" + self._accumulated_op_id(), "Gather").set_input( + "x", w).set_input("indices", ids_squeeze) + return [out], [[0]] + + +class StackParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(StackParser, self).__init__(graph, var2geop) + self.parser_name = "stack" + + def _apply(self): + tiles = len(self.op.input_arg_names) + data_x_lst = [] + for index in range(tiles): + data_x_lst.append( + self._get_ge_input(self.op.input_arg_names[index])) + axis = self.op.attr("axis") + + data_x = data_x_lst[0] + tensor = self._create_ge_tensor([1], 2, axis) + tensor_axis = core.GEOperatorFactory.create_operator( + "axis" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor) + expand = core.GEOperatorFactory.create_operator( + "expand" + self._accumulated_op_id(), + "ExpandDims").set_input("x", data_x).set_input("axis", tensor_axis) + + stack = core.GEOperatorFactory.create_operator( + "stack" + self._accumulated_op_id(), + "TileWithAxis").set_input("x", expand).set_attr_int32( + "axis", axis).set_attr_int32("tiles", tiles) + + return [stack], [[0]] + + +class UnSqueezeParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(UnSqueezeParser, self).__init__(graph, var2geop) + self.parser_name = "unsqueeze2" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + axes = self.op.attr('axes') + + output = core.GEOperatorFactory.create_operator( + "unsqueeze" + self._accumulated_op_id(), + "Unsqueeze").set_input("x", x).set_attr_vec_int32("axes", axes) + shape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), "Shape").set_input("x", output) + return [shape, output], [[1], [0]] + + +## parallel +class AllGatherParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(AllGatherParser, self).__init__(graph, var2geop) + self.parser_name = "c_allgather" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + rank_size = self.op.attr("rank_size") + group = self.op.attr("group") + + allgather = core.GEOperatorFactory.create_operator( + "allgather" + self._accumulated_op_id(), "HcomAllGather").set_input( + "x", x).set_attr_int32( + "rank_size", rank_size).set_attr_string("group", group) + return [allgather], [[0]] class AllReduceParser(AscendParserBase): @@ -667,59 +1338,735 @@ class ReceiveParser(AscendParserBase): return [receive], [[0]] -class ScaleParser(AscendParserBase): +#****************************************************************# +#*************************** *************************# +#*************************** *************************# +#*************************** GradParser *************************# +#*************************** *************************# +#*************************** *************************# +#****************************************************************# +## grad +class ReduceSumGradParser(AscendParserBase): def __init__(self, graph, var2geop): - super(ScaleParser, self).__init__(graph, var2geop) - self.parser_name = "scale" + super(ReduceSumGradParser, self).__init__(graph, var2geop) + self.parser_name = "reduce_sum_grad" def _apply(self): x = self._get_ge_input(self.op.input_arg_names[0]) - scale = self.op.attr( - "scale") #self.get_ge_input(self.op.input_arg_names[1]) - bias = self.op.attr("bias") - bias_after_scale = self.op.attr("bias_after_scale") - if bias_after_scale: - scale_value = core.GEOperatorFactory.create_operator( - "scale" + self._accumulated_op_id(), "Power").set_input( - "x", x).set_attr_float("power", 1.0).set_attr_float( - "scale", scale).set_attr_float("shift", bias) + input = self._get_ge_input(self.op.input_arg_names[1]) + + shape_tensor = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), + "Shape").set_input("x", input, 0) + tensoron = self._create_ge_tensor([1], 2, -1) + const = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensoron) + self._mark_as_input(const) + + reduce_sum = core.GEOperatorFactory.create_operator( + "broadcast_to_d" + self._accumulated_op_id(), + "BroadcastTo").set_input("x", x).set_input("shape", shape_tensor) + #reduce_sum = core.GEOperatorFactory.create_operator("expand" + self._accumulated_op_id(), "ExpandDims").set_input("x", reduce_sum).set_input("axis", const) + + return [reduce_sum], [[0]] + + +class MatMulGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(MatMulGradParser, self).__init__(graph, var2geop) + self.parser_name = "matmul_grad" + + def _apply(self): + out_grad = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + y = self._get_ge_input(self.op.input_arg_names[2]) + transpose_x = self.op.attr("transpose_X") + transpose_y = self.op.attr("transpose_Y") + + out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape + x_shape = self.op.block.var(self.op.input_arg_names[1]).shape + y_shape = self.op.block.var(self.op.input_arg_names[2]).shape + + if len(x_shape) > 2: + if transpose_y: + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "BatchMatMul").set_input("x1", out_grad).set_input( + "x2", y).set_attr_bool( + "adj_x1", False).set_attr_bool("adj_x2", False) + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "BatchMatMul").set_input("x1", out_grad).set_input( + "x2", x).set_attr_bool( + "adj_x1", True).set_attr_bool("adj_x2", False) + else: + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "BatchMatMul").set_input("x1", out_grad).set_input( + "x2", y).set_attr_bool( + "adj_x1", False).set_attr_bool("adj_x2", True) + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "BatchMatMul").set_input("x1", x).set_input( + "x2", out_grad).set_attr_bool( + "adj_x1", True).set_attr_bool("adj_x2", False) else: - x_add_bias = core.GEOperatorFactory.create_operator( - "adds" + self._accumulated_op_id(), "Adds").set_input( - "x", x).set_attr_float("value", - bias) #set_input("x2", bias) - scale_value = core.GEOperatorFactory.create_operator( - "scale" + self._accumulated_op_id(), "Power").set_input( - "x", x_add_bias).set_attr_float( - "power", 1.0).set_attr_float( - "scale", scale).set_attr_float("shift", 0.0) - #tensor_zeros = core.GEOperatorFactory.create_operator("zeroslike" + self.getid(), "ZerosLike").set_input("x", x) - #bias_ = self.create_ge_tensor([1], 5, bias) - #const_bias = core.GEOperatorFactory.create_operator("const" + self.getid(), "Const").set_attr_tensor("value", tensor_bias) - return [scale_value], [[0]] + if transpose_y: + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input("x1", out_grad).set_input( + "x2", y).set_attr_bool( + "transpose_x1", False).set_attr_bool("transpose_x2", + False) + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input("x1", out_grad).set_input( + "x2", x).set_attr_bool( + "transpose_x1", True).set_attr_bool("transpose_x2", + False) + else: + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input("x1", out_grad).set_input( + "x2", y).set_attr_bool( + "transpose_x1", False).set_attr_bool("transpose_x2", + True) + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input("x1", x).set_input( + "x2", out_grad).set_attr_bool( + "transpose_x1", True).set_attr_bool("transpose_x2", + False) + return [x_grad, y_grad], [[0], [1]] -class ReshapeParser(AscendParserBase): + +class MulGradParser(AscendParserBase): def __init__(self, graph, var2geop): - super(ReshapeParser, self).__init__(graph, var2geop) - self.parser_name = "reshape2" + super(MulGradParser, self).__init__(graph, var2geop) + self.parser_name = "mul_grad" def _apply(self): - print("swbuf:", self.op.input_arg_names) - shape = self.op.attr("shape") - axis = 0 - if shape[0] == -1: - axis = 1 - shape = shape[1:] - print("shape: ", shape) - data_x1_shape = self._get_ge_input(self.op.input_arg_names[0]) - tensor = self._create_ge_tensor([len(shape)], 2, shape) + out_grad = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + y = self._get_ge_input(self.op.input_arg_names[2]) + x_num_col_dims = self.op.attr("x_num_col_dims") + y_num_col_dims = self.op.attr("y_num_col_dims") + + shape_out_grad = self.op.block.var(self.op.input_arg_names[0]).shape + shape_x = self.op.block.var(self.op.input_arg_names[1]).shape + shape_y = self.op.block.var(self.op.input_arg_names[2]).shape + + if x_num_col_dims == 1 and y_num_col_dims == 1: + if len(shape_x) == 2 and len(shape_y) == 2: + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input("x1", out_grad).set_input( + "x2", y).set_attr_bool( + "transpose_x1", False).set_attr_bool("transpose_x2", + True) + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input("x1", x).set_input( + "x2", out_grad).set_attr_bool( + "transpose_x1", True).set_attr_bool("transpose_x2", + False) + elif len(shape_x) == 3 and len(shape_y) == 2: + flatten_x = core.GEOperatorFactory.create_operator( + "flatten" + self._accumulated_op_id(), + "Flatten").set_input("x", x) + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input( + "x1", out_grad).set_input("x2", y).set_attr_bool( + "transpose_x1", + False).set_attr_bool("transpose_x2", True) + if len(shape_out_grad) == 2: + x_grad = core.GEOperatorFactory.create_operator( + "unsqueeze" + self._accumulated_op_id(), + "Unsqueeze").set_input("x", x_grad).set_attr_vec_int32( + "axes", [1]) + + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input( + "x1", + flatten_x).set_input("x2", out_grad).set_attr_bool( + "transpose_x1", + True).set_attr_bool("transpose_x2", False) + else: + if len(shape_x) == 3 and len(shape_y) == 2: + assert x_num_col_dims == 2, "only support 2" + flatten_x = core.GEOperatorFactory.create_operator( + "flatten" + self._accumulated_op_id(), + "FlattenV2").set_input("x", x).set_attr_int32( + "axis", 0).set_attr_int32("end_axis", 1) + flatten_out_grad = core.GEOperatorFactory.create_operator( + "flatten" + self._accumulated_op_id(), + "FlattenV2").set_input("x", out_grad).set_attr_int32( + "axis", 0).set_attr_int32("end_axis", 1) + + y_unsqueeze = core.GEOperatorFactory.create_operator( + "unsqueeze" + self._accumulated_op_id(), + "Unsqueeze").set_input("x", + y).set_attr_vec_int32("axes", [0]) + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "BatchMatMul").set_input("x1", out_grad).set_input( + "x2", y_unsqueeze).set_attr_bool( + "adj_x1", False).set_attr_bool("adj_x2", True) + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "MatMul").set_input("x1", flatten_x).set_input( + "x2", flatten_out_grad).set_attr_bool( + "transpose_x1", + True).set_attr_bool("transpose_x2", False) + + return [x_grad, y_grad], [[0], [1]] + + +class ReluGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ReluGradParser, self).__init__(graph, var2geop) + self.parser_name = "relu_grad" + + def _apply(self): + out = self._get_ge_input(self.op.input_arg_names[0]) + out_grad = self._get_ge_input(self.op.input_arg_names[1]) + relu_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), "ReluGrad").set_input( + "gradients", out_grad).set_input("features", out) + return [relu_grad], [[0]] + + +class SoftmaxWithCrossEntropyGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SoftmaxWithCrossEntropyGradParser, self).__init__(graph, var2geop) + self.parser_name = "softmax_with_cross_entropy_grad" + + def _apply(self): + label = self._get_ge_input(self.op.input_arg_names[0]) + loss_grad = self._get_ge_input(self.op.input_arg_names[1]) + softmax = self._get_ge_input(self.op.input_arg_names[2]) + cls_num = self.op.block.var(self.op.input_arg_names[2]).shape[1] + + label_shape = self.op.block.var(self.op.input_arg_names[0]).shape + loss_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape + softmax_shape = self.op.block.var(self.op.input_arg_names[2]).shape + + tensoron = self._create_ge_tensor([1], 5, 1) + on = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensoron) + tensoroff = self._create_ge_tensor([1], 5, 0) + off = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensoroff) + self._mark_as_input(on) + self._mark_as_input(off) + + label = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", label).set_attr_int32("dst_type", 3) + onehot = core.GEOperatorFactory.create_operator( + "onehot" + self._accumulated_op_id(), "OneHotD").set_input( + "x", label).set_input("on_value", on).set_input( + "off_value", off).set_attr_int32("depth", cls_num) + squeeze = core.GEOperatorFactory.create_operator( + "suqeeze" + self._accumulated_op_id(), + "Squeeze").set_input("x", onehot) + sub = core.GEOperatorFactory.create_operator( + "sub" + self._accumulated_op_id(), "Sub").set_input( + "x1", softmax).set_input("x2", squeeze) + grad = core.GEOperatorFactory.create_operator( + "mul" + self._accumulated_op_id(), + "Mul").set_input("x1", loss_grad).set_input("x2", sub) + + return [on, off, label, onehot, grad], [[-1]] + + +class DotMulGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(DotMulGradParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_mul_grad" + + def _apply(self): + out_grad = self._get_ge_input(self.op.input_arg_names[0]) + out_1 = self._get_ge_input(self.op.input_arg_names[1]) + out_2 = self._get_ge_input(self.op.input_arg_names[2]) + + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "Mul").set_input("x1", out_grad).set_input("x2", out_2) + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "Mul").set_input("x1", out_1).set_input("x2", out_grad) + + return [x_grad, y_grad], [[0], [1]] + + +class DotAddGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(DotAddGradParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_add_grad" + + def _apply(self): + out_grad = self._get_ge_input(self.op.input_arg_names[0]) + out_1 = self._get_ge_input(self.op.input_arg_names[1]) + out_2 = self._get_ge_input(self.op.input_arg_names[2]) + out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape + out_1_shape = self.op.block.var(self.op.input_arg_names[1]).shape + out_2_shape = self.op.block.var(self.op.input_arg_names[2]).shape + + x_grad = out_grad + cur_time_x = len(out_grad_shape) - len(out_1_shape) + for i in range(cur_time_x): + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "ReduceSumD").set_input("x", x_grad).set_attr_vec_int32( + "axes", [0]).set_attr_bool("keep_dims", False) + for axis, size in enumerate(out_1_shape): + if size == 1: + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "ReduceSumD").set_input("x", x_grad).set_attr_vec_int32( + "axes", [axis]).set_attr_bool("keep_dims", True) + + y_grad = out_grad + cur_time_y = len(out_grad_shape) - len(out_2_shape) + for i in range(cur_time_y): + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "ReduceSumD").set_input("x", y_grad).set_attr_vec_int32( + "axes", [0]).set_attr_bool("keep_dims", False) + for axis, size in enumerate(out_2_shape): + if size == 1: + y_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "ReduceSumD").set_input("x", y_grad).set_attr_vec_int32( + "axes", [axis]).set_attr_bool("keep_dims", True) + + return [x_grad, y_grad], [[0], [1]] + + +class DotDivGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(DotDivGradParser, self).__init__(graph, var2geop) + self.parser_name = "elementwise_div_grad" + + def _apply(self): + out = self._get_ge_input(self.op.input_arg_names[0]) + out_grad = self._get_ge_input(self.op.input_arg_names[1]) + x = self._get_ge_input(self.op.input_arg_names[2]) + y = self._get_ge_input(self.op.input_arg_names[3]) + + y_power = core.GEOperatorFactory.create_operator( + "power" + self._accumulated_op_id(), "Power").set_input( + "x", y).set_attr_float("power", -1) + + tensor_zeros = core.GEOperatorFactory.create_operator( + "zeroslike" + self._accumulated_op_id(), + "ZerosLike").set_input("x", x) + x_zero = core.GEOperatorFactory.create_operator( + "equal" + self._accumulated_op_id(), "Equal").set_input( + "x1", x).set_input("x2", tensor_zeros) + x_nozero = core.GEOperatorFactory.create_operator( + "logical_not" + self._accumulated_op_id(), + "LogicalNot").set_input("x", x_zero) + x_nozero_f = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", x_nozero).set_attr_int32("dst_type", 0) + x_grad_w = core.GEOperatorFactory.create_operator( + "mul" + self._accumulated_op_id(), "Mul").set_input( + "x1", x_nozero_f).set_input("x2", y_power) + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "Mul").set_input("x1", x_grad_w).set_input("x2", out_grad) + + y_grad_w = core.GEOperatorFactory.create_operator( + "mul" + self._accumulated_op_id(), "Mul").set_input( + "x1", out).set_input("x2", y_power) + y_grad = core.GEOperatorFactory.create_operator( + "mul" + self._accumulated_op_id(), "Mul").set_input( + "x1", y_grad_w).set_input("x2", out_grad) + + return [x_grad, y_grad], [[0], [1]] + + +class SoftmaxGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SoftmaxGradParser, self).__init__(graph, var2geop) + self.parser_name = "softmax_grad" + + def _apply(self): + out = self._get_ge_input(self.op.input_arg_names[0]) + out_grad = self._get_ge_input(self.op.input_arg_names[1]) + + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "SoftmaxGrad").set_input("softmax", out).set_input("grad_softmax", + out_grad) + return [x_grad], [[0]] + + +class ReshapeGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(ReshapeGradParser, self).__init__(graph, var2geop) + self.parser_name = "reshape2_grad" + + def _apply(self): + out_grad = self._get_ge_input(self.op.input_arg_names[0]) + x_shape = self._get_ge_input(self.op.input_arg_names[1]) + x_shape_list = self.op.block.var(self.op.input_arg_names[1]).shape + + if x_shape_list[0] == 0: + x_shape_delzero = x_shape_list[1:] + tensor = self._create_ge_tensor([len(x_shape_delzero)], 2, + x_shape_delzero) const_shape = core.GEOperatorFactory.create_operator( - "shape" + self._accumulated_op_id(), "Const").set_attr_tensor( - "value", tensor) - reshape = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", tensor) + x_grad = core.GEOperatorFactory.create_operator( "reshape" + self._accumulated_op_id(), "Reshape").set_input( - "x", data_x1_shape).set_input( - "shape", const_shape).set_attr_int32("axis", axis) + "x", out_grad).set_input("shape", const_shape) + + return [x_grad], [[0]] + + +class GatherGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(GatherGradParser, self).__init__(graph, var2geop) + self.parser_name = "gather_grad" + + def _apply(self): + index = self._get_ge_input(self.op.input_arg_names[0]) + out_grad = self._get_ge_input(self.op.input_arg_names[1]) + x = self._get_ge_input(self.op.input_arg_names[2]) + + index_shape = self.op.block.var(self.op.input_arg_names[0]).shape + out_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape + x_shape = self.op.block.var(self.op.input_arg_names[2]).shape + + if len(index_shape) == 1: + index = core.GEOperatorFactory.create_operator( + "unsqueeze" + self._accumulated_op_id(), "Unsqueeze").set_input( + "x", index).set_attr_vec_int32("axes", [1]) + + tensor_zeros = core.GEOperatorFactory.create_operator( + "zeroslike" + self._accumulated_op_id(), + "ZerosLike").set_input("x", x) + x_grad = core.GEOperatorFactory.create_operator( + "scatter" + self._accumulated_op_id(), + "TensorScatterUpdate").set_input("x", tensor_zeros).set_input( + "indices", index).set_input("updates", out_grad) + + return [tensor_zeros, x_grad], [[-1]] + + +class TransposeGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(TransposeGradParser, self).__init__(graph, var2geop) + self.parser_name = "transpose2_grad" + + def _apply(self): + out_grad = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + perm = self.op.attr("axis") - return [reshape, reshape], [[0], [1]] + x_shape = self.op.block.var(self.op.input_arg_names[1]).shape[1:] + out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape + assert list(map(lambda x: out_grad_shape[x], perm)) == list(x_shape) + + x_grad = core.GEOperatorFactory.create_operator( + "transpose" + self._accumulated_op_id(), "TransposeD").set_input( + "x", out_grad).set_attr_vec_int32("perm", perm) + + return [x_grad], [[0]] + + +class LayerNormGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(LayerNormGradParser, self).__init__(graph, var2geop) + self.parser_name = "layer_norm_grad" + + def _apply(self): + bias = self._get_ge_input(self.op.input_arg_names[0]) + mean = self._get_ge_input(self.op.input_arg_names[1]) + scale = self._get_ge_input(self.op.input_arg_names[2]) + variance = self._get_ge_input(self.op.input_arg_names[3]) + x = self._get_ge_input(self.op.input_arg_names[4]) + out_grad = self._get_ge_input(self.op.input_arg_names[5]) + x_dtype = self.op.block.var(self.op.input_arg_names[4]).dtype + + x_grad = core.GEOperatorFactory.create_operator( + self.parser_name + self._accumulated_op_id(), + "LayerNormGrad").set_input("dy", out_grad).set_input( + "x", x).set_input("variance", variance).set_input( + "mean", mean).set_input("gamma", scale) + + cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str( + x_dtype)] == 0 else 1 + out_x_grad = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", x_grad, 0).set_attr_int32("dst_type", cast_dtype) + out_scale_grad = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", x_grad, 1).set_attr_int32("dst_type", cast_dtype) + out_bias_grad = core.GEOperatorFactory.create_operator( + "cast" + self._accumulated_op_id(), "Cast").set_input( + "x", x_grad, 2).set_attr_int32("dst_type", cast_dtype) + + return [out_x_grad, out_scale_grad, out_bias_grad], [[2], [1], [0]] + + +class TanhGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(TanhGradParser, self).__init__(graph, var2geop) + self.parser_name = 'tanh_grad' + + def _apply(self): + y = self._get_ge_input(self.op.input_arg_names[0]) + out_grad = self._get_ge_input(self.op.input_arg_names[1]) + tanh_grad = core.GEOperatorFactory.create_operator( + "tanh_grad" + self._accumulated_op_id(), + "TanhGrad").set_input("y", y).set_input("dy", out_grad) + + return [tanh_grad], [[0]] + + +class LogGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(LogGradParser, self).__init__(graph, var2geop) + self.parser_name = 'log_grad' + + def _apply(self): + grad = self._get_ge_input(self.op.input_arg_names[0]) + input = self._get_ge_input(self.op.input_arg_names[1]) + log_grad = core.GEOperatorFactory.create_operator( + "log_grad" + self._accumulated_op_id(), + "DivNoNan").set_input("x1", grad).set_input("x2", input) + return [log_grad], [[0]] + + +class SqrtGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SqrtGradParser, self).__init__(graph, var2geop) + self.parser_name = "sqrt_grad" + + def _apply(self): + y = self._get_ge_input(self.op.input_arg_names[0]) + out_grad = self._get_ge_input(self.op.input_arg_names[1]) + sqrt_grad = core.GEOperatorFactory.create_operator( + "sqrt_grad" + self._accumulated_op_id(), + "SqrtGrad").set_input("y", y).set_input("dy", out_grad) + return [sqrt_grad] + + +class PowGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(PowGradParser, self).__init__(graph, var2geop) + self.parser_name = "pow_grad" + + def _apply(self): + grad = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + factor = self.op.attr("factor") + + shape_tensor = self._create_shape_tensor() + shape_tensor = core.GEOperatorFactory.create_operator( + "shape" + self._accumulated_op_id(), "Shape").set_input("x", x) + factor_scale = self._create_ge_tensor([1], 5, factor) + factor_scale = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), + "Const").set_attr_tensor("value", factor_scale) + factor_tensor = core.GEOperatorFactory.create_operator( + "broadcast_to_d" + self._accumulated_op_id(), + "BroadcastTo").set_input( + "x", factor_scale).set_input("shape", shape_tensor) + + x_power = core.GEOperatorFactory.create_operator( + "x_power" + self._accumulated_op_id(), "Power").set_input( + "x", x).set_attr_float("power", factor - 1) + x_power_mul_factor = core.GEOperatorFactory.create_operator( + "x_power_mul_factor" + self._accumulated_op_id(), "Mul").set_input( + "x1", x).set_input("x2", factor_tensor) + x_power_mul_factor_grad = core.GEOperatorFactory.create_operator( + "x_power_mul_factor_grad" + self._accumulated_op_id(), + "Mul").set_input("x1", x_power_mul_factor).set_input("x2", grad) + + return [x_power_mul_factor_grad], [[0]] + + +class GeluGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(GeluGradParser, self).__init__(graph, var2geop) + self.parser_name = "gelu_grad" + + def _apply(self): + grad = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + + y = core.GEOperatorFactory.create_operator( + "gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x) + gelu_grad = core.GEOperatorFactory.create_operator( + "gelu_grad" + self._accumulated_op_id(), "GeluGrad").set_input( + "x", x).set_input("dy", grad).set_input("y", y) + + return [gelu_grad], [[0]] + + +class MeanGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(MeanGradParser, self).__init__(graph, var2geop) + self.parser_name = "mean_grad" + + def _apply(self): + grad = self._get_ge_input(self.op.input_arg_names[0]) + x = self._get_ge_input(self.op.input_arg_names[1]) + + ones_tensor = core.GEOperatorFactory.create_operator( + "one_tensor" + self._accumulated_op_id(), + "OnesLike").set_input("x", x) + sum = core.GEOperatorFactory.create_operator( + "mean" + self._accumulated_op_id(), "ReduceSumD").set_input( + "x", ones_tensor).set_attr_bool( + "keep_dims", False).set_attr_vec_int32("axes", []) + mean = core.GEOperatorFactory.create_operator( + "x_power" + self._accumulated_op_id(), "Power").set_input( + "x", sum).set_attr_float("power", -1) + + mean_grad = core.GEOperatorFactory.create_operator( + "mean_grad" + self._accumulated_op_id(), + "Mul").set_input("x1", mean).set_input("x2", grad) + + return [mean_grad], [[0]] + + +class SliceGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SliceGradParser, self).__init__(graph, var2geop) + self.parser_name = "slice_grad" + + def _apply(self): + x = self._get_ge_input(self.op.input_arg_names[0]) + grad = self._get_ge_input(self.op.input_arg_names[1]) + axes = self.op.attr("axes") + starts = self.op.attr("starts") + ends = self.op.attr("ends") + + x_shape = self.op.block.var(self.op.input_arg_names[0]).shape + grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape + + len_shape = len(x_shape) + axes_cor = list(range(len_shape)) + starts_cor, ends_cor = [], [] + cnt = 0 + for i in range(len_shape): + starts_cor.append(starts[cnt] if i in axes else 0) + if i in axes and ends[cnt] <= x_shape[i]: + ends_cor.append(x_shape[i] - ends[cnt]) + else: + ends_cor.append(0) + if i in axes: + cnt += 1 + + starts_cor[0] = 0 + ends_cor[0] = 0 + paddings = [[s, e] for (s, e) in zip(starts_cor, ends_cor)] + slice_value = core.GEOperatorFactory.create_operator( + "slice_grad" + self._accumulated_op_id(), "PadD").set_input( + "x", grad).set_attr_vec_vec_int64("paddings", paddings) + + return [slice_value], [[0]] + + +class LookUpTableGradParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(LookUpTableGradParser, self).__init__(graph, var2geop) + self.parser_name = "lookup_table_grad" + + def _apply(self): + ids = self._get_ge_input(self.op.input_arg_names[0]) + grad = self._get_ge_input(self.op.input_arg_names[1]) + embedding = self._get_ge_input(self.op.input_arg_names[2]) + + shape_ids = self.op.block.var(self.op.input_arg_names[0]).shape + shape_grad = self.op.block.var(self.op.input_arg_names[1]).shape + shape_embedding = self.op.block.var(self.op.input_arg_names[2]).shape + + ids_flatten = core.GEOperatorFactory.create_operator( + "flatten" + self._accumulated_op_id(), "FlattenV2").set_input( + "x", + ids).set_attr_int32("axis", 0).set_attr_int32("end_axis", 1) + grad_flatten = core.GEOperatorFactory.create_operator( + "flatten" + self._accumulated_op_id(), "FlattenV2").set_input( + "x", + grad).set_attr_int32("axis", 0).set_attr_int32("end_axis", 1) + + tensor_zeros = core.GEOperatorFactory.create_operator( + "zeroslike" + self._accumulated_op_id(), + "ZerosLike").set_input("x", embedding) + embedding_grad = core.GEOperatorFactory.create_operator( + "scatteradd" + self._accumulated_op_id(), + "TensorScatterAdd").set_input( + "x", tensor_zeros).set_input("indices", ids_flatten).set_input( + "updates", grad_flatten) + + return [embedding_grad], [[0]] + + +class SGDParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(SGDParser, self).__init__(graph, var2geop) + self.parser_name = "sgd" + + def _apply(self): + grad = self._get_ge_input(self.op.input_arg_names[0]) + lr = self._get_ge_input(self.op.input_arg_names[1]) + param = self._get_ge_input(self.op.input_arg_names[2]) + sgd = core.GEOperatorFactory.create_operator( + "momentum" + self._accumulated_op_id(), + "ApplyGradientDescent").set_input("var", param).set_input( + "alpha", lr).set_input("delta", grad) + return [sgd], [[0]] + + +class AdamParser(AscendParserBase): + def __init__(self, graph, var2geop): + super(AdamParser, self).__init__(graph, var2geop) + self.parser_name = "adam" + + def _apply(self): + beta1_power = self._get_ge_input(self.op.input_arg_names[0]) + beta2_power = self._get_ge_input(self.op.input_arg_names[1]) + grad = self._get_ge_input(self.op.input_arg_names[2]) + lr = self._get_ge_input(self.op.input_arg_names[3]) + moment1 = self._get_ge_input(self.op.input_arg_names[4]) + moment2 = self._get_ge_input(self.op.input_arg_names[5]) + param = self._get_ge_input(self.op.input_arg_names[6]) + beta1 = self.op.attr('beta1') + beta2 = self.op.attr('beta2') + epsilon = self.op.attr('epsilon') + + beta1 = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), "Const").set_attr_tensor( + "value", self._create_ge_tensor([1], 5, beta1)) + beta2 = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), "Const").set_attr_tensor( + "value", self._create_ge_tensor([1], 5, beta2)) + epsilon = core.GEOperatorFactory.create_operator( + "const" + self._accumulated_op_id(), "Const").set_attr_tensor( + "value", self._create_ge_tensor([1], 5, epsilon)) + + adam = core.GEOperatorFactory.create_operator( + "adam" + self._accumulated_op_id(), + "ApplyAdam").set_input("var", param).set_input( + "m", moment1).set_input("v", moment2).set_input( + "beta1_power", beta1_power).set_input( + "beta2_power", beta2_power).set_input( + "lr", lr).set_input("beta1", beta1).set_input( + "beta2", beta2).set_input( + "epsilon", epsilon).set_input("grad", grad) + + return [adam], [[0]]