prim2code.py 21.1 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15 16
NO_OUTPUT_COUNT = 0
 
S
SunAhong1993 已提交
17 18 19 20 21 22 23 24 25 26 27
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
28
def get_value(layer, key, layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
29 30 31 32 33 34
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
S
SunAhong1993 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
        if different_attrs is None:
            return str(layer.attrs[key])
        else:
            key_name = "{}_{}".format(layer.outputs[0], key)
            if key_name in different_attrs:
                return key_name
            else:
                if layer_id is None:
                    return str(layer.attrs[key])
                key_name = "{}_{}".format("layer_id/{}".format(layer_id), key)
                if key_name in different_attrs:
                    new_key_name = "attr_{}".format(NO_OUTPUT_COUNT)
                    NO_OUTPUT_COUNT += 1
                    diff_index = different_attrs.index(key_name)
                    different_attrs[diff_index] = new_key_name
                    return new_key_name
                else:
                    return str(layer.attrs[key])


def prim_add(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
56
    line = "{} = {} + {}".format(layer.outputs[0],
S
SunAhong1993 已提交
57
                                 get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
58 59 60
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
61
def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
62
    line = "{} = {} + {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
63
                                      get_value(layer, "x", different_attrs),
S
SunAhong1993 已提交
64
                                      layer.attrs["alpha"],
S
SunAhong1993 已提交
65
                                      get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
66 67 68
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
69
def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
70
    line = "{} = {} and {}".format(layer.outputs[0],
S
SunAhong1993 已提交
71
                                   get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
72 73 74
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
75
def prim_append(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
76
    line = "{}.append({})".format(
S
SunAhong1993 已提交
77 78
        get_value(layer, "list", layer_id, different_attrs), 
        get_value(layer, "element", layer_id, different_attrs))
S
SunAhong1993 已提交
79 80 81
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
82
def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
    if layer.attrs["type"] == "eq":
        values = get_value(layer, "key")
        if "value" in layer.attrs:
            values = layer.attrs["value"]
        if isinstance(values, list):
            s = ""
            for v in values:
                s += "{} == {} or ".format(get_value(layer, "key"), v)
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
                s, get_value(layer, "key"), get_value(layer, "value"))
        else:
            line = "assert {} == {}, \'The {} must be {}!\'".format(
                get_value(layer, "key"),
                get_value(layer, "value"),
                get_value(layer, "key"), get_value(layer, "value"))
    else:
        raise Exception("Not implement yet!")
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
105
def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
106
    lines = []
S
SunAhong1993 已提交
107 108
    dim = get_value(layer, "dim", different_attrs)
    lines.append("if {} < 0:".format(dim))
S
SunAhong1993 已提交
109
    lines.append("    {} = {} + {}".format(layer.outputs[
S
SunAhong1993 已提交
110
        0], dim, get_value(layer, "len", different_attrs)))
S
SunAhong1993 已提交
111
    lines.append("else:")
S
SunAhong1993 已提交
112
    lines.append("    {} = {}".format(layer.outputs[0], dim))
S
SunAhong1993 已提交
113 114 115
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
116
def prim_constant(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
117 118 119 120
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
121
def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
122
    line = "{} = {} in {}".format(layer.outputs[0],
S
SunAhong1993 已提交
123 124
                                  get_value(layer, "element", different_attrs),
                                  get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
125 126 127
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
128 129 130 131 132 133 134
def prim_dict(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = dict()".format(layer.outputs[0])
    forward_func.extend(gen_codes([line], indent=indent))
    
    
def prim_dict_construct(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    lines = list()
S
SunAhong1993 已提交
135
    line = "{} = dict()".format(layer.outputs[0])
S
SunAhong1993 已提交
136 137 138 139 140 141 142 143 144 145 146 147
    lines.append(line)
    for i in range(len(layer.inputs)):
        line = "{}[{}] = {}".format(layer.outputs[0],
                                    get_value(layer, "key{}".format(i), different_attrs),
                                    get_value(layer, "value{}".format(i), different_attrs))
        lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))
    
    
def prim_dict2values(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = list({}.values())".format(layer.outputs[0],
                                           get_value(layer, "x", different_attrs))
S
SunAhong1993 已提交
148 149 150
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
151
def prim_div(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
152
    line = "{} = {} / {}".format(layer.outputs[0],
S
SunAhong1993 已提交
153 154
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
155 156 157
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
158
def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
159
    line = "{} = {} == {}".format(layer.outputs[0],
S
SunAhong1993 已提交
160 161
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
162 163 164
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
165 166
def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
167 168 169
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
170 171
def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "raise RaiseException({})".format(get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
172 173 174
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
175 176
def prim_float(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
177 178 179
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
180
def prim_floor(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
181
    line = "{} = math.floor({})".format(layer.outputs[0],
S
SunAhong1993 已提交
182
                                        get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
183 184 185
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
186
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
187
    line = "{} = {} // {}".format(layer.outputs[0],
S
SunAhong1993 已提交
188 189
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
190 191 192
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
193
def prim_getitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
194
    line = "{} = {}[{}]".format(layer.outputs[0],
S
SunAhong1993 已提交
195 196
                                get_value(layer, "list", different_attrs),
                                get_value(layer, "index", different_attrs))
S
SunAhong1993 已提交
197 198 199
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
200
def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
201
    line = "{} = {} > {}".format(layer.outputs[0],
S
SunAhong1993 已提交
202 203
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
204 205 206
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
207 208
def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "if {} :".format(get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)
    block = layer.blocks[1]
    if len(block.layers) > 0:
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
        if len(b_forward_lines) != 0:
            line = "else:"
            forward_func.extend(gen_codes([line], indent=indent))
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)


S
SunAhong1993 已提交
225 226
def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
227 228 229
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
230
def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
231
    line = "{} = {} is {}".format(layer.outputs[0],
S
SunAhong1993 已提交
232 233
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
234 235 236
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
237
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
238
    line = "{} = isinstance({}, {})".format(layer.outputs[0],
S
SunAhong1993 已提交
239
                                            get_value(layer, "input", different_attrs),
S
SunAhong1993 已提交
240 241 242 243
                                            layer.attrs["cls"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
244
def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
245
    line = "{} = {} is not {}".format(layer.outputs[0],
S
SunAhong1993 已提交
246 247
                                      get_value(layer, "x", different_attrs),
                                      get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
248 249 250
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
251
def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
252
    line = "{} = {} <= {}".format(layer.outputs[0],
S
SunAhong1993 已提交
253 254
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
255 256 257
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
258 259
def prim_len(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
260 261 262
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
263
def prim_len2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
264 265
    lines = []
    lines.append("{} = []".format(layer.outputs[0]))
S
SunAhong1993 已提交
266
    lines.append("for i in range({}):".format(get_value(layer, "len", different_attrs)))
S
SunAhong1993 已提交
267 268 269 270
    lines.append("    {}.append(i)".format(layer.outputs[0]))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
271
def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
272
    line = "{} = {} < {}".format(layer.outputs[0],
S
SunAhong1993 已提交
273 274
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
275 276 277
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
278
def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
279 280 281
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
S
SunAhong1993 已提交
282
        inputs_list.append(get_value(layer, "input{}".format(i), different_attrs))
S
SunAhong1993 已提交
283 284 285 286 287
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
288 289
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
290 291 292
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
293 294
def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    loop_range = get_value(layer, "input", different_attrs)
S
SunAhong1993 已提交
295 296 297 298 299 300 301 302
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


S
SunAhong1993 已提交
303 304
def prim_min(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
305 306 307
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
308
def prim_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
309
    line = "{} = {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
310 311
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
312 313 314
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
315
def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
316
    line = "{} = {} != {}".format(layer.outputs[0],
S
SunAhong1993 已提交
317 318
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
319 320 321
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
322 323
def prim_neg(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
324 325 326
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
327 328
def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
329 330 331
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
332
def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
333
    line = "{} = {} or {}".format(layer.outputs[0],
S
SunAhong1993 已提交
334 335
                                  get_value(layer, "x", different_attrs), 
                                  get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
336 337 338
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
339
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
340
    line = "{}[{}] = {}".format(
S
SunAhong1993 已提交
341 342 343
        get_value(layer, "list", layer_id, different_attrs),
        get_value(layer, "index", layer_id, different_attrs), 
        get_value(layer, "item", layer_id, different_attrs))
S
SunAhong1993 已提交
344 345 346
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
347
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
348
    line = "{} = not {}.stop_gradient".format(layer.outputs[0],
S
SunAhong1993 已提交
349
                                              get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
350 351 352
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
353
def prim_rsub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
354
    line = "{} = {} - {} * {}".format(layer.outputs[0],
S
SunAhong1993 已提交
355 356 357
                                      get_value(layer, "y", different_attrs),
                                      get_value(layer, "x", different_attrs),
                                      get_value(layer, "alpha", different_attrs))
S
SunAhong1993 已提交
358 359 360
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
361 362
def prim_select(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
363 364
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
365
    line += (get_value(layer, "index", different_attrs) + "]")
S
SunAhong1993 已提交
366 367 368
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
369 370
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
371 372 373
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
374
def prim_set_item(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
375
    line = "{}[{}] = {}".format(
S
SunAhong1993 已提交
376 377
        get_value(layer, "dict", different_attrs),
        get_value(layer, "key", different_attrs), get_value(layer, "value", different_attrs))
S
SunAhong1993 已提交
378
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
379 380 381 382 383 384 385
    
    
def prim_shape(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}.shape".format(layer.outputs[0],
                                  get_value(layer, "input", different_attrs))
    forward_func.extend(gen_codes([line], indent=indent))

S
SunAhong1993 已提交
386 387


S
SunAhong1993 已提交
388 389 390 391
def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}.shape[{}]".format(layer.outputs[0],
                                    get_value(layer, "input", different_attrs),
                                    get_value(layer, "dim", different_attrs))
S
SunAhong1993 已提交
392 393 394
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
395
def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
396
    line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
S
SunAhong1993 已提交
397 398 399 400
                                        get_value(layer, "input", different_attrs),
                                        get_value(layer, "start", different_attrs),
                                        get_value(layer, "end", different_attrs),
                                        get_value(layer, "step", different_attrs))
S
SunAhong1993 已提交
401 402 403
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
404 405
def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = str({})".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
406 407 408
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
409
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
410
    line = "{} = {} - {}".format(layer.outputs[0],
S
SunAhong1993 已提交
411 412
                                 get_value(layer, "x", different_attrs), 
                                 get_value(layer, "y", different_attrs))
S
SunAhong1993 已提交
413 414 415
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
416
def prim_tuple(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
417 418 419
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
S
SunAhong1993 已提交
420
        inputs_list.append(get_value(layer, "input{}".format(i), different_attrs))
S
SunAhong1993 已提交
421 422 423 424 425
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
426
def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
427
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
428
    line = "{} = {}".format(outputs_str, get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
429 430 431
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
432 433
def prim_type(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
    line = "{} = {}.dtype".format(layer.outputs[0], get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
434 435 436
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
437
def prim_var2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
438
    line = "{} = {}.numpy().tolist()".format(layer.outputs[0],
S
SunAhong1993 已提交
439
                                             get_value(layer, "input", different_attrs))
S
SunAhong1993 已提交
440 441 442
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
443
def prim_warnings(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
S
SunAhong1993 已提交
444 445
    lines = ["import warnings"]
    line = "warnings.warn({}, stacklevel={})".format(
S
SunAhong1993 已提交
446
        get_value(layer, "input", different_attrs), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
447 448
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))