prim2code.py 21.2 KB
Newer Older
S
SunAhong1993 已提交
1
# -*- coding:UTF-8 -*-
S
SunAhong1993 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
16 17
NO_OUTPUT_COUNT = 0
 
S
SunAhong1993 已提交
18 19 20 21 22 23 24 25 26 27 28
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
29
def get_value(layer, key, layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
30 31 32 33 34 35
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
S
SunAhong1993 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
        if different_attrs is None:
            return str(layer.attrs[key])
        else:
            key_name = "{}_{}".format(layer.outputs[0], key)
            if key_name in different_attrs:
                return key_name
            else:
                if layer_id is None:
                    return str(layer.attrs[key])
                key_name = "{}_{}".format("layer_id/{}".format(layer_id), key)
                if key_name in different_attrs:
                    new_key_name = "attr_{}".format(NO_OUTPUT_COUNT)
                    NO_OUTPUT_COUNT += 1
                    diff_index = different_attrs.index(key_name)
                    different_attrs[diff_index] = new_key_name
                    return new_key_name
                else:
                    return str(layer.attrs[key])


def prim_add(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
57
    line = "{} = {} + {}".format(layer.outputs[0],
S
SunAhong1993 已提交
58
                                 get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
59 60 61
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
62
def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
63
    line = "{} = {} + {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
64
                                      get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
65
                                      layer.attrs["alpha"],
S
SunAhong1993 已提交
66
                                      get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
67 68 69
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
70
def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
71
    line = "{} = {} and {}".format(layer.outputs[0],
S
SunAhong1993 已提交
72
                                   get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
73 74 75
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
76
def prim_append(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
77
    line = "{}.append({})".format(
S
SunAhong1993 已提交
78 79
        get_value(layer, "list", layer_id, different_attrs), 
        get_value(layer, "element", layer_id, different_attrs))
S
SunAhong1993 已提交
80 81 82
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
83
def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    if layer.attrs["type"] == "eq":
        values = get_value(layer, "key")
        if "value" in layer.attrs:
            values = layer.attrs["value"]
        if isinstance(values, list):
            s = ""
            for v in values:
                s += "{} == {} or ".format(get_value(layer, "key"), v)
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
                s, get_value(layer, "key"), get_value(layer, "value"))
        else:
            line = "assert {} == {}, \'The {} must be {}!\'".format(
                get_value(layer, "key"),
                get_value(layer, "value"),
                get_value(layer, "key"), get_value(layer, "value"))
    else:
        raise Exception("Not implement yet!")
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
106
def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
107
    lines = []
S
SunAhong1993 已提交
108 109
    dim = get_value(layer, "dim", different_attrs)
    lines.append("if {} < 0:".format(dim))
S
SunAhong1993 已提交
110
    lines.append("    {} = {} + {}".format(layer.outputs[
S
SunAhong1993 已提交
111
        0], dim, get_value(layer, "len", different_attrs)))
S
SunAhong1993 已提交
112
    lines.append("else:")
S
SunAhong1993 已提交
113
    lines.append("    {} = {}".format(layer.outputs[0], dim))
S
SunAhong1993 已提交
114 115 116
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
117
def prim_constant(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
118 119 120 121
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
122
def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
123
    line = "{} = {} in {}".format(layer.outputs[0],
S
SunAhong1993 已提交
124 125
                                  get_value(layer, "element", different_attrs),
                                  get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
126 127 128
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
129 130 131 132 133 134 135
def prim_dict(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = dict()".format(layer.outputs[0])
    forward_func.extend(gen_codes([line], indent=indent))
    
    
def prim_dict_construct(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    lines = list()
S
SunAhong1993 已提交
136
    line = "{} = dict()".format(layer.outputs[0])
S
SunAhong1993 已提交
137 138 139 140 141 142 143 144 145 146 147 148
    lines.append(line)
    for i in range(len(layer.inputs)):
        line = "{}[{}] = {}".format(layer.outputs[0],
                                    get_value(layer, "key{}".format(i), different_attrs),
                                    get_value(layer, "value{}".format(i), different_attrs))
        lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))
    
    
def prim_dict2values(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = list({}.values())".format(layer.outputs[0],
                                           get_value(layer, "x", different_attrs))
S
SunAhong1993 已提交
149 150 151
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
152
def prim_div(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
153
    line = "{} = {} / {}".format(layer.outputs[0],
S
SunAhong1993 已提交
154 155
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
156 157 158
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
159
def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
160
    line = "{} = {} == {}".format(layer.outputs[0],
S
SunAhong1993 已提交
161 162
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
163 164 165
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
166 167
def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
168 169 170
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
171 172
def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "raise RaiseException({})".format(get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
173 174 175
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
176 177
def prim_float(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
178 179 180
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
181
def prim_floor(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
182
    line = "{} = math.floor({})".format(layer.outputs[0],
S
SunAhong1993 已提交
183
                                        get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
184 185 186
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
187
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
188
    line = "{} = {} // {}".format(layer.outputs[0],
S
SunAhong1993 已提交
189 190
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
191 192 193
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
194
def prim_getitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
195
    line = "{} = {}[{}]".format(layer.outputs[0],
S
SunAhong1993 已提交
196 197
                                get_value(layer, "list", different_attrs),
                                get_value(layer, "index", different_attrs))
S
SunAhong1993 已提交
198 199 200
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
201
def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
202
    line = "{} = {} > {}".format(layer.outputs[0],
S
SunAhong1993 已提交
203 204
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
205 206 207
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
208 209
def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "if {} :".format(get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
210 211
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
S
SunAhong1993 已提交
212 213 214 215 216 217 218
    if len(block.layers) == 0:
        line = "pass"
        forward_func.extend(gen_codes([line], indent=indent + 1))
    else:
        b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)
S
SunAhong1993 已提交
219 220 221 222 223 224 225 226 227 228 229
    block = layer.blocks[1]
    if len(block.layers) > 0:
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
        if len(b_forward_lines) != 0:
            line = "else:"
            forward_func.extend(gen_codes([line], indent=indent))
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)


S
SunAhong1993 已提交
230 231
def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
232 233 234
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
235
def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
236
    line = "{} = {} is {}".format(layer.outputs[0],
S
SunAhong1993 已提交
237 238
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
239 240 241
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
242
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
243
    line = "{} = isinstance({}, {})".format(layer.outputs[0],
S
SunAhong1993 已提交
244
                                            get_value(layer, "input", different_attrs),
S
SunAhong1993 已提交
245 246 247 248
                                            layer.attrs["cls"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
249
def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
250
    line = "{} = {} is not {}".format(layer.outputs[0],
S
SunAhong1993 已提交
251 252
                                      get_value(layer, "x", different_attrs),
                                      get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
253 254 255
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
256
def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
257
    line = "{} = {} <= {}".format(layer.outputs[0],
S
SunAhong1993 已提交
258 259
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
260 261 262
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
263 264
def prim_len(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
265 266 267
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
268
def prim_len2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
269 270
    lines = []
    lines.append("{} = []".format(layer.outputs[0]))
S
SunAhong1993 已提交
271
    lines.append("for i in range({}):".format(get_value(layer, "len", different_attrs)))
S
SunAhong1993 已提交
272 273 274 275
    lines.append("    {}.append(i)".format(layer.outputs[0]))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
276
def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
277
    line = "{} = {} < {}".format(layer.outputs[0],
S
SunAhong1993 已提交
278 279
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
280 281 282
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
283
def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
284 285 286
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
S
SunAhong1993 已提交
287
        inputs_list.append(get_value(layer, "input{}".format(i), different_attrs))
S
SunAhong1993 已提交
288 289 290 291 292
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
293 294
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
295 296 297
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
298 299
def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    loop_range = get_value(layer, "input", different_attrs)
S
SunAhong1993 已提交
300 301 302 303 304 305 306 307
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


S
SunAhong1993 已提交
308 309
def prim_min(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
310 311 312
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
313
def prim_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
314
    line = "{} = {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
315 316
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
317 318 319
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
320
def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
321
    line = "{} = {} != {}".format(layer.outputs[0],
S
SunAhong1993 已提交
322 323
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
324 325 326
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
327 328
def prim_neg(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
329 330 331
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
332 333
def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
334 335 336
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
337
def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
338
    line = "{} = {} or {}".format(layer.outputs[0],
S
SunAhong1993 已提交
339 340
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
341 342 343
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
344
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
345
    line = "{}[{}] = {}".format(
S
SunAhong1993 已提交
346 347 348
        get_value(layer, "list", layer_id, different_attrs),
        get_value(layer, "index", layer_id, different_attrs), 
        get_value(layer, "item", layer_id, different_attrs))
S
SunAhong1993 已提交
349 350 351
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
352
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
353
    line = "{} = not {}.stop_gradient".format(layer.outputs[0],
S
SunAhong1993 已提交
354
                                              get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
355 356 357
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
358
def prim_rsub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
359
    line = "{} = {} - {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
360 361 362
                                      get_value(layer, "y", different_attrs),
                                      get_value(layer, "x", different_attrs),
                                      get_value(layer, "alpha", different_attrs))
S
SunAhong1993 已提交
363 364 365
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
366 367
def prim_select(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
368 369
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
370
    line += (get_value(layer, "index", different_attrs) + "]")
S
SunAhong1993 已提交
371 372 373
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
374 375
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
376 377 378
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
379
def prim_set_item(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
380
    line = "{}[{}] = {}".format(
S
SunAhong1993 已提交
381 382
        get_value(layer, "dict", different_attrs),
        get_value(layer, "key", different_attrs), get_value(layer, "value", different_attrs))
S
SunAhong1993 已提交
383
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
384 385 386 387 388 389 390
    
    
def prim_shape(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}.shape".format(layer.outputs[0],
                                  get_value(layer, "input", different_attrs))
    forward_func.extend(gen_codes([line], indent=indent))

S
SunAhong1993 已提交
391 392


S
SunAhong1993 已提交
393 394 395 396
def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}.shape[{}]".format(layer.outputs[0],
                                    get_value(layer, "input", different_attrs),
                                    get_value(layer, "dim", different_attrs))
S
SunAhong1993 已提交
397 398 399
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
400
def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
401
    line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
S
SunAhong1993 已提交
402 403 404 405
                                        get_value(layer, "input", different_attrs),
                                        get_value(layer, "start", different_attrs),
                                        get_value(layer, "end", different_attrs),
                                        get_value(layer, "step", different_attrs))
S
SunAhong1993 已提交
406 407 408
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
409 410
def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = str({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
411 412 413
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
414
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
415
    line = "{} = {} - {}".format(layer.outputs[0],
S
SunAhong1993 已提交
416 417
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
418 419 420
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
421
def prim_tuple(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
422 423 424
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
S
SunAhong1993 已提交
425
        inputs_list.append(get_value(layer, "input{}".format(i), different_attrs))
S
SunAhong1993 已提交
426 427 428 429 430
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
431
def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
432
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
433
    line = "{} = {}".format(outputs_str, get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
434 435 436
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
437 438
def prim_type(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}.dtype".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
439 440 441
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
442
def prim_var2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
443
    line = "{} = {}.numpy().tolist()".format(layer.outputs[0],
S
SunAhong1993 已提交
444
                                             get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
445 446 447
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
448
def prim_warnings(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
449 450
    lines = ["import warnings"]
    line = "warnings.warn({}, stacklevel={})".format(
S
SunAhong1993 已提交
451
        get_value(layer, "input", different_attrs), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
452 453
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))