ascend_parser.py 92.6 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Optimizer
import paddle.fluid.core as core
import numpy as np
18 19 20
from paddle.distributed import fleet
from functools import reduce

21 22
__all__ = []

23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
registerd_op = {  ## forwards
    "elementwise_add": "AddParser",
    "matmul": "MatMulParser",
    "mul": "MulParser",
    "relu": "ReluParser",
    "softmax_with_cross_entropy": "SoftmaxWithCrossEntropyParser",
    "shape": "ShapeParser",
    "fill_constant": "FillConstantParser",
    "reduce_sum": "ReduceSumParser",
    "elementwise_mul": "DotMulParser",
    "elementwise_div": "DotDivParser",
    "elementwise_pow": "DotPowParser",
    "elementwise_max": "MaxParser",
    "elementwise_min": "MinParser",
    "elementwise_sub": "DotSubParser",
    "pow": "PowParser",
    "gelu": "GeluParser",
    "sqrt": "SqrtParser",
    "log": "LogParser",
    "sum": "SumParser",
    "logical_not": "LogicalNotParser",
    "gather": "GatherParser",
    "scatter": "ScatterParser",
    "cast": "CastParser",
    "tanh": "TanhParser",
    "stack": "StackParser",
    "square": "SquareParser",
    "unsqueeze2": "UnSqueezeParser",
    "assign": "AssignParser",
    "softmax": "SoftMaxParser",
    "reshape2": "ReshapeParser",
    "transpose2": "TransposeParser",
    "layer_norm": "LayerNormParser",
    "less_than": "LessParser",
    "mean": "MeanParser",
    "scale": "ScaleParser",
    "slice": "SliceParser",
    "top_k": "TopkParser",
    "accuracy": "AccuracyParser",
    #"increment": "IncrementParser",
    "lookup_table": "LookupTableParser",
    "truncated_gaussian_random": "TruncatedNormalParser",
    "c_allgather": "AllGatherParser",
    "c_allreduce_sum": "AllReduceSumParser",
    "c_allreduce_max": "AllReduceMaxParser",
    "c_broadcast": "BroadcastParser",
    "c_reduce_scatter": "ReduceScatterParser",
    "c_send": "SendParser",
    "c_receive": "ReceiveParser",
    "uniform_random": "UniformRandomParser",
    "range": "RangeParser",
    "equal": "EqualParser",
    "expand": "ExpandParser",
    "squeeze2": "SqueezeParser",

    ## backwords
    "matmul_grad": "MatMulGradParser",
    "mul_grad": "MulGradParser",
    "relu_grad": "ReluGradParser",
    "reduce_sum_grad": "ReduceSumGradParser",
    "softmax_with_cross_entropy_grad": "SoftmaxWithCrossEntropyGradParser",
    "tanh_grad": "TanhGradParser",
    "log_grad": "LogGradParser",
    "pow_grad": "PowGradParser",
    "sqrt_grad": "SqrtGradParser",
    "gelu_grad": "GeluGradParser",
    "mean_grad": "MeanGradParser",
    'lookup_table_grad': "LookUpTableGradParser",
    "elementwise_mul_grad": "DotMulGradParser",
    "elementwise_add_grad": "DotAddGradParser",
    "elementwise_div_grad": "DotDivGradParser",
    "softmax_grad": "SoftmaxGradParser",
    "slice_grad": "SliceGradParser",
    "reshape2_grad": "ReshapeGradParser",
    "gather_grad": "GatherGradParser",
    "transpose2_grad": "TransposeGradParser",
    "layer_norm_grad": "LayerNormGradParser",

    ## opt
    "sgd": "SGDParser",
    #"adam": "AdamParser",
}
105 106 107 108 109
global_cnt = -1
global_input_cnt = -1


class AscendHelper(object):
110

111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    def __init__(self):
        self.dtype2ge_map = {
            0: core.GEDataType.DT_BOOL,
            1: core.GEDataType.DT_INT16,
            2: core.GEDataType.DT_INT32,
            3: core.GEDataType.DT_INT64,
            4: core.GEDataType.DT_FLOAT16,
            5: core.GEDataType.DT_FLOAT,
            6: core.GEDataType.DT_DOUBLE
        }
        self.dtype2np_map = {
            0: "bool",
            1: "int16",
            2: "int32",
            3: "int64",
            4: "float16",
            5: "float32",
            6: "float64"
        }
130
        self.dtype2paddle_inv_map = {"VarType.FP32": 0, "VarType.FP16": 1}
131 132 133 134 135 136 137 138

    def dtype2ge(self, dtype):
        assert dtype in self.dtype2ge_map, "dtype[%d] is not supported %d" % (
            dtype)
        return self.dtype2ge_map[dtype]

    def dtype2np(self, index):
        assert index in self.dtype2np_map, "index[%d] is not supported %d" % (
J
Jiangxinz 已提交
139
            index)
140 141 142 143
        return self.dtype2np_map[index]


class AscendParserFactory(object):
144

145 146 147 148 149 150 151 152 153 154 155 156 157
    def __init__(self, graph, var2geop):
        self.graph = graph
        self.var2geop = var2geop

    def create_parse(self, parser_class):
        try:
            parser = globals()[parser_class](self.graph, self.var2geop)
            return parser
        except:
            raise ValueError("parser class %s does not exist" % parser_class)


class AscendParserBase(object):
158

159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
    def __init__(self, graph, var2geop):
        self.graph = graph
        self.var2geop = var2geop
        self.op = None
        self.ascend_helper = AscendHelper()

    def _get_ge_input(self, input_var_name):
        assert input_var_name in self.var2geop, "var %s not created before" % (
            input_var_name)
        return self.var2geop[input_var_name]

    def update_output(self, geop_list, index_list):
        output_num = len(self.op.output_names)
        assert output_num == len(
            index_list
        ), "Parser[%s]'s output number[%d] is not equal to parameters number[%d]" % (
            self.parser_name, len(index_list), output_num)
        for output_id in range(output_num):
            arguments = self.op.output(self.op.output_names[output_id])
            if len(arguments) > 0:
                assert len(arguments) == len(
                    index_list[output_id]
                ), "Parser[%s]'s %dth argument number[%d] is not equal to paddle's number[%d]" % (
182 183
                    self.parser_name, output_id, len(
                        index_list[output_id]), len(arguments))
184
                for i in range(len(arguments)):
185 186
                    self.var2geop[arguments[i]] = geop_list[
                        index_list[output_id][i]]
187 188 189 190 191 192 193 194

        for geop in geop_list:
            self.graph.add_op(geop)

    def apply(self, op):
        self.op = op
        assert self.op.type == self.parser_name, "op [%s] != parser_name[%s]" % (
            self.op.type, self.parser_name)
195
        #print("begin to parse op %s" % (self.parser_name))
196 197 198 199 200 201 202 203 204 205 206
        geop_list, index_list = self._apply()
        self.update_output(geop_list, index_list)

    def _mark_as_input(self, ge_tensor):
        global global_input_cnt
        global_input_cnt += 1
        self.var2geop["geinput." + str(global_input_cnt)] = ge_tensor

    def _accumulated_op_id(self):
        global global_cnt
        global_cnt += 1
207 208
        name = "." + str(global_cnt)
        return name
209 210

    def _create_ge_tensor(self, shape, dtype, value):
211 212 213
        tensor_desc = core.GETensorDesc(core.GEShape(shape),
                                        core.GEFormat.FORMAT_ND,
                                        self.ascend_helper.dtype2ge(dtype))
214 215
        tensor = core.GETensor(tensor_desc)

216 217
        data = (value * np.ones(
            (shape))).reshape(shape).astype(self.ascend_helper.dtype2np(dtype))
218 219 220 221 222
        buf = data.tobytes()
        data_8 = np.frombuffer(buf, dtype=np.uint8)
        tensor.set_data(data_8)
        return tensor

223
    def _get_ge_tensor(self, shape, dtype, value_list):
224 225 226
        tensor_desc = core.GETensorDesc(core.GEShape(shape),
                                        core.GEFormat.FORMAT_ND,
                                        self.ascend_helper.dtype2ge(dtype))
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
        tensor = core.GETensor(tensor_desc)

        data = np.array(value_list).reshape(shape).astype(
            self.ascend_helper.dtype2np(dtype))
        buf = data.tobytes()
        data_8 = np.frombuffer(buf, dtype=np.uint8)
        tensor.set_data(data_8)

        tensor_const = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor)

        return tensor_const

    def _get_variable(self, shape, dtype, tensor):
        if dtype == "int32":
            type = core.GEDataType.DT_INT32
        elif dtype == "float32":
            type = core.GEDataType.DT_FLOAT

        var = core.GEOperatorFactory.create_operator(
            "variable" + self._accumulated_op_id(), "Variable")
249 250 251 252
        var.update_output_desc(
            "y",
            core.GETensorDesc(core.GEShape(shape), core.GEFormat.FORMAT_ND,
                              type))
253
        assign = core.GEOperatorFactory.create_operator(
254 255
            "assign" + self._accumulated_op_id(),
            "Assign").set_input("value", tensor).set_input("ref", var)
256 257 258 259

        return assign

    def _create_shape_tensor(self):
260 261 262
        tensor_desc = core.GETensorDesc(core.GEShape([2]),
                                        core.GEFormat.FORMAT_ND,
                                        core.GEDataType.DT_INT32)
263 264 265 266 267 268 269 270 271 272 273
        tensor = core.GETensor(tensor_desc)

        data = np.ones((2)).astype("int32").reshape([2])
        data[0] = 64
        buf = data.tobytes()
        data_8 = np.frombuffer(buf, dtype=np.uint8)
        tensor.set_data(data_8)
        return tensor

    def _get_GEtensor_shape(self, tensor):
        tensor_shape = core.GEOperatorFactory.create_operator(
274 275
            "shape" + self._accumulated_op_id(),
            "Shape").set_input("x", tensor)
276
        tensor_shape = core.GEOperatorFactory.create_operator(
277 278
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", tensor_shape).set_attr_int32("dst_type", 0)
279 280
        return tensor_shape

281 282

class AddParser(AscendParserBase):
283

284 285 286 287 288 289 290 291
    def __init__(self, graph, var2geop):
        super(AddParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_add"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        add = core.GEOperatorFactory.create_operator(
292 293
            "add" + self._accumulated_op_id(),
            "Add").set_input("x1", x).set_input("x2", y)
294 295 296
        return [add], [[0]]


297
class DotSubParser(AscendParserBase):
298

299
    def __init__(self, graph, var2geop):
300 301
        super(DotSubParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_sub"
302 303 304

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
305 306 307 308 309
        y = self._get_ge_input(self.op.input_arg_names[1])
        sub = core.GEOperatorFactory.create_operator(
            "sub" + self._accumulated_op_id(),
            "Sub").set_input("x1", x).set_input("x2", y)
        return [sub], [[0]]
310 311


312
class DotMulParser(AscendParserBase):
313

314
    def __init__(self, graph, var2geop):
315 316
        super(DotMulParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_mul"
317 318 319

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
320 321 322 323 324
        y = self._get_ge_input(self.op.input_arg_names[1])
        mul = core.GEOperatorFactory.create_operator(
            "dotmul" + self._accumulated_op_id(),
            "Mul").set_input("x1", x).set_input("x2", y)
        return [mul], [[0]]
325 326


327
class DotDivParser(AscendParserBase):
328

329 330 331 332 333 334 335 336 337 338 339
    def __init__(self, graph, var2geop):
        super(DotDivParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_div"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        div = core.GEOperatorFactory.create_operator(
            "dotdiv" + self._accumulated_op_id(),
            "Div").set_input("x1", x).set_input("x2", y)
        return [div], [[0]]
340 341


342
class DotPowParser(AscendParserBase):
343

344
    def __init__(self, graph, var2geop):
345 346
        super(DotPowParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_pow"
347 348

    def _apply(self):
349 350 351 352
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        pow = core.GEOperatorFactory.create_operator(
            "dotpow" + self._accumulated_op_id(),
J
Jiangxinz 已提交
353
            "Pow").set_input("x1", x).set_input("x2", y)
354
        return [pow], [[0]]
355 356


357
class LessParser(AscendParserBase):
358

359
    def __init__(self, graph, var2geop):
360 361
        super(LessParser, self).__init__(graph, var2geop)
        self.parser_name = "less_than"
362 363

    def _apply(self):
364 365 366 367 368 369
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        less_than = core.GEOperatorFactory.create_operator(
            "less_than" + self._accumulated_op_id(),
            "Less").set_input("x1", x).set_input("x2", y)
        return [less_than], [[0]]
370 371


372
class MaxParser(AscendParserBase):
373

374 375 376
    def __init__(self, graph, var2geop):
        super(MaxParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_max"
377

378 379 380 381 382 383 384 385 386 387
    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        max_out = core.GEOperatorFactory.create_operator(
            "max" + self._accumulated_op_id(),
            "Maximum").set_input("x1", x).set_input("x2", y)
        return [max_out], [[0]]


class MinParser(AscendParserBase):
388

389
    def __init__(self, graph, var2geop):
390 391
        super(MinParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_min"
392 393

    def _apply(self):
394 395 396 397 398 399
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        min_out = core.GEOperatorFactory.create_operator(
            "min" + self._accumulated_op_id(),
            "Minimum").set_input("x1", x).set_input("x2", y)
        return [min_out], [[0]]
400 401


402 403
## cal
class LogParser(AscendParserBase):
404

405 406 407 408 409 410 411 412 413 414 415 416
    def __init__(self, graph, var2geop):
        super(LogParser, self).__init__(graph, var2geop)
        self.parser_name = "log"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        log = core.GEOperatorFactory.create_operator(
            "log" + self._accumulated_op_id(), "Log").set_input("x", x)
        return [log], [[0]]


class SqrtParser(AscendParserBase):
417

418 419 420 421 422 423 424 425 426 427 428 429
    def __init__(self, graph, var2geop):
        super(SqrtParser, self).__init__(graph, var2geop)
        self.parser_name = "sqrt"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        sqrt = core.GEOperatorFactory.create_operator(
            "sqrt" + self._accumulated_op_id(), "Sqrt").set_input("x", x)
        return [sqrt], [[0]]


class PowParser(AscendParserBase):
430

431 432 433 434 435 436 437 438 439 440
    def __init__(self, graph, var2geop):
        super(PowParser, self).__init__(graph, var2geop)
        self.parser_name = "pow"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        factor = self.op.attr("factor")
        pow_value = core.GEOperatorFactory.create_operator(
            "pow" + self._accumulated_op_id(),
            "Power").set_input("x", x).set_attr_float(
441 442 443
                "power",
                factor).set_attr_float("scale",
                                       1.0).set_attr_float("shift", 0.0)
444 445 446 447
        return [pow_value], [[0]]


class SquareParser(AscendParserBase):
448

449 450 451 452 453 454 455 456 457 458 459 460
    def __init__(self, graph, var2geop):
        super(SquareParser, self).__init__(graph, var2geop)
        self.parser_name = "square"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        square = core.GEOperatorFactory.create_operator(
            "square" + self._accumulated_op_id(), "Square").set_input("x", x)
        return [square], [[0]]


class SumParser(AscendParserBase):
461

462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    def __init__(self, graph, var2geop):
        super(SumParser, self).__init__(graph, var2geop)
        self.parser_name = "sum"

    def _apply(self):
        len_list = len(self.op.input_arg_names)
        if len_list < 2:
            assert False, "the size of input list must large or equal 2"
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        sum = core.GEOperatorFactory.create_operator(
            "sum" + self._accumulated_op_id(),
            "Add").set_input("x1", x).set_input("x2", y)
        for i in range(2, len_list):
            y = self._get_ge_input(self.op.input_arg_names[i])
            sum = core.GEOperatorFactory.create_operator(
                "sum" + self._accumulated_op_id(),
                "Add").set_input("x1", sum).set_input("x2", y)
        return [sum], [[0]]


class LogicalNotParser(AscendParserBase):
484

485 486 487 488 489 490 491 492 493 494 495 496 497
    def __init__(self, graph, var2geop):
        super(LogicalNotParser, self).__init__(graph, var2geop)
        self.parser_name = "logical_not"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        logical_not = core.GEOperatorFactory.create_operator(
            "logical_not" + self._accumulated_op_id(),
            "LogicalNot").set_input("x", x)
        return [logical_not], [[0]]


class MeanParser(AscendParserBase):
498

499 500 501 502 503 504 505
    def __init__(self, graph, var2geop):
        super(MeanParser, self).__init__(graph, var2geop)
        self.parser_name = "mean"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        mean = core.GEOperatorFactory.create_operator(
506 507 508
            "mean" + self._accumulated_op_id(), "ReduceMeanD").set_input(
                "x", x).set_attr_bool("keep_dims",
                                      False).set_attr_vec_int32("axes", [])
509 510 511 512
        return [mean], [[0]]


class ReduceSumParser(AscendParserBase):
513

514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
    def __init__(self, graph, var2geop):
        super(ReduceSumParser, self).__init__(graph, var2geop)
        self.parser_name = "reduce_sum"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        axes = self.op.attr("dim")
        keep_dims = self.op.attr("keep_dim")
        reduce_all = self.op.attr("reduce_all")
        x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        if reduce_all:
            axes = list(range(len(x_shape)))
        reduce_sum = core.GEOperatorFactory.create_operator(
            "reduce_sum" + self._accumulated_op_id(),
            "ReduceSumD").set_input("x", x, 0).set_attr_vec_int32(
                "axes", axes).set_attr_bool("keep_dims", keep_dims)
        return [reduce_sum], [[0]]


#class IncrementParser(AscendParserBase):
#    def __init__(self, graph, var2geop):
#        super(IncrementParser, self).__init__(graph, var2geop)
#        self.parser_name = "increment"
#
538
#    def _apply(self):
539 540 541
#        x = self._get_ge_input(self.op.input_arg_names[0])
#        step = self.op.attr("step") #self._get_ge_input(self.op.input_arg_names[1])
#        print("step: ", step)
542
#
543
#        increment = core.GEOperatorFactory.create_operator("adds" + self._accumulated_op_id(), "Adds").set_input("x", x).set_attr_float("value", step) #set_input("x2", bias)
544
#
545 546 547 548 549
#        return [increment]


## matrix cal
class MatMulParser(AscendParserBase):
550

551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
    def __init__(self, graph, var2geop):
        super(MatMulParser, self).__init__(graph, var2geop)
        self.parser_name = "matmul"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
        transpose_x = self.op.attr("transpose_X")
        transpose_y = self.op.attr("transpose_Y")

        x1_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        x2_shape = self.op.block.var(self.op.input_arg_names[1]).shape

        if len(x1_shape) > 2:
            matmul = core.GEOperatorFactory.create_operator(
                "matmul" + self._accumulated_op_id(), "BatchMatMul").set_input(
                    "x1", x).set_input("x2", y).set_attr_bool(
                        "adj_x1",
                        transpose_x).set_attr_bool("adj_x2", transpose_y)
        elif len(x1_shape) == 2:
            matmul = core.GEOperatorFactory.create_operator(
                "matmul" + self._accumulated_op_id(),
                "MatMul").set_input("x1", x).set_input("x2", y).set_attr_bool(
574 575
                    "transpose_x1",
                    transpose_x).set_attr_bool("transpose_x2", transpose_y)
576 577 578
        else:
            assert False, "not support"
        return [matmul], [[0]]
579 580 581


class MulParser(AscendParserBase):
582

583 584 585 586 587 588 589
    def __init__(self, graph, var2geop):
        super(MulParser, self).__init__(graph, var2geop)
        self.parser_name = "mul"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        y = self._get_ge_input(self.op.input_arg_names[1])
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
        x_num_col_dims = self.op.attr("x_num_col_dims")
        y_num_col_dims = self.op.attr("y_num_col_dims")
        shape_x1 = self.op.block.var(self.op.input_arg_names[0]).shape
        shape_x2 = self.op.block.var(self.op.input_arg_names[1]).shape

        if x_num_col_dims == 1 and y_num_col_dims == 1:
            if len(shape_x1) == 2 and len(shape_x2) == 2:
                matmul = core.GEOperatorFactory.create_operator(
                    "mul" + self._accumulated_op_id(),
                    "MatMul").set_input("x1", x).set_input("x2", y)
            elif len(shape_x1) == 3 and len(shape_x2) == 2:
                flatten_x1 = core.GEOperatorFactory.create_operator(
                    "flatten" + self._accumulated_op_id(),
                    "Flatten").set_input("x", x)
                matmul = core.GEOperatorFactory.create_operator(
605 606 607
                    "mul" + self._accumulated_op_id(),
                    "MatMul").set_input("x1", flatten_x1,
                                        0).set_input("x2", y, 0)
608 609 610 611 612 613 614 615 616 617
            else:
                assert False, "not support"
        else:
            if len(shape_x1) == 3 and len(shape_x2) == 2:
                assert x_num_col_dims == 2, "only support 2"
                flatten_x1 = core.GEOperatorFactory.create_operator(
                    "flatten" + self._accumulated_op_id(),
                    "FlattenV2").set_input("x", x).set_attr_int32(
                        "axis", 0).set_attr_int32("end_axis", 1)
                matmul_m = core.GEOperatorFactory.create_operator(
618 619 620
                    "mul" + self._accumulated_op_id(),
                    "MatMul").set_input("x1", flatten_x1,
                                        0).set_input("x2", y, 0)
621 622
                matmul_transpose = core.GEOperatorFactory.create_operator(
                    "transpose" + self._accumulated_op_id(),
623 624
                    "TransposeD").set_input("x", matmul_m).set_attr_vec_int32(
                        "perm", [1, 0])
625 626 627 628 629 630
                tensor = self._create_ge_tensor(
                    [3], 2, [shape_x2[1], shape_x1[0], shape_x1[1]])
                const_shape = core.GEOperatorFactory.create_operator(
                    "shape" + self._accumulated_op_id(),
                    "Const").set_attr_tensor("value", tensor)
                reshape_matmul = core.GEOperatorFactory.create_operator(
631 632 633
                    "reshape" + self._accumulated_op_id(),
                    "Reshape").set_input("x", matmul_transpose).set_input(
                        "shape", const_shape).set_attr_int32("axis", 0)
634 635
                matmul = core.GEOperatorFactory.create_operator(
                    "transpose" + self._accumulated_op_id(),
636 637 638
                    "TransposeD").set_input("x",
                                            reshape_matmul).set_attr_vec_int32(
                                                "perm", [1, 2, 0])
639 640
            else:
                assert False, "not support"
641 642 643 644

        return [matmul], [[0]]


645
class LayerNormParser(AscendParserBase):
646

647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666
    def __init__(self, graph, var2geop):
        super(LayerNormParser, self).__init__(graph, var2geop)
        self.parser_name = "layer_norm"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[2])
        scale = self._get_ge_input(self.op.input_arg_names[1])
        bias = self._get_ge_input(self.op.input_arg_names[0])
        epsilon = self.op.attr("epsilon")
        begin_norm_axis = self.op.attr("begin_norm_axis")
        x_dtype = self.op.block.var(self.op.input_arg_names[2]).dtype

        shape_tensor = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
        scale_expand = core.GEOperatorFactory.create_operator(
            "broadcast_to_d" + self._accumulated_op_id(),
            "BroadcastTo").set_input("x",
                                     scale).set_input("shape", shape_tensor)
        bias_expand = core.GEOperatorFactory.create_operator(
            "broadcast_to_d" + self._accumulated_op_id(),
667 668
            "BroadcastTo").set_input("x",
                                     bias).set_input("shape", shape_tensor)
669 670 671 672 673 674 675 676 677 678 679 680
        layer_norm = core.GEOperatorFactory.create_operator(
            "layer_norm" + self._accumulated_op_id(),
            "LayerNorm").set_input("x", x).set_input(
                "gamma",
                scale_expand).set_input("beta", bias_expand).set_attr_int32(
                    "begin_norm_axis", begin_norm_axis).set_attr_int32(
                        "begin_params_axis",
                        begin_norm_axis).set_attr_float("epsilon", epsilon)

        cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str(
            x_dtype)] == 0 else 1
        y = core.GEOperatorFactory.create_operator(
681 682 683
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", layer_norm,
                              0).set_attr_int32("dst_type", cast_dtype)
684
        mean = core.GEOperatorFactory.create_operator(
685 686 687
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", layer_norm,
                              1).set_attr_int32("dst_type", cast_dtype)
688
        variance = core.GEOperatorFactory.create_operator(
689 690 691
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", layer_norm,
                              2).set_attr_int32("dst_type", cast_dtype)
692 693 694 695
        return [y, mean, variance], [[1], [2], [0]]


## activate function
696
class ReluParser(AscendParserBase):
697

698 699 700 701 702 703 704 705 706 707 708
    def __init__(self, graph, var2geop):
        super(ReluParser, self).__init__(graph, var2geop)
        self.parser_name = "relu"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        relu = core.GEOperatorFactory.create_operator(
            "relu" + self._accumulated_op_id(), "Relu").set_input("x", x)
        return [relu], [[0]]


709
class GeluParser(AscendParserBase):
710

711
    def __init__(self, graph, var2geop):
712 713
        super(GeluParser, self).__init__(graph, var2geop)
        self.parser_name = "gelu"
714 715

    def _apply(self):
716 717 718 719 720 721 722
        x = self._get_ge_input(self.op.input_arg_names[0])
        gelu = core.GEOperatorFactory.create_operator(
            "gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x)
        return [gelu], [[0]]


class TanhParser(AscendParserBase):
723

724 725 726 727 728 729 730 731 732
    def __init__(self, graph, var2geop):
        super(TanhParser, self).__init__(graph, var2geop)
        self.parser_name = "tanh"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        tanh = core.GEOperatorFactory.create_operator(
            "tanh" + self._accumulated_op_id(), "Tanh").set_input("x", x)
        return [tanh], [[0]]
733 734


735
## loss function
736
class SoftmaxWithCrossEntropyParser(AscendParserBase):
737

738 739 740 741 742 743 744 745
    def __init__(self, graph, var2geop):
        super(SoftmaxWithCrossEntropyParser, self).__init__(graph, var2geop)
        self.parser_name = "softmax_with_cross_entropy"

    def _apply(self):
        label = self._get_ge_input(self.op.input_arg_names[0])
        logits = self._get_ge_input(self.op.input_arg_names[1])
        cls_num = self.op.block.var(self.op.input_arg_names[1]).shape[1]
746

747
        softmax = core.GEOperatorFactory.create_operator(
748 749
            "softmax" + self._accumulated_op_id(),
            "SoftmaxV2").set_input("x", logits)
750
        label = core.GEOperatorFactory.create_operator(
751 752
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", label).set_attr_int32("dst_type", 3)
753 754

        tensoron = self._create_ge_tensor([1], 5, 1)
755 756 757
        on = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensoron)
758
        tensoroff = self._create_ge_tensor([1], 5, 0)
759 760 761 762 763
        off = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensoroff)
        self._mark_as_input(on)
        self._mark_as_input(off)
764
        onehot = core.GEOperatorFactory.create_operator(
765 766 767 768 769
            "onehot" + self._accumulated_op_id(),
            "OneHotD").set_input("x",
                                 label).set_input("on_value", on).set_input(
                                     "off_value",
                                     off).set_attr_int32("depth", cls_num)
770
        squeeze = core.GEOperatorFactory.create_operator(
771 772
            "mul" + self._accumulated_op_id(),
            "Squeeze").set_input("x", onehot)
773 774

        loss_all = core.GEOperatorFactory.create_operator(
775
            "loss" + self._accumulated_op_id(),
776 777 778
            "SoftmaxCrossEntropyWithLogits").set_input("features",
                                                       logits).set_input(
                                                           "labels", squeeze)
779
        loss = core.GEOperatorFactory.create_operator(
780 781
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", loss_all, 0).set_attr_int32("dst_type", 0)
782 783 784 785
        loss_expand = core.GEOperatorFactory.create_operator(
            "unsqueeze" + self._accumulated_op_id(),
            "Unsqueeze").set_input("x", loss).set_attr_vec_int32("axes", [1])
        return [label, softmax, loss_expand], [[2], [1]]
786 787


788
class SoftMaxParser(AscendParserBase):
789

790
    def __init__(self, graph, var2geop):
791 792
        super(SoftMaxParser, self).__init__(graph, var2geop)
        self.parser_name = "softmax"
793 794

    def _apply(self):
795 796
        logits = self._get_ge_input(self.op.input_arg_names[0])
        axes = self.op.attr("axis")
797

798
        softmax = core.GEOperatorFactory.create_operator(
799 800 801
            "softmax" + self._accumulated_op_id(),
            "SoftmaxV2").set_input("x",
                                   logits).set_attr_vec_int32("axes", [axes])
802
        return [softmax], [[0]]
803 804


805
## general
806
class ShapeParser(AscendParserBase):
807

808 809 810 811 812 813 814 815 816 817 818 819
    def __init__(self, graph, var2geop):
        super(ShapeParser, self).__init__(graph, var2geop)
        self.parser_name = "shape"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        shape = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
        return [shape], [[0]]


class FillConstantParser(AscendParserBase):
820

821 822 823 824 825 826 827 828
    def __init__(self, graph, var2geop):
        super(FillConstantParser, self).__init__(graph, var2geop)
        self.parser_name = "fill_constant"

    def _apply(self):
        shape = self.op.attr("shape")
        dtype = self.op.attr("dtype")
        value = self.op.attr("value")
829

830 831
        tensor = self._create_ge_tensor(shape, dtype, value)
        const = core.GEOperatorFactory.create_operator(
832 833
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor)
834 835
        self._mark_as_input(const)
        if self.op.block.var(self.op.output('Out')[0]).persistable:
836 837
            #print("%s is Persistable in fill_constant" %
            #      (self.op.output('Out')[0]))
838 839
            var = core.GEOperatorFactory.create_operator(
                self.op.output('Out')[0], "Variable")
840 841 842 843
            var.update_output_desc(
                "y",
                core.GETensorDesc(core.GEShape(shape), core.GEFormat.FORMAT_ND,
                                  core.GEDataType.DT_FLOAT))
844
            assign = core.GEOperatorFactory.create_operator(
845 846
                "assign" + self._accumulated_op_id(),
                "Assign").set_input("value", const).set_input("ref", var)
847
            return [const], [[0]]
848
        return [const], [[0]]
849 850 851


class TruncatedNormalParser(AscendParserBase):
852

853 854 855 856 857 858 859 860 861 862
    def __init__(self, graph, var2geop):
        super(TruncatedNormalParser, self).__init__(graph, var2geop)
        self.parser_name = "truncated_gaussian_random"

    def _apply(self):
        shape = self.op.attr("shape")
        dtype = self.op.attr("dtype")
        mean = self.op.attr("mean")
        std = self.op.attr("std")
        seed = self.op.attr("seed")
863

864 865
        tensor1 = self._create_ge_tensor([len(shape)], 2, shape)
        shape_tensor = core.GEOperatorFactory.create_operator(
866 867
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor1)
868 869
        tensor2 = self._create_ge_tensor([1], dtype, mean)
        mean_tensor = core.GEOperatorFactory.create_operator(
870 871
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor2)
872 873
        tensor3 = self._create_ge_tensor([1], dtype, std)
        std_tensor = core.GEOperatorFactory.create_operator(
874 875
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor3)
876 877
        tensor4 = self._create_ge_tensor([1], dtype, mean - 2 * std)
        min_tensor = core.GEOperatorFactory.create_operator(
878 879
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor4)
880 881
        tensor5 = self._create_ge_tensor([1], dtype, mean + 2 * std)
        max_tensor = core.GEOperatorFactory.create_operator(
882 883
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor5)
884 885 886 887 888 889 890 891 892 893

        self._mark_as_input(shape_tensor)
        self._mark_as_input(mean_tensor)
        self._mark_as_input(std_tensor)
        self._mark_as_input(min_tensor)
        self._mark_as_input(max_tensor)

        truncated_normal = core.GEOperatorFactory.create_operator(
            "truncated_normal" + self._accumulated_op_id(),
            "ParameterizedTruncatedNormal").set_input(
894 895 896 897 898
                "shape",
                shape_tensor).set_input("means", mean_tensor).set_input(
                    "stdevs",
                    std_tensor).set_input("min", min_tensor).set_input(
                        "max", max_tensor).set_attr_int32("seed", 0)
899 900 901

        ## wirte the output of truncatedNormal from startup_program to main_program
        if self.op.block.var(self.op.output('Out')[0]).persistable:
902 903
            #print("%s is Persistable in truncated_normal" %
            #      (self.op.output('Out')[0]))
904 905
            var = core.GEOperatorFactory.create_operator(
                self.op.output('Out')[0], "Variable")
906 907 908 909
            var.update_output_desc(
                "y",
                core.GETensorDesc(core.GEShape(shape), core.GEFormat.FORMAT_ND,
                                  core.GEDataType.DT_FLOAT))
910
            assign = core.GEOperatorFactory.create_operator(
911 912 913
                "assign" + self._accumulated_op_id(),
                "Assign").set_input("value",
                                    truncated_normal).set_input("ref", var)
914 915 916 917
            return [
                shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor,
                truncated_normal
            ], [[-1]]
918 919 920 921 922
        #else:
        #    print(
        #        "self.op.output('Out')[0] is not persistable in truncated_noraml"
        #    )
        return [truncated_normal], [[0]]
923 924


925
class GatherParser(AscendParserBase):
926

927
    def __init__(self, graph, var2geop):
928 929
        super(GatherParser, self).__init__(graph, var2geop)
        self.parser_name = "gather"
930 931

    def _apply(self):
932 933 934 935 936
        index = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])
        clo = self.op.block.var(self.op.input_arg_names[1]).shape[-1]

        gather = core.GEOperatorFactory.create_operator(
937 938 939 940
            "gather" + self._accumulated_op_id(),
            "Gather").set_input("x", x).set_input("indices",
                                                  index).set_attr_bool(
                                                      "validate_indices", True)
941 942 943 944
        return [gather], [[0]]


class ScatterParser(AscendParserBase):
945

946 947 948 949 950 951 952 953 954 955 956 957 958
    def __init__(self, graph, var2geop):
        super(ScatterParser, self).__init__(graph, var2geop)
        self.parser_name = "scatter"

    def _apply(self):
        index = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])
        updates = self._get_ge_input(self.op.input_arg_names[2])
        overwrite = self.op.attr("overwrite")
        index_shape = self.op.block.var(self.op.input_arg_names[0]).shape

        if len(index_shape) == 1:
            index = core.GEOperatorFactory.create_operator(
959 960 961
                "unsqueeze" + self.getid(),
                "Unsqueeze").set_input("x",
                                       index).set_attr_vec_int32("axes", [1])
962 963 964
        if not overwrite:
            scatter_value = core.GEOperatorFactory.create_operator(
                "scatter" + self._accumulated_op_id(),
965 966
                "TensorScatterAdd").set_input("x", x).set_input(
                    "indices", index).set_input("updates", updates)
967 968 969
        else:
            scatter_value = core.GEOperatorFactory.create_operator(
                "scatter" + self._accumulated_op_id(),
970 971
                "TensorScatterUpdate").set_input("x", x).set_input(
                    "indices", index).set_input("updates", updates)
J
Jiangxinz 已提交
972
        return [x, index, updates, scatter_value], [[-1]]
973 974 975


class CastParser(AscendParserBase):
976

977 978 979 980 981 982 983 984
    def __init__(self, graph, var2geop):
        super(CastParser, self).__init__(graph, var2geop)
        self.parser_name = "cast"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        dtype = self.op.attr("out_dtype")
        cast = core.GEOperatorFactory.create_operator(
985 986
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", x).set_attr_int32("dst_type", dtype)
987 988 989 990
        return [cast], [[0]]


class AssignParser(AscendParserBase):
991

992 993 994 995 996 997 998 999
    def __init__(self, graph, var2geop):
        super(AssignParser, self).__init__(graph, var2geop)
        self.parser_name = "assign"

    def _apply(self):
        const = self._get_ge_input(self.op.input_arg_names[0])
        var = self._get_ge_input(self.op.input_arg_names[1])
        assign = core.GEOperatorFactory.create_operator(
1000 1001
            "assign" + self._accumulated_op_id(),
            "Assign").set_input("value", const).set_input("ref", var)
1002 1003 1004 1005
        return [assign], [[0]]


class ScaleParser(AscendParserBase):
1006

1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
    def __init__(self, graph, var2geop):
        super(ScaleParser, self).__init__(graph, var2geop)
        self.parser_name = "scale"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        scale = self.op.attr("scale")
        bias = self.op.attr("bias")
        bias_after_scale = self.op.attr("bias_after_scale")

        if bias_after_scale:
            scale_value = core.GEOperatorFactory.create_operator(
1019 1020 1021 1022 1023
                "scale" + self._accumulated_op_id(),
                "Power").set_input("x", x).set_attr_float(
                    "power",
                    1.0).set_attr_float("scale",
                                        scale).set_attr_float("shift", bias)
1024 1025
        else:
            x_add_bias = core.GEOperatorFactory.create_operator(
1026 1027
                "adds" + self._accumulated_op_id(),
                "Adds").set_input("x", x).set_attr_float("value", bias)
1028
            scale_value = core.GEOperatorFactory.create_operator(
1029 1030 1031 1032 1033
                "scale" + self._accumulated_op_id(),
                "Power").set_input("x", x_add_bias).set_attr_float(
                    "power",
                    1.0).set_attr_float("scale",
                                        scale).set_attr_float("shift", 0.0)
1034 1035 1036
        return [scale_value], [[0]]


1037
class SliceParser(AscendParserBase):
1038

1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066
    def __init__(self, graph, var2geop):
        super(SliceParser, self).__init__(graph, var2geop)
        self.parser_name = "slice"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        axes = self.op.attr("axes")
        starts = self.op.attr("starts")
        ends = self.op.attr("ends")

        x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        len_shape = len(x_shape)
        axes_cor = list(range(len_shape))
        starts_cor, ends_cor = [], []
        cnt = 0
        for i in range(len_shape):
            starts_cor.append(starts[cnt] if i in axes else 0)
            if i in axes and ends[cnt] <= x_shape[i]:
                ends_cor.append(ends[cnt])
            else:
                ends_cor.append(x_shape[i])
            if i in axes:
                cnt += 1
        size = [ends_cor[i] - starts_cor[i] for i in range(len(axes_cor))]

        assert len(axes_cor) == len(starts_cor) == len(
            ends_cor), "the three fields must have same size"
        slice_value = core.GEOperatorFactory.create_operator(
1067 1068 1069
            "slice" + self._accumulated_op_id(),
            "SliceD").set_input("x", x).set_attr_vec_int32(
                "offsets", starts_cor).set_attr_vec_int32("size", size)
1070 1071 1072 1073

        return [slice_value], [[0]]


1074
class ReshapeParser(AscendParserBase):
1075

1076 1077 1078 1079 1080
    def __init__(self, graph, var2geop):
        super(ReshapeParser, self).__init__(graph, var2geop)
        self.parser_name = "reshape2"

    def _apply(self):
1081 1082
        org_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        assert org_shape.count(-1) == 0, "do not allow the dim is -1"
1083
        shape = self.op.attr("shape")
1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
        for cnt in range(len(shape)):
            if shape[cnt] == 0:
                shape[cnt] = org_shape[cnt]

        if -1 in shape:
            assert shape.count(-1) == 1, "only allow one dim is -1"
            mul_res_org = reduce(lambda x, y: x * y, org_shape)
            mul_res_refine = reduce(lambda x, y: x * y, shape) * -1
            idx = shape.index(-1)
            shape[idx] = mul_res_org // mul_res_refine

        x = self._get_ge_input(self.op.input_arg_names[0])
1096 1097
        tensor = self._create_ge_tensor([len(shape)], 2, shape)
        const_shape = core.GEOperatorFactory.create_operator(
1098 1099
            "shape" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor)
1100
        reshape = core.GEOperatorFactory.create_operator(
1101 1102 1103 1104
            "reshape" + self._accumulated_op_id(),
            "Reshape").set_input("x", x).set_input("shape",
                                                   const_shape).set_attr_int32(
                                                       "axis", 0)
1105 1106 1107 1108 1109 1110 1111
        x_shape = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(), "Shape").set_input("x", x)

        return [x_shape, reshape], [[1], [0]]


class TransposeParser(AscendParserBase):
1112

1113 1114 1115 1116 1117 1118 1119 1120
    def __init__(self, graph, var2geop):
        super(TransposeParser, self).__init__(graph, var2geop)
        self.parser_name = "transpose2"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        perm = self.op.attr("axis")
        transpose = core.GEOperatorFactory.create_operator(
1121 1122
            "transpose" + self._accumulated_op_id(),
            "TransposeD").set_input("x", x).set_attr_vec_int32("perm", perm)
1123 1124 1125 1126 1127 1128 1129
        x_shape = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(), "Shape").set_input("x", x)

        return [x_shape, transpose], [[1], [0]]


class AccuracyParser(AscendParserBase):
1130

1131 1132 1133 1134 1135 1136 1137 1138 1139 1140
    def __init__(self, graph, var2geop):
        super(AccuracyParser, self).__init__(graph, var2geop)
        self.parser_name = "accuracy"

    def _apply(self):
        pred = self._get_ge_input(self.op.input_arg_names[0])
        label = self._get_ge_input(self.op.input_arg_names[1])
        logits = self._get_ge_input(self.op.input_arg_names[2])

        pred = core.GEOperatorFactory.create_operator(
1141 1142
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", pred).set_attr_int32("dst_type", 3)
1143
        label = core.GEOperatorFactory.create_operator(
1144 1145
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", label).set_attr_int32("dst_type", 3)
1146
        equal = core.GEOperatorFactory.create_operator(
1147 1148
            "equal" + self._accumulated_op_id(),
            "Equal").set_input("x1", pred).set_input("x2", label)
1149
        cast = core.GEOperatorFactory.create_operator(
1150 1151
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", equal).set_attr_int32("dst_type", 0)
1152 1153
        acc = core.GEOperatorFactory.create_operator(
            "mean" + self._accumulated_op_id(), "ReduceMeanD").set_input(
1154 1155
                "x", cast).set_attr_bool("keep_dims",
                                         False).set_attr_vec_int32("axes", [])
1156 1157
        correct = core.GEOperatorFactory.create_operator(
            "sum" + self._accumulated_op_id(), "ReduceSumD").set_input(
1158 1159
                "x", cast).set_attr_bool("keep_dims",
                                         False).set_attr_vec_int32("axes", [])
1160 1161 1162 1163
        ones_tensor = core.GEOperatorFactory.create_operator(
            "oneslike" + self._accumulated_op_id(),
            "OnesLike").set_input("x", label)
        ones_tensor = core.GEOperatorFactory.create_operator(
1164 1165
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", ones_tensor).set_attr_int32("dst_type", 0)
1166
        total = core.GEOperatorFactory.create_operator(
1167 1168 1169
            "sum" + self._accumulated_op_id(),
            "ReduceSumD").set_input("x", ones_tensor).set_attr_bool(
                "keep_dims", False).set_attr_vec_int32("axes", [])
1170 1171 1172 1173 1174

        return [acc, correct, total], [[0], [1], [2]]


class TopkParser(AscendParserBase):
1175

1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194
    def __init__(self, graph, var2geop):
        super(TopkParser, self).__init__(graph, var2geop)
        self.parser_name = "top_k"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        k = self.op.attr("k")

        tensor = self._create_ge_tensor([1], 2, k)
        const_k = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor)
        cast_x = core.GEOperatorFactory.create_operator(
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", x).set_attr_int32("dst_type", 1)
        topk = core.GEOperatorFactory.create_operator(
            "topk" + self._accumulated_op_id(),
            "TopK").set_input("x", cast_x).set_input("k", const_k)
        value = core.GEOperatorFactory.create_operator(
1195 1196
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", topk, 0).set_attr_int32("dst_type", 0)
1197
        index = core.GEOperatorFactory.create_operator(
1198 1199
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", topk, 1).set_attr_int32("dst_type", 0)
1200 1201 1202 1203
        return [value, index], [[1], [0]]


class LookupTableParser(AscendParserBase):
1204

1205 1206 1207 1208 1209 1210 1211 1212 1213
    def __init__(self, graph, var2geop):
        super(LookupTableParser, self).__init__(graph, var2geop)
        self.parser_name = "lookup_table"

    def _apply(self):
        ids = self._get_ge_input(self.op.input_arg_names[0])
        w = self._get_ge_input(self.op.input_arg_names[1])

        ids_squeeze = core.GEOperatorFactory.create_operator(
1214 1215
            "squeeze" + self._accumulated_op_id(),
            "Squeeze").set_input("x", ids).set_attr_vec_int32("axes", [-1])
1216
        out = core.GEOperatorFactory.create_operator(
1217 1218
            "lookup" + self._accumulated_op_id(),
            "Gather").set_input("x", w).set_input("indices", ids_squeeze)
1219 1220 1221 1222
        return [out], [[0]]


class StackParser(AscendParserBase):
1223

1224 1225 1226 1227 1228 1229 1230 1231
    def __init__(self, graph, var2geop):
        super(StackParser, self).__init__(graph, var2geop)
        self.parser_name = "stack"

    def _apply(self):
        tiles = len(self.op.input_arg_names)
        data_x_lst = []
        for index in range(tiles):
1232 1233
            data_x_lst.append(self._get_ge_input(
                self.op.input_arg_names[index]))
1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
        axis = self.op.attr("axis")

        data_x = data_x_lst[0]
        tensor = self._create_ge_tensor([1], 2, axis)
        tensor_axis = core.GEOperatorFactory.create_operator(
            "axis" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor)
        expand = core.GEOperatorFactory.create_operator(
            "expand" + self._accumulated_op_id(),
            "ExpandDims").set_input("x", data_x).set_input("axis", tensor_axis)

        stack = core.GEOperatorFactory.create_operator(
1246 1247 1248 1249
            "stack" + self._accumulated_op_id(), "TileWithAxis").set_input(
                "x",
                expand).set_attr_int32("axis",
                                       axis).set_attr_int32("tiles", tiles)
1250 1251 1252 1253 1254

        return [stack], [[0]]


class UnSqueezeParser(AscendParserBase):
1255

1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267
    def __init__(self, graph, var2geop):
        super(UnSqueezeParser, self).__init__(graph, var2geop)
        self.parser_name = "unsqueeze2"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        axes = self.op.attr('axes')

        output = core.GEOperatorFactory.create_operator(
            "unsqueeze" + self._accumulated_op_id(),
            "Unsqueeze").set_input("x", x).set_attr_vec_int32("axes", axes)
        shape = core.GEOperatorFactory.create_operator(
1268 1269
            "shape" + self._accumulated_op_id(),
            "Shape").set_input("x", output)
1270 1271 1272 1273 1274
        return [shape, output], [[1], [0]]


## parallel
class AllGatherParser(AscendParserBase):
1275

1276 1277 1278 1279 1280 1281 1282 1283 1284 1285
    def __init__(self, graph, var2geop):
        super(AllGatherParser, self).__init__(graph, var2geop)
        self.parser_name = "c_allgather"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        rank_size = self.op.attr("rank_size")
        group = self.op.attr("group")

        allgather = core.GEOperatorFactory.create_operator(
1286 1287 1288
            "allgather" + self._accumulated_op_id(),
            "HcomAllGather").set_input("x", x).set_attr_int32(
                "rank_size", rank_size).set_attr_string("group", group)
1289 1290 1291 1292
        return [allgather], [[0]]


class AllReduceParser(AscendParserBase):
1293

1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307
    def __init__(self, graph, var2geop, reduction):
        super(AllReduceParser, self).__init__(graph, var2geop)
        self.parser_name = "c_allreduce_" + reduction
        self.reduction = reduction

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        reduction = self.reduction
        ring_id = self.op.attr("ring_id")
        group = "hcom_group_" + str(ring_id)
        fusion = None  #self.op.attr("fusion")
        fusion_id = None  #self.op.attr("fusion_id")

        allreduce = core.GEOperatorFactory.create_operator(
1308 1309 1310
            "allreduce" + self._accumulated_op_id(),
            "HcomAllReduce").set_input("x", x).set_attr_string(
                "reduction", reduction).set_attr_string("group", group)
1311 1312 1313 1314 1315 1316 1317 1318 1319
        if fusion is not None:
            allreduce.set_attr_int32("fusion", fusion)

        if fusion_id is not None:
            allreduce.set_attr_int32("fusion_id", fusion_id)
        return [allreduce], [[0]]


class AllReduceSumParser(AllReduceParser):
1320

1321 1322 1323 1324 1325
    def __init__(self, graph, var2geop):
        super(AllReduceSumParser, self).__init__(graph, var2geop, 'sum')


class AllReduceMaxParser(AllReduceParser):
1326

1327 1328 1329 1330 1331
    def __init__(self, graph, var2geop):
        super(AllReduceMaxParser, self).__init__(graph, var2geop, 'max')


class BroadcastParser(AscendParserBase):
1332

1333 1334 1335 1336 1337 1338 1339 1340 1341 1342
    def __init__(self, graph, var2geop):
        super(BroadcastParser, self).__init__(graph, var2geop)
        self.parser_name = "c_broadcast"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        root_rank = self.op.attr("root_rank")
        group = self.op.attr("group")

        broadcast = core.GEOperatorFactory.create_operator(
1343 1344 1345
            "broadcast" + self._accumulated_op_id(),
            "HcomBroadcast").set_input("x", x).set_attr_int32(
                "root_rank", root_rank).set_attr_string("group", group)
1346 1347 1348 1349
        return [broadcast], [[0]]


class ReduceScatterParser(AscendParserBase):
1350

1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363
    def __init__(self, graph, var2geop):
        super(ReduceScatterParser, self).__init__(graph, var2geop)
        self.parser_name = "c_reduce_scatter"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        reduction = self.op.attr("reduction")
        group = self.op.attr("group")
        rank_size = self.op.attr("rank_size")

        reduce_scatter = core.GEOperatorFactory.create_operator(
            "reducescatter" + self._accumulated_op_id(),
            "HcomReduceScatter").set_input("x", x).set_attr_string(
1364 1365 1366
                "reduction",
                reduction).set_attr_string("group", group).set_attr_int32(
                    "rank_size", rank_size)
1367 1368 1369 1370
        return [reduce_scatter], [[0]]


class SendParser(AscendParserBase):
1371

1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389
    def __init__(self, graph, var2geop):
        super(SendParser, self).__init__(graph, var2geop)
        self.parser_name = "c_send"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        sr_tag = self.op.attr("sr_tag")
        dest_rank = self.op.attr("dest_rank")
        group = self.op.attr("group")

        send = core.GEOperatorFactory.create_operator(
            "send" + self._accumulated_op_id(), "HcomSend").set_input(
                "x", x).set_attr_int32("sr_tag", sr_tag).set_attr_int32(
                    "dest_rank", dest_rank).set_attr_string("group", group)
        return [send], [[0]]


class ReceiveParser(AscendParserBase):
1390

1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403
    def __init__(self, graph, var2geop):
        super(ReceiveParser, self).__init__(graph, var2geop)
        self.parser_name = "c_receive"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        sr_tag = self.op.attr("sr_tag")
        src_rank = self.op.attr("src_rank")
        group = self.op.attr("group")
        shape = self.op.attr("shape")
        dtype = self.op.attr("dtype")

        receive = core.GEOperatorFactory.create_operator(
1404 1405 1406 1407 1408 1409 1410
            "receive" + self._accumulated_op_id(),
            "HcomReceive").set_input("x", x).set_attr_int32(
                "sr_tag",
                sr_tag).set_attr_int32("src_rank", src_rank).set_attr_string(
                    "group", group).set_attr_vec_int32("shape",
                                                       shape).set_attr_int32(
                                                           "dtype", dtype)
1411 1412 1413 1414
        return [receive], [[0]]


class RangeParser(AscendParserBase):
1415

1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435
    def __init__(self, graph, var2geop):
        super(RangeParser, self).__init__(graph, var2geop)
        self.parser_name = "range"

    def _apply(self):
        # TODO not support range type yet
        start = self._get_ge_input(self.op.input_arg_names[0])
        end = self._get_ge_input(self.op.input_arg_names[1])
        delta = self._get_ge_input(self.op.input_arg_names[2])

        ge_range = core.GEOperatorFactory.create_operator(
            "range" + self._accumulated_op_id(), "Range")\
              .set_input("start", end)\
              .set_input("limit", start) \
              .set_input("delta", delta)

        return [ge_range], [[0]]


class UniformRandomParser(AscendParserBase):
1436

1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447
    def __init__(self, graph, var2geop):
        super(UniformRandomParser, self).__init__(graph, var2geop)
        self.parser_name = "uniform_random"

    def _apply(self):
        shape = self.op.attr("shape")

        min_v = self.op.attr("min")
        max_v = self.op.attr("max")
        seed = self.op.attr("seed")
        dtype = self.op.attr("dtype")
1448
        assert max_v > min_v, "assert max_v > min_v, but received " + \
1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465
               "as max_v={}, min_v={} ".format(max_v, min_v)

        tensor1 = self._create_ge_tensor([len(shape)], 2, shape)
        shape_tensor = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor1)

        ge_ur = core.GEOperatorFactory.create_operator(
            "uniform_random" + self._accumulated_op_id(), "RandomUniform")\
            .set_input("shape", shape_tensor)\
            .set_attr_dtype("dtype", self.ascend_helper.dtype2ge(dtype))  \
            .set_attr_int32("seed", seed)\
            .set_attr_int32("seed2", seed)

        scale = max_v - min_v

        scale_value = core.GEOperatorFactory.create_operator(
1466 1467 1468 1469 1470
            "scale" + self._accumulated_op_id(),
            "Power").set_input("x", ge_ur).set_attr_float(
                "power",
                1.0).set_attr_float("scale",
                                    scale).set_attr_float("shift", min_v)
1471 1472 1473 1474 1475

        return [scale_value], [[0]]


class EqualParser(AscendParserBase):
1476

1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491
    def __init__(self, graph, var2geop):
        super(EqualParser, self).__init__(graph, var2geop)
        self.parser_name = "equal"

    def _apply(self):
        data_x1 = self._get_ge_input(self.op.input_arg_names[0])
        data_x2 = self._get_ge_input(self.op.input_arg_names[1])
        equal = core.GEOperatorFactory.create_operator("equal" \
           + self._accumulated_op_id(), "Equal")\
             .set_input("x1", data_x1)\
             .set_input("x2", data_x2)
        return [equal], [[0]]


class ExpandParser(AscendParserBase):
1492

1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513
    def __init__(self, graph, var2geop):
        super(ExpandParser, self).__init__(graph, var2geop)
        self.parser_name = "expand"

    def _apply(self):
        data_x1_shape = self._get_ge_input(self.op.input_arg_names[0])
        expand_times = self.op.attr('expand_times')

        tensor = self._create_ge_tensor([len(expand_times)], 2, expand_times)
        expand_tensor = core.GEOperatorFactory.\
           create_operator("const" + self._accumulated_op_id(), "Const")\
              .set_attr_tensor("value", tensor)

        assign = core.GEOperatorFactory\
           .create_operator("tile" + self._accumulated_op_id(), "Tile")\
              .set_input("x", data_x1_shape)\
              .set_input("multiples", expand_tensor)
        return [assign], [[0]]


class SqueezeParser(AscendParserBase):
1514

1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541
    def __init__(self, graph, var2geop):
        super(SqueezeParser, self).__init__(graph, var2geop)
        self.parser_name = "squeeze2"

    def _apply(self):
        tensor = self._get_ge_input(self.op.input_arg_names[0])
        axes = self.op.attr("axes")

        data_squeezed = core.GEOperatorFactory\
           .create_operator("squeeze" + self._accumulated_op_id(), "Squeeze")\
             .set_input("x", tensor)\
             .set_attr_vec_int32("axes", axes)
        shape = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(),
            "Shape").set_input("x", data_squeezed)
        return [shape, data_squeezed], [[1], [0]]


#****************************************************************#
#***************************            *************************#
#***************************            *************************#
#*************************** GradParser *************************#
#***************************            *************************#
#***************************            *************************#
#****************************************************************#
## grad
class ReduceSumGradParser(AscendParserBase):
1542

1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568
    def __init__(self, graph, var2geop):
        super(ReduceSumGradParser, self).__init__(graph, var2geop)
        self.parser_name = "reduce_sum_grad"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        input = self._get_ge_input(self.op.input_arg_names[1])

        shape_tensor = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(),
            "Shape").set_input("x", input, 0)
        tensoron = self._create_ge_tensor([1], 2, -1)
        const = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensoron)
        self._mark_as_input(const)

        reduce_sum = core.GEOperatorFactory.create_operator(
            "broadcast_to_d" + self._accumulated_op_id(),
            "BroadcastTo").set_input("x", x).set_input("shape", shape_tensor)
        #reduce_sum = core.GEOperatorFactory.create_operator("expand" + self._accumulated_op_id(), "ExpandDims").set_input("x", reduce_sum).set_input("axis", const)

        return [reduce_sum], [[0]]


class MatMulGradParser(AscendParserBase):
1569

1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589
    def __init__(self, graph, var2geop):
        super(MatMulGradParser, self).__init__(graph, var2geop)
        self.parser_name = "matmul_grad"

    def _apply(self):
        out_grad = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])
        y = self._get_ge_input(self.op.input_arg_names[2])
        transpose_x = self.op.attr("transpose_X")
        transpose_y = self.op.attr("transpose_Y")

        out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        x_shape = self.op.block.var(self.op.input_arg_names[1]).shape
        y_shape = self.op.block.var(self.op.input_arg_names[2]).shape

        if len(x_shape) > 2:
            if transpose_y:
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "BatchMatMul").set_input("x1", out_grad).set_input(
1590 1591 1592
                        "x2",
                        y).set_attr_bool("adj_x1",
                                         False).set_attr_bool("adj_x2", False)
1593 1594 1595
                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "BatchMatMul").set_input("x1", out_grad).set_input(
1596 1597 1598
                        "x2",
                        x).set_attr_bool("adj_x1",
                                         True).set_attr_bool("adj_x2", False)
1599 1600 1601 1602
            else:
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "BatchMatMul").set_input("x1", out_grad).set_input(
1603 1604 1605
                        "x2",
                        y).set_attr_bool("adj_x1",
                                         False).set_attr_bool("adj_x2", True)
1606 1607
                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
1608 1609
                    "BatchMatMul").set_input(
                        "x1", x).set_input("x2", out_grad).set_attr_bool(
1610 1611 1612 1613 1614 1615
                            "adj_x1", True).set_attr_bool("adj_x2", False)
        else:
            if transpose_y:
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "MatMul").set_input("x1", out_grad).set_input(
1616 1617 1618
                        "x2", y).set_attr_bool("transpose_x1",
                                               False).set_attr_bool(
                                                   "transpose_x2", False)
1619 1620 1621
                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "MatMul").set_input("x1", out_grad).set_input(
1622 1623 1624
                        "x2", x).set_attr_bool("transpose_x1",
                                               True).set_attr_bool(
                                                   "transpose_x2", False)
1625 1626 1627 1628
            else:
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "MatMul").set_input("x1", out_grad).set_input(
1629 1630 1631
                        "x2", y).set_attr_bool("transpose_x1",
                                               False).set_attr_bool(
                                                   "transpose_x2", True)
1632 1633 1634
                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "MatMul").set_input("x1", x).set_input(
1635 1636 1637
                        "x2", out_grad).set_attr_bool("transpose_x1",
                                                      True).set_attr_bool(
                                                          "transpose_x2", False)
1638 1639 1640 1641 1642

        return [x_grad, y_grad], [[0], [1]]


class MulGradParser(AscendParserBase):
1643

1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663
    def __init__(self, graph, var2geop):
        super(MulGradParser, self).__init__(graph, var2geop)
        self.parser_name = "mul_grad"

    def _apply(self):
        out_grad = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])
        y = self._get_ge_input(self.op.input_arg_names[2])
        x_num_col_dims = self.op.attr("x_num_col_dims")
        y_num_col_dims = self.op.attr("y_num_col_dims")

        shape_out_grad = self.op.block.var(self.op.input_arg_names[0]).shape
        shape_x = self.op.block.var(self.op.input_arg_names[1]).shape
        shape_y = self.op.block.var(self.op.input_arg_names[2]).shape

        if x_num_col_dims == 1 and y_num_col_dims == 1:
            if len(shape_x) == 2 and len(shape_y) == 2:
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "MatMul").set_input("x1", out_grad).set_input(
1664 1665 1666
                        "x2", y).set_attr_bool("transpose_x1",
                                               False).set_attr_bool(
                                                   "transpose_x2", True)
1667 1668 1669
                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "MatMul").set_input("x1", x).set_input(
1670 1671 1672
                        "x2", out_grad).set_attr_bool("transpose_x1",
                                                      True).set_attr_bool(
                                                          "transpose_x2", False)
1673 1674 1675 1676 1677 1678
            elif len(shape_x) == 3 and len(shape_y) == 2:
                flatten_x = core.GEOperatorFactory.create_operator(
                    "flatten" + self._accumulated_op_id(),
                    "Flatten").set_input("x", x)
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
1679 1680 1681 1682
                    "MatMul").set_input("x1", out_grad).set_input(
                        "x2", y).set_attr_bool("transpose_x1",
                                               False).set_attr_bool(
                                                   "transpose_x2", True)
1683 1684 1685 1686 1687 1688 1689 1690
                if len(shape_out_grad) == 2:
                    x_grad = core.GEOperatorFactory.create_operator(
                        "unsqueeze" + self._accumulated_op_id(),
                        "Unsqueeze").set_input("x", x_grad).set_attr_vec_int32(
                            "axes", [1])

                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
1691 1692 1693 1694
                    "MatMul").set_input("x1", flatten_x).set_input(
                        "x2", out_grad).set_attr_bool("transpose_x1",
                                                      True).set_attr_bool(
                                                          "transpose_x2", False)
1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710
        else:
            if len(shape_x) == 3 and len(shape_y) == 2:
                assert x_num_col_dims == 2, "only support 2"
                flatten_x = core.GEOperatorFactory.create_operator(
                    "flatten" + self._accumulated_op_id(),
                    "FlattenV2").set_input("x", x).set_attr_int32(
                        "axis", 0).set_attr_int32("end_axis", 1)
                flatten_out_grad = core.GEOperatorFactory.create_operator(
                    "flatten" + self._accumulated_op_id(),
                    "FlattenV2").set_input("x", out_grad).set_attr_int32(
                        "axis", 0).set_attr_int32("end_axis", 1)

                y_unsqueeze = core.GEOperatorFactory.create_operator(
                    "unsqueeze" + self._accumulated_op_id(),
                    "Unsqueeze").set_input("x",
                                           y).set_attr_vec_int32("axes", [0])
1711 1712 1713 1714
                y_stack = core.GEOperatorFactory.create_operator(
                    "stack" + self._accumulated_op_id(),
                    "TileWithAxis").set_input("x", y_unsqueeze).set_attr_int32(
                        "axis", 0).set_attr_int32("tiles", shape_out_grad[0])
1715 1716 1717
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "BatchMatMul").set_input("x1", out_grad).set_input(
1718 1719 1720
                        "x2", y_stack).set_attr_bool("adj_x1",
                                                     False).set_attr_bool(
                                                         "adj_x2", True)
1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731
                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "MatMul").set_input("x1", flatten_x).set_input(
                        "x2", flatten_out_grad).set_attr_bool(
                            "transpose_x1",
                            True).set_attr_bool("transpose_x2", False)

        return [x_grad, y_grad], [[0], [1]]


class ReluGradParser(AscendParserBase):
1732

1733 1734 1735 1736 1737 1738 1739 1740
    def __init__(self, graph, var2geop):
        super(ReluGradParser, self).__init__(graph, var2geop)
        self.parser_name = "relu_grad"

    def _apply(self):
        out = self._get_ge_input(self.op.input_arg_names[0])
        out_grad = self._get_ge_input(self.op.input_arg_names[1])
        relu_grad = core.GEOperatorFactory.create_operator(
1741 1742 1743
            self.parser_name + self._accumulated_op_id(),
            "ReluGrad").set_input("gradients",
                                  out_grad).set_input("features", out)
1744 1745 1746 1747
        return [relu_grad], [[0]]


class SoftmaxWithCrossEntropyGradParser(AscendParserBase):
1748

1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774
    def __init__(self, graph, var2geop):
        super(SoftmaxWithCrossEntropyGradParser, self).__init__(graph, var2geop)
        self.parser_name = "softmax_with_cross_entropy_grad"

    def _apply(self):
        label = self._get_ge_input(self.op.input_arg_names[0])
        loss_grad = self._get_ge_input(self.op.input_arg_names[1])
        softmax = self._get_ge_input(self.op.input_arg_names[2])
        cls_num = self.op.block.var(self.op.input_arg_names[2]).shape[1]

        label_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        loss_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
        softmax_shape = self.op.block.var(self.op.input_arg_names[2]).shape

        tensoron = self._create_ge_tensor([1], 5, 1)
        on = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensoron)
        tensoroff = self._create_ge_tensor([1], 5, 0)
        off = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensoroff)
        self._mark_as_input(on)
        self._mark_as_input(off)

        label = core.GEOperatorFactory.create_operator(
1775 1776
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", label).set_attr_int32("dst_type", 3)
1777
        onehot = core.GEOperatorFactory.create_operator(
1778 1779 1780 1781 1782
            "onehot" + self._accumulated_op_id(),
            "OneHotD").set_input("x",
                                 label).set_input("on_value", on).set_input(
                                     "off_value",
                                     off).set_attr_int32("depth", cls_num)
1783 1784 1785 1786
        squeeze = core.GEOperatorFactory.create_operator(
            "suqeeze" + self._accumulated_op_id(),
            "Squeeze").set_input("x", onehot)
        sub = core.GEOperatorFactory.create_operator(
1787 1788
            "sub" + self._accumulated_op_id(),
            "Sub").set_input("x1", softmax).set_input("x2", squeeze)
1789 1790 1791 1792 1793 1794 1795 1796
        grad = core.GEOperatorFactory.create_operator(
            "mul" + self._accumulated_op_id(),
            "Mul").set_input("x1", loss_grad).set_input("x2", sub)

        return [on, off, label, onehot, grad], [[-1]]


class DotMulGradParser(AscendParserBase):
1797

1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817
    def __init__(self, graph, var2geop):
        super(DotMulGradParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_mul_grad"

    def _apply(self):
        out_grad = self._get_ge_input(self.op.input_arg_names[0])
        out_1 = self._get_ge_input(self.op.input_arg_names[1])
        out_2 = self._get_ge_input(self.op.input_arg_names[2])

        x_grad = core.GEOperatorFactory.create_operator(
            self.parser_name + self._accumulated_op_id(),
            "Mul").set_input("x1", out_grad).set_input("x2", out_2)
        y_grad = core.GEOperatorFactory.create_operator(
            self.parser_name + self._accumulated_op_id(),
            "Mul").set_input("x1", out_1).set_input("x2", out_grad)

        return [x_grad, y_grad], [[0], [1]]


class DotAddGradParser(AscendParserBase):
1818

1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862
    def __init__(self, graph, var2geop):
        super(DotAddGradParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_add_grad"

    def _apply(self):
        out_grad = self._get_ge_input(self.op.input_arg_names[0])
        out_1 = self._get_ge_input(self.op.input_arg_names[1])
        out_2 = self._get_ge_input(self.op.input_arg_names[2])
        out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        out_1_shape = self.op.block.var(self.op.input_arg_names[1]).shape
        out_2_shape = self.op.block.var(self.op.input_arg_names[2]).shape

        x_grad = out_grad
        cur_time_x = len(out_grad_shape) - len(out_1_shape)
        for i in range(cur_time_x):
            x_grad = core.GEOperatorFactory.create_operator(
                self.parser_name + self._accumulated_op_id(),
                "ReduceSumD").set_input("x", x_grad).set_attr_vec_int32(
                    "axes", [0]).set_attr_bool("keep_dims", False)
        for axis, size in enumerate(out_1_shape):
            if size == 1:
                x_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "ReduceSumD").set_input("x", x_grad).set_attr_vec_int32(
                        "axes", [axis]).set_attr_bool("keep_dims", True)

        y_grad = out_grad
        cur_time_y = len(out_grad_shape) - len(out_2_shape)
        for i in range(cur_time_y):
            y_grad = core.GEOperatorFactory.create_operator(
                self.parser_name + self._accumulated_op_id(),
                "ReduceSumD").set_input("x", y_grad).set_attr_vec_int32(
                    "axes", [0]).set_attr_bool("keep_dims", False)
        for axis, size in enumerate(out_2_shape):
            if size == 1:
                y_grad = core.GEOperatorFactory.create_operator(
                    self.parser_name + self._accumulated_op_id(),
                    "ReduceSumD").set_input("x", y_grad).set_attr_vec_int32(
                        "axes", [axis]).set_attr_bool("keep_dims", True)

        return [x_grad, y_grad], [[0], [1]]


class DotDivGradParser(AscendParserBase):
1863

1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874
    def __init__(self, graph, var2geop):
        super(DotDivGradParser, self).__init__(graph, var2geop)
        self.parser_name = "elementwise_div_grad"

    def _apply(self):
        out = self._get_ge_input(self.op.input_arg_names[0])
        out_grad = self._get_ge_input(self.op.input_arg_names[1])
        x = self._get_ge_input(self.op.input_arg_names[2])
        y = self._get_ge_input(self.op.input_arg_names[3])

        y_power = core.GEOperatorFactory.create_operator(
1875 1876
            "power" + self._accumulated_op_id(),
            "Power").set_input("x", y).set_attr_float("power", -1)
1877 1878 1879 1880 1881

        tensor_zeros = core.GEOperatorFactory.create_operator(
            "zeroslike" + self._accumulated_op_id(),
            "ZerosLike").set_input("x", x)
        x_zero = core.GEOperatorFactory.create_operator(
1882 1883
            "equal" + self._accumulated_op_id(),
            "Equal").set_input("x1", x).set_input("x2", tensor_zeros)
1884 1885 1886 1887
        x_nozero = core.GEOperatorFactory.create_operator(
            "logical_not" + self._accumulated_op_id(),
            "LogicalNot").set_input("x", x_zero)
        x_nozero_f = core.GEOperatorFactory.create_operator(
1888 1889
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", x_nozero).set_attr_int32("dst_type", 0)
1890
        x_grad_w = core.GEOperatorFactory.create_operator(
1891 1892
            "mul" + self._accumulated_op_id(),
            "Mul").set_input("x1", x_nozero_f).set_input("x2", y_power)
1893 1894 1895 1896 1897
        x_grad = core.GEOperatorFactory.create_operator(
            self.parser_name + self._accumulated_op_id(),
            "Mul").set_input("x1", x_grad_w).set_input("x2", out_grad)

        y_grad_w = core.GEOperatorFactory.create_operator(
1898 1899
            "mul" + self._accumulated_op_id(),
            "Mul").set_input("x1", out).set_input("x2", y_power)
1900
        y_grad = core.GEOperatorFactory.create_operator(
1901 1902
            "mul" + self._accumulated_op_id(),
            "Mul").set_input("x1", y_grad_w).set_input("x2", out_grad)
1903 1904 1905 1906 1907

        return [x_grad, y_grad], [[0], [1]]


class SoftmaxGradParser(AscendParserBase):
1908

1909 1910 1911 1912 1913 1914 1915 1916 1917 1918
    def __init__(self, graph, var2geop):
        super(SoftmaxGradParser, self).__init__(graph, var2geop)
        self.parser_name = "softmax_grad"

    def _apply(self):
        out = self._get_ge_input(self.op.input_arg_names[0])
        out_grad = self._get_ge_input(self.op.input_arg_names[1])

        x_grad = core.GEOperatorFactory.create_operator(
            self.parser_name + self._accumulated_op_id(),
1919 1920
            "SoftmaxGrad").set_input("softmax",
                                     out).set_input("grad_softmax", out_grad)
1921 1922 1923 1924
        return [x_grad], [[0]]


class ReshapeGradParser(AscendParserBase):
1925

1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942
    def __init__(self, graph, var2geop):
        super(ReshapeGradParser, self).__init__(graph, var2geop)
        self.parser_name = "reshape2_grad"

    def _apply(self):
        out_grad = self._get_ge_input(self.op.input_arg_names[0])
        x_shape = self._get_ge_input(self.op.input_arg_names[1])
        x_shape_list = self.op.block.var(self.op.input_arg_names[1]).shape

        if x_shape_list[0] == 0:
            x_shape_delzero = x_shape_list[1:]
        tensor = self._create_ge_tensor([len(x_shape_delzero)], 2,
                                        x_shape_delzero)
        const_shape = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", tensor)
        x_grad = core.GEOperatorFactory.create_operator(
1943 1944
            "reshape" + self._accumulated_op_id(),
            "Reshape").set_input("x", out_grad).set_input("shape", const_shape)
1945 1946

        return [x_grad], [[0]]
1947

1948 1949

class GatherGradParser(AscendParserBase):
1950

1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965
    def __init__(self, graph, var2geop):
        super(GatherGradParser, self).__init__(graph, var2geop)
        self.parser_name = "gather_grad"

    def _apply(self):
        index = self._get_ge_input(self.op.input_arg_names[0])
        out_grad = self._get_ge_input(self.op.input_arg_names[1])
        x = self._get_ge_input(self.op.input_arg_names[2])

        index_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        out_grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape
        x_shape = self.op.block.var(self.op.input_arg_names[2]).shape

        if len(index_shape) == 1:
            index = core.GEOperatorFactory.create_operator(
1966 1967 1968
                "unsqueeze" + self._accumulated_op_id(),
                "Unsqueeze").set_input("x",
                                       index).set_attr_vec_int32("axes", [1])
1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981

        tensor_zeros = core.GEOperatorFactory.create_operator(
            "zeroslike" + self._accumulated_op_id(),
            "ZerosLike").set_input("x", x)
        x_grad = core.GEOperatorFactory.create_operator(
            "scatter" + self._accumulated_op_id(),
            "TensorScatterUpdate").set_input("x", tensor_zeros).set_input(
                "indices", index).set_input("updates", out_grad)

        return [tensor_zeros, x_grad], [[-1]]


class TransposeGradParser(AscendParserBase):
1982

1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996
    def __init__(self, graph, var2geop):
        super(TransposeGradParser, self).__init__(graph, var2geop)
        self.parser_name = "transpose2_grad"

    def _apply(self):
        out_grad = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])
        perm = self.op.attr("axis")

        x_shape = self.op.block.var(self.op.input_arg_names[1]).shape[1:]
        out_grad_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        assert list(map(lambda x: out_grad_shape[x], perm)) == list(x_shape)

        x_grad = core.GEOperatorFactory.create_operator(
1997 1998 1999
            "transpose" + self._accumulated_op_id(),
            "TransposeD").set_input("x",
                                    out_grad).set_attr_vec_int32("perm", perm)
2000 2001 2002 2003 2004

        return [x_grad], [[0]]


class LayerNormGradParser(AscendParserBase):
2005

2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021
    def __init__(self, graph, var2geop):
        super(LayerNormGradParser, self).__init__(graph, var2geop)
        self.parser_name = "layer_norm_grad"

    def _apply(self):
        bias = self._get_ge_input(self.op.input_arg_names[0])
        mean = self._get_ge_input(self.op.input_arg_names[1])
        scale = self._get_ge_input(self.op.input_arg_names[2])
        variance = self._get_ge_input(self.op.input_arg_names[3])
        x = self._get_ge_input(self.op.input_arg_names[4])
        out_grad = self._get_ge_input(self.op.input_arg_names[5])
        x_dtype = self.op.block.var(self.op.input_arg_names[4]).dtype

        x_grad = core.GEOperatorFactory.create_operator(
            self.parser_name + self._accumulated_op_id(),
            "LayerNormGrad").set_input("dy", out_grad).set_input(
2022 2023 2024
                "x", x).set_input("variance",
                                  variance).set_input("mean", mean).set_input(
                                      "gamma", scale)
2025 2026 2027 2028

        cast_dtype = 0 if self.ascend_helper.dtype2paddle_inv_map[str(
            x_dtype)] == 0 else 1
        out_x_grad = core.GEOperatorFactory.create_operator(
2029 2030 2031
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", x_grad,
                              0).set_attr_int32("dst_type", cast_dtype)
2032
        out_scale_grad = core.GEOperatorFactory.create_operator(
2033 2034 2035
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", x_grad,
                              1).set_attr_int32("dst_type", cast_dtype)
2036
        out_bias_grad = core.GEOperatorFactory.create_operator(
2037 2038 2039
            "cast" + self._accumulated_op_id(),
            "Cast").set_input("x", x_grad,
                              2).set_attr_int32("dst_type", cast_dtype)
2040 2041 2042 2043 2044

        return [out_x_grad, out_scale_grad, out_bias_grad], [[2], [1], [0]]


class TanhGradParser(AscendParserBase):
2045

2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060
    def __init__(self, graph, var2geop):
        super(TanhGradParser, self).__init__(graph, var2geop)
        self.parser_name = 'tanh_grad'

    def _apply(self):
        y = self._get_ge_input(self.op.input_arg_names[0])
        out_grad = self._get_ge_input(self.op.input_arg_names[1])
        tanh_grad = core.GEOperatorFactory.create_operator(
            "tanh_grad" + self._accumulated_op_id(),
            "TanhGrad").set_input("y", y).set_input("dy", out_grad)

        return [tanh_grad], [[0]]


class LogGradParser(AscendParserBase):
2061

2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075
    def __init__(self, graph, var2geop):
        super(LogGradParser, self).__init__(graph, var2geop)
        self.parser_name = 'log_grad'

    def _apply(self):
        grad = self._get_ge_input(self.op.input_arg_names[0])
        input = self._get_ge_input(self.op.input_arg_names[1])
        log_grad = core.GEOperatorFactory.create_operator(
            "log_grad" + self._accumulated_op_id(),
            "DivNoNan").set_input("x1", grad).set_input("x2", input)
        return [log_grad], [[0]]


class SqrtGradParser(AscendParserBase):
2076

2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090
    def __init__(self, graph, var2geop):
        super(SqrtGradParser, self).__init__(graph, var2geop)
        self.parser_name = "sqrt_grad"

    def _apply(self):
        y = self._get_ge_input(self.op.input_arg_names[0])
        out_grad = self._get_ge_input(self.op.input_arg_names[1])
        sqrt_grad = core.GEOperatorFactory.create_operator(
            "sqrt_grad" + self._accumulated_op_id(),
            "SqrtGrad").set_input("y", y).set_input("dy", out_grad)
        return [sqrt_grad]


class PowGradParser(AscendParserBase):
2091

2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109
    def __init__(self, graph, var2geop):
        super(PowGradParser, self).__init__(graph, var2geop)
        self.parser_name = "pow_grad"

    def _apply(self):
        grad = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])
        factor = self.op.attr("factor")

        shape_tensor = self._create_shape_tensor()
        shape_tensor = core.GEOperatorFactory.create_operator(
            "shape" + self._accumulated_op_id(), "Shape").set_input("x", x)
        factor_scale = self._create_ge_tensor([1], 5, factor)
        factor_scale = core.GEOperatorFactory.create_operator(
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value", factor_scale)
        factor_tensor = core.GEOperatorFactory.create_operator(
            "broadcast_to_d" + self._accumulated_op_id(),
2110 2111
            "BroadcastTo").set_input("x", factor_scale).set_input(
                "shape", shape_tensor)
2112 2113

        x_power = core.GEOperatorFactory.create_operator(
2114 2115
            "x_power" + self._accumulated_op_id(),
            "Power").set_input("x", x).set_attr_float("power", factor - 1)
2116
        x_power_mul_factor = core.GEOperatorFactory.create_operator(
2117 2118
            "x_power_mul_factor" + self._accumulated_op_id(),
            "Mul").set_input("x1", x).set_input("x2", factor_tensor)
2119 2120 2121 2122 2123 2124 2125 2126
        x_power_mul_factor_grad = core.GEOperatorFactory.create_operator(
            "x_power_mul_factor_grad" + self._accumulated_op_id(),
            "Mul").set_input("x1", x_power_mul_factor).set_input("x2", grad)

        return [x_power_mul_factor_grad], [[0]]


class GeluGradParser(AscendParserBase):
2127

2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138
    def __init__(self, graph, var2geop):
        super(GeluGradParser, self).__init__(graph, var2geop)
        self.parser_name = "gelu_grad"

    def _apply(self):
        grad = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])

        y = core.GEOperatorFactory.create_operator(
            "gelu" + self._accumulated_op_id(), "Gelu").set_input("x", x)
        gelu_grad = core.GEOperatorFactory.create_operator(
2139 2140 2141
            "gelu_grad" + self._accumulated_op_id(),
            "GeluGrad").set_input("x", x).set_input("dy",
                                                    grad).set_input("y", y)
2142 2143 2144 2145 2146

        return [gelu_grad], [[0]]


class MeanGradParser(AscendParserBase):
2147

2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159
    def __init__(self, graph, var2geop):
        super(MeanGradParser, self).__init__(graph, var2geop)
        self.parser_name = "mean_grad"

    def _apply(self):
        grad = self._get_ge_input(self.op.input_arg_names[0])
        x = self._get_ge_input(self.op.input_arg_names[1])

        ones_tensor = core.GEOperatorFactory.create_operator(
            "one_tensor" + self._accumulated_op_id(),
            "OnesLike").set_input("x", x)
        sum = core.GEOperatorFactory.create_operator(
2160 2161 2162
            "mean" + self._accumulated_op_id(),
            "ReduceSumD").set_input("x", ones_tensor).set_attr_bool(
                "keep_dims", False).set_attr_vec_int32("axes", [])
2163
        mean = core.GEOperatorFactory.create_operator(
2164 2165
            "x_power" + self._accumulated_op_id(),
            "Power").set_input("x", sum).set_attr_float("power", -1)
2166 2167 2168 2169 2170 2171 2172 2173 2174

        mean_grad = core.GEOperatorFactory.create_operator(
            "mean_grad" + self._accumulated_op_id(),
            "Mul").set_input("x1", mean).set_input("x2", grad)

        return [mean_grad], [[0]]


class SliceGradParser(AscendParserBase):
2175

2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213
    def __init__(self, graph, var2geop):
        super(SliceGradParser, self).__init__(graph, var2geop)
        self.parser_name = "slice_grad"

    def _apply(self):
        x = self._get_ge_input(self.op.input_arg_names[0])
        grad = self._get_ge_input(self.op.input_arg_names[1])
        axes = self.op.attr("axes")
        starts = self.op.attr("starts")
        ends = self.op.attr("ends")

        x_shape = self.op.block.var(self.op.input_arg_names[0]).shape
        grad_shape = self.op.block.var(self.op.input_arg_names[1]).shape

        len_shape = len(x_shape)
        axes_cor = list(range(len_shape))
        starts_cor, ends_cor = [], []
        cnt = 0
        for i in range(len_shape):
            starts_cor.append(starts[cnt] if i in axes else 0)
            if i in axes and ends[cnt] <= x_shape[i]:
                ends_cor.append(x_shape[i] - ends[cnt])
            else:
                ends_cor.append(0)
            if i in axes:
                cnt += 1

        starts_cor[0] = 0
        ends_cor[0] = 0
        paddings = [[s, e] for (s, e) in zip(starts_cor, ends_cor)]
        slice_value = core.GEOperatorFactory.create_operator(
            "slice_grad" + self._accumulated_op_id(), "PadD").set_input(
                "x", grad).set_attr_vec_vec_int64("paddings", paddings)

        return [slice_value], [[0]]


class LookUpTableGradParser(AscendParserBase):
2214

2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229
    def __init__(self, graph, var2geop):
        super(LookUpTableGradParser, self).__init__(graph, var2geop)
        self.parser_name = "lookup_table_grad"

    def _apply(self):
        ids = self._get_ge_input(self.op.input_arg_names[0])
        grad = self._get_ge_input(self.op.input_arg_names[1])
        embedding = self._get_ge_input(self.op.input_arg_names[2])

        shape_ids = self.op.block.var(self.op.input_arg_names[0]).shape
        shape_grad = self.op.block.var(self.op.input_arg_names[1]).shape
        shape_embedding = self.op.block.var(self.op.input_arg_names[2]).shape

        ids_flatten = core.GEOperatorFactory.create_operator(
            "flatten" + self._accumulated_op_id(), "FlattenV2").set_input(
2230 2231
                "x", ids).set_attr_int32("axis",
                                         0).set_attr_int32("end_axis", 1)
2232 2233
        grad_flatten = core.GEOperatorFactory.create_operator(
            "flatten" + self._accumulated_op_id(), "FlattenV2").set_input(
2234 2235
                "x", grad).set_attr_int32("axis",
                                          0).set_attr_int32("end_axis", 1)
2236 2237 2238 2239 2240 2241

        tensor_zeros = core.GEOperatorFactory.create_operator(
            "zeroslike" + self._accumulated_op_id(),
            "ZerosLike").set_input("x", embedding)
        embedding_grad = core.GEOperatorFactory.create_operator(
            "scatteradd" + self._accumulated_op_id(),
2242 2243
            "TensorScatterAdd").set_input("x", tensor_zeros).set_input(
                "indices", ids_flatten).set_input("updates", grad_flatten)
2244 2245 2246 2247 2248

        return [embedding_grad], [[0]]


class SGDParser(AscendParserBase):
2249

2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265
    def __init__(self, graph, var2geop):
        super(SGDParser, self).__init__(graph, var2geop)
        self.parser_name = "sgd"

    def _apply(self):
        grad = self._get_ge_input(self.op.input_arg_names[0])
        lr = self._get_ge_input(self.op.input_arg_names[1])
        param = self._get_ge_input(self.op.input_arg_names[2])
        sgd = core.GEOperatorFactory.create_operator(
            "momentum" + self._accumulated_op_id(),
            "ApplyGradientDescent").set_input("var", param).set_input(
                "alpha", lr).set_input("delta", grad)
        return [sgd], [[0]]


class AdamParser(AscendParserBase):
2266

2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283
    def __init__(self, graph, var2geop):
        super(AdamParser, self).__init__(graph, var2geop)
        self.parser_name = "adam"

    def _apply(self):
        beta1_power = self._get_ge_input(self.op.input_arg_names[0])
        beta2_power = self._get_ge_input(self.op.input_arg_names[1])
        grad = self._get_ge_input(self.op.input_arg_names[2])
        lr = self._get_ge_input(self.op.input_arg_names[3])
        moment1 = self._get_ge_input(self.op.input_arg_names[4])
        moment2 = self._get_ge_input(self.op.input_arg_names[5])
        param = self._get_ge_input(self.op.input_arg_names[6])
        beta1 = self.op.attr('beta1')
        beta2 = self.op.attr('beta2')
        epsilon = self.op.attr('epsilon')

        beta1 = core.GEOperatorFactory.create_operator(
2284 2285 2286
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value",
                                     self._create_ge_tensor([1], 5, beta1))
2287
        beta2 = core.GEOperatorFactory.create_operator(
2288 2289 2290
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value",
                                     self._create_ge_tensor([1], 5, beta2))
2291
        epsilon = core.GEOperatorFactory.create_operator(
2292 2293 2294
            "const" + self._accumulated_op_id(),
            "Const").set_attr_tensor("value",
                                     self._create_ge_tensor([1], 5, epsilon))
2295 2296 2297 2298 2299 2300

        adam = core.GEOperatorFactory.create_operator(
            "adam" + self._accumulated_op_id(),
            "ApplyAdam").set_input("var", param).set_input(
                "m", moment1).set_input("v", moment2).set_input(
                    "beta1_power", beta1_power).set_input(
2301 2302 2303 2304
                        "beta2_power",
                        beta2_power).set_input("lr", lr).set_input(
                            "beta1", beta1).set_input("beta2", beta2).set_input(
                                "epsilon", epsilon).set_input("grad", grad)
2305 2306

        return [adam], [[0]]