prim2code.py 24.5 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15
NO_OUTPUT_COUNT = 0
S
SunAhong1993 已提交
16 17


S
SunAhong1993 已提交
18 19 20 21 22 23 24 25 26 27 28
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
29
def get_value(layer, key, layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
30 31 32 33 34 35
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
S
SunAhong1993 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
        if different_attrs is None:
            return str(layer.attrs[key])
        else:
            key_name = "{}_{}".format(layer.outputs[0], key)
            if key_name in different_attrs:
                return key_name
            else:
                if layer_id is None:
                    return str(layer.attrs[key])
                key_name = "{}_{}".format("layer_id/{}".format(layer_id), key)
                if key_name in different_attrs:
                    new_key_name = "attr_{}".format(NO_OUTPUT_COUNT)
                    NO_OUTPUT_COUNT += 1
                    diff_index = different_attrs.index(key_name)
                    different_attrs[diff_index] = new_key_name
                    return new_key_name
                else:
                    return str(layer.attrs[key])


S
SunAhong1993 已提交
56 57 58 59 60 61
def prim_add(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
S
SunAhong1993 已提交
62
    line = "{} = {} + {}".format(layer.outputs[0],
S
SunAhong1993 已提交
63 64
                                 get_value(layer, "x", different_attrs),
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
65 66 67
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
68 69 70 71 72 73
def prim_add_(layer,
              indent=1,
              init_func=[],
              forward_func=[],
              layer_id=None,
              different_attrs=None):
S
SunAhong1993 已提交
74
    line = "{} = {} + {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
75
                                      get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
76
                                      layer.attrs["alpha"],
S
SunAhong1993 已提交
77
                                      get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
78 79 80
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
81 82 83 84 85
def prim_and(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
S
SunAhong1993 已提交
86
             different_attrs=None):
S
SunAhong1993 已提交
87
    line = "{} = {} and {}".format(layer.outputs[0],
S
SunAhong1993 已提交
88 89
                                   get_value(layer, "x", different_attrs),
                                   get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
90 91 92
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
93 94 95 96 97 98
def prim_append(layer,
                indent=1,
                init_func=[],
                forward_func=[],
                layer_id=None,
                different_attrs=None):
S
SunAhong1993 已提交
99
    line = "{}.append({})".format(
S
SunAhong1993 已提交
100
        get_value(layer, "list", layer_id, different_attrs),
S
SunAhong1993 已提交
101
        get_value(layer, "element", layer_id, different_attrs))
S
SunAhong1993 已提交
102 103 104
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
105 106 107 108 109 110
def prim_assert(layer,
                indent=1,
                init_func=[],
                forward_func=[],
                layer_id=None,
                different_attrs=None):
S
SunAhong1993 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123
    if layer.attrs["type"] == "eq":
        values = get_value(layer, "key")
        if "value" in layer.attrs:
            values = layer.attrs["value"]
        if isinstance(values, list):
            s = ""
            for v in values:
                s += "{} == {} or ".format(get_value(layer, "key"), v)
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
                s, get_value(layer, "key"), get_value(layer, "value"))
        else:
S
SunAhong1993 已提交
124 125 126
            line = "assert {} == {}, \'The {} must be {}!\'".format(
                get_value(layer, "key"),
                get_value(layer, "value"),
S
SunAhong1993 已提交
127
                get_value(layer, "key"), get_value(layer, "value"))
S
SunAhong1993 已提交
128 129
    else:
        raise Exception("Not implement yet!")
S
SunAhong1993 已提交
130
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
131 132


S
SunAhong1993 已提交
133 134 135 136 137 138
def prim_check_dim(layer,
                   indent=1,
                   init_func=[],
                   forward_func=[],
                   layer_id=None,
                   different_attrs=None):
S
SunAhong1993 已提交
139
    lines = []
S
SunAhong1993 已提交
140 141
    dim = get_value(layer, "dim", different_attrs)
    lines.append("if {} < 0:".format(dim))
S
SunAhong1993 已提交
142
    lines.append("    {} = {} + {}".format(layer.outputs[
S
SunAhong1993 已提交
143
        0], dim, get_value(layer, "len", different_attrs)))
S
SunAhong1993 已提交
144
    lines.append("else:")
S
SunAhong1993 已提交
145
    lines.append("    {} = {}".format(layer.outputs[0], dim))
S
SunAhong1993 已提交
146 147 148
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
149 150 151 152 153 154
def prim_constant(layer,
                  indent=1,
                  init_func=[],
                  forward_func=[],
                  layer_id=None,
                  different_attrs=None):
S
SunAhong1993 已提交
155 156 157 158
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
159 160 161 162 163
def prim_contain(layer,
                 indent=1,
                 init_func=[],
                 forward_func=[],
                 layer_id=None,
S
SunAhong1993 已提交
164
                 different_attrs=None):
S
SunAhong1993 已提交
165
    line = "{} = {} in {}".format(layer.outputs[0],
S
SunAhong1993 已提交
166 167
                                  get_value(layer, "element", different_attrs),
                                  get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
168 169 170
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
171 172 173 174 175 176
def prim_dict(layer,
              indent=1,
              init_func=[],
              forward_func=[],
              layer_id=None,
              different_attrs=None):
S
SunAhong1993 已提交
177 178
    line = "{} = dict()".format(layer.outputs[0])
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
179 180 181 182 183 184 185 186


def prim_dict_construct(layer,
                        indent=1,
                        init_func=[],
                        forward_func=[],
                        layer_id=None,
                        different_attrs=None):
S
SunAhong1993 已提交
187
    lines = list()
S
SunAhong1993 已提交
188
    line = "{} = dict()".format(layer.outputs[0])
S
SunAhong1993 已提交
189 190
    lines.append(line)
    for i in range(len(layer.inputs)):
S
SunAhong1993 已提交
191 192 193 194
        line = "{}[{}] = {}".format(
            layer.outputs[0],
            get_value(layer, "key{}".format(i), different_attrs),
            get_value(layer, "value{}".format(i), different_attrs))
S
SunAhong1993 已提交
195 196
        lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))
S
SunAhong1993 已提交
197 198 199 200 201 202 203 204


def prim_div(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
S
SunAhong1993 已提交
205
    line = "{} = {} / {}".format(layer.outputs[0],
S
SunAhong1993 已提交
206
                                 get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
207
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
208 209 210
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
211 212 213 214 215
def prim_eq(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
S
SunAhong1993 已提交
216
            different_attrs=None):
S
SunAhong1993 已提交
217
    line = "{} = {} == {}".format(layer.outputs[0],
S
SunAhong1993 已提交
218
                                  get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
219
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
220 221 222
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
223 224 225 226 227 228 229 230
def prim_equal(layer,
               indent=1,
               init_func=[],
               forward_func=[],
               layer_id=None,
               different_attrs=None):
    line = "{} = {}".format(layer.outputs[0],
                            get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
231 232 233
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
234 235 236 237 238 239
def prim_exception(layer,
                   indent=1,
                   init_func=[],
                   forward_func=[],
                   layer_id=None,
                   different_attrs=None):
S
SunAhong1993 已提交
240
    line = "raise RaiseException({})".format(
S
SunAhong1993 已提交
241
        get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
242 243 244
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
245 246 247 248 249 250 251 252
def prim_float(layer,
               indent=1,
               init_func=[],
               forward_func=[],
               layer_id=None,
               different_attrs=None):
    line = "{} = float({})".format(layer.outputs[0],
                                   get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
253 254 255
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
256 257 258 259 260 261
def prim_floor(layer,
               indent=1,
               init_func=[],
               forward_func=[],
               layer_id=None,
               different_attrs=None):
S
SunAhong1993 已提交
262 263
    line = "{} = math.floor({})".format(
        layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
264 265 266
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
267 268 269 270 271 272
def prim_floordiv(layer,
                  indent=1,
                  init_func=[],
                  forward_func=[],
                  layer_id=None,
                  different_attrs=None):
S
SunAhong1993 已提交
273
    line = "{} = {} // {}".format(layer.outputs[0],
S
SunAhong1993 已提交
274
                                  get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
275
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
276 277 278
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
279 280 281 282 283 284
def prim_getitem(layer,
                 indent=1,
                 init_func=[],
                 forward_func=[],
                 layer_id=None,
                 different_attrs=None):
S
SunAhong1993 已提交
285
    line = "{} = {}[{}]".format(layer.outputs[0],
S
SunAhong1993 已提交
286 287
                                get_value(layer, "list", different_attrs),
                                get_value(layer, "index", different_attrs))
S
SunAhong1993 已提交
288 289 290
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
291 292 293 294 295
def prim_gt(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
S
SunAhong1993 已提交
296
            different_attrs=None):
S
SunAhong1993 已提交
297
    line = "{} = {} > {}".format(layer.outputs[0],
S
SunAhong1993 已提交
298
                                 get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
299
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
300 301 302
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
303 304 305 306 307 308
def prim_if(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
            different_attrs=None):
S
SunAhong1993 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
    line = "if {} :".format(get_value(layer, "input", different_attrs))
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)
    block = layer.blocks[1]
    if len(block.layers) > 0:
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
        if len(b_forward_lines) != 0:
            line = "else:"
            forward_func.extend(gen_codes([line], indent=indent))
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)
S
SunAhong1993 已提交
324 325


S
SunAhong1993 已提交
326 327 328 329 330 331 332 333
def prim_int(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
    line = "{} = int({})".format(layer.outputs[0],
                                 get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
334 335 336
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
337 338 339 340 341
def prim_is(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
S
SunAhong1993 已提交
342
            different_attrs=None):
S
SunAhong1993 已提交
343
    line = "{} = {} is {}".format(layer.outputs[0],
S
SunAhong1993 已提交
344
                                  get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
345
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
346 347 348
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
349 350 351 352 353
def prim_isinstance(layer,
                    indent=1,
                    init_func=[],
                    forward_func=[],
                    layer_id=None,
S
SunAhong1993 已提交
354
                    different_attrs=None):
S
SunAhong1993 已提交
355 356 357
    line = "{} = isinstance({}, {})".format(
        layer.outputs[0],
        get_value(layer, "input", different_attrs), layer.attrs["cls"])
S
SunAhong1993 已提交
358 359 360
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
361 362 363 364 365
def prim_isnot(layer,
               indent=1,
               init_func=[],
               forward_func=[],
               layer_id=None,
S
SunAhong1993 已提交
366
               different_attrs=None):
S
SunAhong1993 已提交
367
    line = "{} = {} is not {}".format(layer.outputs[0],
S
SunAhong1993 已提交
368 369
                                      get_value(layer, "x", different_attrs),
                                      get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
370 371 372
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
373 374 375 376 377
def prim_le(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
S
SunAhong1993 已提交
378
            different_attrs=None):
S
SunAhong1993 已提交
379
    line = "{} = {} <= {}".format(layer.outputs[0],
S
SunAhong1993 已提交
380
                                  get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
381
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
382 383 384
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
385 386 387 388 389 390 391 392
def prim_len(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
    line = "{} = len({})".format(layer.outputs[0],
                                 get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
393 394 395
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
396 397 398 399 400 401
def prim_len2list(layer,
                  indent=1,
                  init_func=[],
                  forward_func=[],
                  layer_id=None,
                  different_attrs=None):
S
SunAhong1993 已提交
402 403
    lines = []
    lines.append("{} = []".format(layer.outputs[0]))
S
SunAhong1993 已提交
404 405
    lines.append("for i in range({}):".format(
        get_value(layer, "len", different_attrs)))
S
SunAhong1993 已提交
406 407 408 409
    lines.append("    {}.append(i)".format(layer.outputs[0]))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
410 411 412 413 414
def prim_lt(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
S
SunAhong1993 已提交
415
            different_attrs=None):
S
SunAhong1993 已提交
416
    line = "{} = {} < {}".format(layer.outputs[0],
S
SunAhong1993 已提交
417
                                 get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
418
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
419 420 421
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
422 423 424 425 426 427
def prim_list(layer,
              indent=1,
              init_func=[],
              forward_func=[],
              layer_id=None,
              different_attrs=None):
S
SunAhong1993 已提交
428 429 430
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
S
SunAhong1993 已提交
431 432
        inputs_list.append(
            get_value(layer, "input{}".format(i), different_attrs))
S
SunAhong1993 已提交
433 434 435 436 437
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
438 439 440 441 442 443 444 445
def prim_list_unpack(layer,
                     indent=1,
                     init_func=[],
                     forward_func=[],
                     layer_id=None,
                     different_attrs=None):
    line = "{} = {}".format(", ".join(layer.outputs),
                            get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
446 447 448
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
449 450 451 452 453 454
def prim_loop(layer,
              indent=1,
              init_func=[],
              forward_func=[],
              layer_id=None,
              different_attrs=None):
S
SunAhong1993 已提交
455
    loop_range = get_value(layer, "input", different_attrs)
S
SunAhong1993 已提交
456 457 458 459 460 461 462 463
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


S
SunAhong1993 已提交
464 465 466 467 468 469 470 471
def prim_min(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
    line = "{} = min({})".format(layer.outputs[0],
                                 get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
472 473 474
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
475 476 477 478 479 480
def prim_mul(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
S
SunAhong1993 已提交
481
    line = "{} = {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
482
                                 get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
483
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
484 485 486
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
487 488 489 490 491
def prim_ne(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
S
SunAhong1993 已提交
492
            different_attrs=None):
S
SunAhong1993 已提交
493
    line = "{} = {} != {}".format(layer.outputs[0],
S
SunAhong1993 已提交
494
                                  get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
495
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
496 497 498
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
499 500 501 502 503 504 505 506
def prim_neg(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
    line = "{} = -{}".format(layer.outputs[0],
                             get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
507 508 509
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
510 511 512 513 514
def prim_not(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
S
SunAhong1993 已提交
515
             different_attrs=None):
S
SunAhong1993 已提交
516 517
    line = "{} = not {}".format(layer.outputs[0],
                                get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
518 519 520
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
521 522 523 524 525
def prim_or(layer,
            indent=1,
            init_func=[],
            forward_func=[],
            layer_id=None,
S
SunAhong1993 已提交
526
            different_attrs=None):
S
SunAhong1993 已提交
527
    line = "{} = {} or {}".format(layer.outputs[0],
S
SunAhong1993 已提交
528
                                  get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
529
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
530 531 532
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
533 534 535 536 537 538
def prim_replaceitem(layer,
                     indent=1,
                     init_func=[],
                     forward_func=[],
                     layer_id=None,
                     different_attrs=None):
S
SunAhong1993 已提交
539
    line = "{}[{}] = {}".format(
S
SunAhong1993 已提交
540
        get_value(layer, "list", layer_id, different_attrs),
S
SunAhong1993 已提交
541
        get_value(layer, "index", layer_id, different_attrs),
S
SunAhong1993 已提交
542
        get_value(layer, "item", layer_id, different_attrs))
S
SunAhong1993 已提交
543 544 545
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
546 547 548 549 550 551 552 553
def prim_requires_grad(layer,
                       indent=1,
                       init_func=[],
                       forward_func=[],
                       layer_id=None,
                       different_attrs=None):
    line = "{} = not {}.stop_gradient".format(
        layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
554 555 556
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
557 558 559 560 561 562 563 564 565 566 567
def prim_rsub(layer,
              indent=1,
              init_func=[],
              forward_func=[],
              layer_id=None,
              different_attrs=None):
    line = "{} = {} - {} * {}".format(
        layer.outputs[0],
        get_value(layer, "y", different_attrs),
        get_value(layer, "x", different_attrs),
        get_value(layer, "alpha", different_attrs))
S
SunAhong1993 已提交
568 569 570
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
571 572 573 574 575 576 577 578
def prim_select(layer,
                indent=1,
                init_func=[],
                forward_func=[],
                layer_id=None,
                different_attrs=None):
    line = "{} = {}[".format(layer.outputs[0],
                             get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
579 580
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
581
    line += (get_value(layer, "index", different_attrs) + "]")
S
SunAhong1993 已提交
582 583 584
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
585 586 587 588 589 590 591 592
def prim_set_attr(layer,
                  indent=1,
                  init_func=[],
                  forward_func=[],
                  layer_id=None,
                  different_attrs=None):
    line = "{} = {}".format(layer.outputs[0],
                            get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
593 594 595
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
596 597 598 599 600 601
def prim_set_item(layer,
                  indent=1,
                  init_func=[],
                  forward_func=[],
                  layer_id=None,
                  different_attrs=None):
S
SunAhong1993 已提交
602
    line = "{}[{}] = {}".format(
S
SunAhong1993 已提交
603
        get_value(layer, "dict", different_attrs),
S
SunAhong1993 已提交
604 605
        get_value(layer, "key", different_attrs),
        get_value(layer, "value", different_attrs))
S
SunAhong1993 已提交
606 607 608
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
609 610 611 612 613 614
def prim_shape_dim(layer,
                   indent=1,
                   init_func=[],
                   forward_func=[],
                   layer_id=None,
                   different_attrs=None):
W
wjj19950828 已提交
615
    line = "{} = paddle.shape({})[{}]".format(
S
SunAhong1993 已提交
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
        layer.outputs[0],
        get_value(layer, "input", different_attrs),
        get_value(layer, "dim", different_attrs))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_slice(layer,
               indent=1,
               init_func=[],
               forward_func=[],
               layer_id=None,
               different_attrs=None):
    line = "{} = {}[{}: {}: {}]".format(
        layer.outputs[0],
        get_value(layer, "input", different_attrs),
        get_value(layer, "start", different_attrs),
        get_value(layer, "end", different_attrs),
        get_value(layer, "step", different_attrs))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_str(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
    line = "{} = str({})".format(layer.outputs[0],
                                 get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
645 646 647
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
648 649 650 651 652 653
def prim_sub(layer,
             indent=1,
             init_func=[],
             forward_func=[],
             layer_id=None,
             different_attrs=None):
S
SunAhong1993 已提交
654 655 656
    line = "{} = {} - {}".format(layer.outputs[0],
                                 get_value(layer, "x", different_attrs),
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
657 658 659
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
660 661 662 663 664 665
def prim_tuple(layer,
               indent=1,
               init_func=[],
               forward_func=[],
               layer_id=None,
               different_attrs=None):
S
SunAhong1993 已提交
666 667 668
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
S
SunAhong1993 已提交
669 670
        inputs_list.append(
            get_value(layer, "input{}".format(i), different_attrs))
S
SunAhong1993 已提交
671 672 673 674 675
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
676 677 678 679 680 681
def prim_tuple_unpack(layer,
                      indent=1,
                      init_func=[],
                      forward_func=[],
                      layer_id=None,
                      different_attrs=None):
S
SunAhong1993 已提交
682
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
683 684
    line = "{} = {}".format(outputs_str,
                            get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
685 686 687
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
688 689 690 691 692 693 694 695
def prim_type(layer,
              indent=1,
              init_func=[],
              forward_func=[],
              layer_id=None,
              different_attrs=None):
    line = "{} = {}.dtype".format(layer.outputs[0],
                                  get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
696 697 698
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
699 700 701 702 703 704 705 706
def prim_var2list(layer,
                  indent=1,
                  init_func=[],
                  forward_func=[],
                  layer_id=None,
                  different_attrs=None):
    line = "{} = {}.numpy().tolist()".format(
        layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
707 708 709
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
710 711 712 713 714 715
def prim_warnings(layer,
                  indent=1,
                  init_func=[],
                  forward_func=[],
                  layer_id=None,
                  different_attrs=None):
S
SunAhong1993 已提交
716 717
    lines = ["import warnings"]
    line = "warnings.warn({}, stacklevel={})".format(
S
SunAhong1993 已提交
718
        get_value(layer, "input", different_attrs), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
719 720
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))