prim2code.py 15.7 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


S
SunAhong1993 已提交
16 17 18 19 20 21 22 23 24 25 26
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
27 28 29 30 31 32 33 34 35 36
def get_value(layer, key):
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
        return str(layer.attrs[key])


S
SunAhong1993 已提交
37
def prim_add(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
38 39
    line = "{} = {} + {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
40 41 42 43
    forward_func.extend(gen_codes([line], indent=indent))


def prim_add_(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
44 45 46 47
    line = "{} = {} + {} * {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      layer.attrs["alpha"],
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
48 49 50 51
    forward_func.extend(gen_codes([line], indent=indent))


def prim_and(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
52 53
    line = "{} = {} and {}".format(layer.outputs[0],
                                   get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
54 55 56 57
    forward_func.extend(gen_codes([line], indent=indent))


def prim_append(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
58 59
    line = "{}.append({})".format(
        get_value(layer, "list"), get_value(layer, "element"))
S
SunAhong1993 已提交
60 61 62 63 64
    forward_func.extend(gen_codes([line], indent=indent))


def prim_assert(layer, indent=1, init_func=[], forward_func=[]):
    if layer.attrs["type"] == "eq":
S
SunAhong1993 已提交
65 66 67 68
        values = get_value(layer, "key")
        if "value" in layer.attrs:
            values = layer.attrs["value"]
        if isinstance(values, list):
S
SunAhong1993 已提交
69
            s = ""
S
SunAhong1993 已提交
70 71
            for v in values:
                s += "{} == {} or ".format(get_value(layer, "key"), v)
S
SunAhong1993 已提交
72 73 74
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
S
SunAhong1993 已提交
75
                s, get_value(layer, "key"), get_value(layer, "value"))
S
SunAhong1993 已提交
76 77
        else:
            line = "assert {} == {}, \'The {} must be {}!\'".format(
S
SunAhong1993 已提交
78 79 80
                get_value(layer, "key"),
                get_value(layer, "value"),
                get_value(layer, "key"), get_value(layer, "value"))
S
SunAhong1993 已提交
81 82 83 84 85
    else:
        raise Exception("Not implement yet!")
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
86 87 88 89 90 91 92 93 94 95 96
def prim_check_dim(layer, indent=1, init_func=[], forward_func=[]):
    lines = []
    lines.append("if {} < 0:".format(get_value(layer, "dim")))
    lines.append("    {} = {} + {}".format(layer.outputs[
        0], get_value(layer, "dim"), get_value(layer, "len")))
    lines.append("else:")
    lines.append("    {} = {}".format(layer.outputs[0], get_value(layer,
                                                                  "dim")))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
97 98 99 100 101
def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
102 103 104 105 106 107 108 109 110 111 112 113
def prim_contain(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} in {}".format(layer.outputs[0],
                                  get_value(layer, "element"),
                                  get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_dict(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = dict()".format(layer.outputs[0])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
114 115 116 117 118 119
def prim_div(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} / {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
120
def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
121 122
    line = "{} = {} == {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
123 124 125 126
    forward_func.extend(gen_codes([line], indent=indent))


def prim_equal(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
127
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
128 129 130 131
    forward_func.extend(gen_codes([line], indent=indent))


def prim_exception(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
132
    line = "raise RaiseException({})".format(get_value(layer, "input"))
S
SunAhong1993 已提交
133 134 135
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
136 137 138 139 140 141 142 143 144 145 146
def prim_float(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_floor(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = math.floor({})".format(layer.outputs[0],
                                        get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


147 148 149 150 151 152
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} // {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165
def prim_getitem(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}[{}]".format(layer.outputs[0],
                                get_value(layer, "list"),
                                get_value(layer, "index"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_gt(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} > {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
166
def prim_if(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
167
    line = "if {} :".format(get_value(layer, "input"))
S
SunAhong1993 已提交
168 169 170 171 172 173 174
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)
    block = layer.blocks[1]
    if len(block.layers) > 0:
S
SunAhong1993 已提交
175 176
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
S
SunAhong1993 已提交
177 178 179
        if len(b_forward_lines) != 0:
            line = "else:"
            forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
180 181
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)
S
SunAhong1993 已提交
182 183


S
SunAhong1993 已提交
184 185 186 187 188
def prim_int(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
189 190 191
def prim_is(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} is {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
192 193 194
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
195 196 197 198 199 200 201
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = isinstance({}, {})".format(layer.outputs[0],
                                            get_value(layer, "input"),
                                            layer.attrs["cls"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
202 203 204 205
def prim_isnot(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} is not {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
206 207 208 209
    forward_func.extend(gen_codes([line], indent=indent))


def prim_le(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
210 211
    line = "{} = {} <= {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
212 213 214 215
    forward_func.extend(gen_codes([line], indent=indent))


def prim_len(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
216
    line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
217
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
218 219


220 221 222 223 224 225 226 227
def prim_len2list(layer, indent=1, init_func=[], forward_func=[]):
    lines = []
    lines.append("{} = []".format(layer.outputs[0]))
    lines.append("for i in range({}):".format(get_value(layer, "len")))
    lines.append("    {}.append(i)".format(layer.outputs[0]))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
228
def prim_lt(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
229 230
    line = "{} = {} < {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
231 232 233 234
    forward_func.extend(gen_codes([line], indent=indent))


def prim_list(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
235 236 237 238
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
239 240 241 242 243
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


244 245 246 247 248
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
249
def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
250
    loop_range = get_value(layer, "input")
S
SunAhong1993 已提交
251 252 253 254 255 256 257 258 259
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


def prim_min(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
260
    line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
261 262 263 264
    forward_func.extend(gen_codes([line], indent=indent))


def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
265 266
    line = "{} = {} * {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
267 268 269 270
    forward_func.extend(gen_codes([line], indent=indent))


def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
271 272
    line = "{} = {} != {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
273 274 275 276
    forward_func.extend(gen_codes([line], indent=indent))


def prim_neg(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
277
    line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
278 279 280 281
    forward_func.extend(gen_codes([line], indent=indent))


def prim_not(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
282
    line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
283 284 285
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
286 287 288 289 290 291
def prim_or(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} or {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


292 293 294 295 296 297 298
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[]):
    line = "{}[{}] = {}".format(
        get_value(layer, "list"),
        get_value(layer, "index"), get_value(layer, "item"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
299 300
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = not {}.stop_gradient".format(layer.outputs[0],
S
SunAhong1993 已提交
301
                                              get_value(layer, "input"))
S
SunAhong1993 已提交
302 303 304
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
305 306 307 308 309 310 311 312
def prim_rsub(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} - {} * {}".format(layer.outputs[0],
                                      get_value(layer, "y"),
                                      get_value(layer, "x"),
                                      get_value(layer, "alpha"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
313
def prim_select(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
314
    line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
315 316
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
317
    line += (get_value(layer, "index") + "]")
S
SunAhong1993 已提交
318 319 320
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
321 322 323 324 325
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
326 327 328 329 330 331 332
def prim_set_item(layer, indent=1, init_func=[], forward_func=[]):
    line = "{}[{}] = {}".format(
        get_value(layer, "dict"),
        get_value(layer, "key"), get_value(layer, "value"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
333
def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
334 335 336
    line = "{} = fluid.layers.shape({})[{}]".format(layer.outputs[0],
                                                    get_value(layer, "input"),
                                                    get_value(layer, "dim"))
S
SunAhong1993 已提交
337 338 339
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
340
def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
341 342 343 344 345
    line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
                                        get_value(layer, "input"),
                                        get_value(layer, "start"),
                                        get_value(layer, "end"),
                                        get_value(layer, "step"))
S
SunAhong1993 已提交
346 347 348
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
349 350 351 352 353
def prim_str(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = str({})".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
354
def prim_sub(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
355 356
    line = "{} = {} - {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
357 358 359 360
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
361 362 363 364
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
365 366 367 368 369 370 371
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]):
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
372
    line = "{} = {}".format(outputs_str, get_value(layer, "input"))
S
SunAhong1993 已提交
373 374 375
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
376 377 378 379 380
def prim_type(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}.dtype".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
381 382 383 384 385 386
def prim_var2list(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}.numpy().tolist()".format(layer.outputs[0],
                                             get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
387 388
def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
    lines = ["import warnings"]
S
SunAhong1993 已提交
389 390
    line = "warnings.warn({}, stacklevel={})".format(
        get_value(layer, "input"), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
391 392
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))