prim2code.py 14.6 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


S
SunAhong1993 已提交
16 17 18 19 20 21 22 23 24 25 26
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
27 28 29 30 31 32 33 34 35 36
def get_value(layer, key):
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
        return str(layer.attrs[key])


S
SunAhong1993 已提交
37
def prim_add(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
38 39
    line = "{} = {} + {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
40 41 42 43
    forward_func.extend(gen_codes([line], indent=indent))


def prim_add_(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
44 45 46 47
    line = "{} = {} + {} * {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      layer.attrs["alpha"],
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
48 49 50 51
    forward_func.extend(gen_codes([line], indent=indent))


def prim_and(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
52 53
    line = "{} = {} and {}".format(layer.outputs[0],
                                   get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
54 55 56 57
    forward_func.extend(gen_codes([line], indent=indent))


def prim_append(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
58 59
    line = "{}.append({})".format(
        get_value(layer, "list"), get_value(layer, "element"))
S
SunAhong1993 已提交
60 61 62 63 64
    forward_func.extend(gen_codes([line], indent=indent))


def prim_assert(layer, indent=1, init_func=[], forward_func=[]):
    if layer.attrs["type"] == "eq":
S
SunAhong1993 已提交
65 66 67 68
        values = get_value(layer, "key")
        if "value" in layer.attrs:
            values = layer.attrs["value"]
        if isinstance(values, list):
S
SunAhong1993 已提交
69
            s = ""
S
SunAhong1993 已提交
70 71
            for v in values:
                s += "{} == {} or ".format(get_value(layer, "key"), v)
S
SunAhong1993 已提交
72 73 74
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
S
SunAhong1993 已提交
75
                s, get_value(layer, "key"), get_value(layer, "value"))
S
SunAhong1993 已提交
76 77
        else:
            line = "assert {} == {}, \'The {} must be {}!\'".format(
S
SunAhong1993 已提交
78 79 80
                get_value(layer, "key"),
                get_value(layer, "value"),
                get_value(layer, "key"), get_value(layer, "value"))
S
SunAhong1993 已提交
81 82 83 84 85
    else:
        raise Exception("Not implement yet!")
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
86 87 88 89 90 91 92 93 94 95 96
def prim_check_dim(layer, indent=1, init_func=[], forward_func=[]):
    lines = []
    lines.append("if {} < 0:".format(get_value(layer, "dim")))
    lines.append("    {} = {} + {}".format(layer.outputs[
        0], get_value(layer, "dim"), get_value(layer, "len")))
    lines.append("else:")
    lines.append("    {} = {}".format(layer.outputs[0], get_value(layer,
                                                                  "dim")))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
97 98 99 100 101
def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
102 103 104 105 106 107 108 109 110 111 112 113
def prim_contain(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} in {}".format(layer.outputs[0],
                                  get_value(layer, "element"),
                                  get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_dict(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = dict()".format(layer.outputs[0])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
114 115 116 117 118 119
def prim_div(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} / {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
120
def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
121 122
    line = "{} = {} == {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
123 124 125 126
    forward_func.extend(gen_codes([line], indent=indent))


def prim_equal(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
127
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
128 129 130 131
    forward_func.extend(gen_codes([line], indent=indent))


def prim_exception(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
132
    line = "raise RaiseException({})".format(get_value(layer, "input"))
S
SunAhong1993 已提交
133 134 135
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
136 137 138 139 140 141 142 143 144 145 146
def prim_float(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_floor(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = math.floor({})".format(layer.outputs[0],
                                        get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


147 148 149 150 151 152
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} // {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165
def prim_getitem(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}[{}]".format(layer.outputs[0],
                                get_value(layer, "list"),
                                get_value(layer, "index"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_gt(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} > {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
166
def prim_if(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
167
    line = "if {} :".format(get_value(layer, "input"))
S
SunAhong1993 已提交
168 169 170 171 172 173 174 175
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)
    block = layer.blocks[1]
    if len(block.layers) > 0:
        line = "else:"
S
SunAhong1993 已提交
176 177 178 179 180
        forward_func.extend(gen_codes([line], indent=indent))
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)
S
SunAhong1993 已提交
181 182


S
SunAhong1993 已提交
183 184 185 186 187
def prim_int(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
188 189 190
def prim_is(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} is {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
191 192 193
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
194 195 196 197
def prim_isnot(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} is not {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
198 199 200 201
    forward_func.extend(gen_codes([line], indent=indent))


def prim_le(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
202 203
    line = "{} = {} <= {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
204 205 206 207
    forward_func.extend(gen_codes([line], indent=indent))


def prim_len(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
208
    line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
209
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
210 211


212 213 214 215 216 217 218 219
def prim_len2list(layer, indent=1, init_func=[], forward_func=[]):
    lines = []
    lines.append("{} = []".format(layer.outputs[0]))
    lines.append("for i in range({}):".format(get_value(layer, "len")))
    lines.append("    {}.append(i)".format(layer.outputs[0]))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
220
def prim_lt(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
221 222
    line = "{} = {} < {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
223 224 225 226
    forward_func.extend(gen_codes([line], indent=indent))


def prim_list(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
227 228 229 230
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
231 232 233 234 235
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


236 237 238 239 240
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
241
def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
242
    loop_range = get_value(layer, "input")
S
SunAhong1993 已提交
243 244 245 246 247 248 249 250 251
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


def prim_min(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
252
    line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
253 254 255 256
    forward_func.extend(gen_codes([line], indent=indent))


def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
257 258
    line = "{} = {} * {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
259
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
260 261
    if "x2589" in layer.outputs:
        print(layer.inputs["y"])
S
SunAhong1993 已提交
262 263 264


def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
265 266
    line = "{} = {} != {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
267 268 269 270
    forward_func.extend(gen_codes([line], indent=indent))


def prim_neg(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
271
    line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
272 273 274 275
    forward_func.extend(gen_codes([line], indent=indent))


def prim_not(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
276
    line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
277 278 279
    forward_func.extend(gen_codes([line], indent=indent))


280 281 282 283 284 285 286
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[]):
    line = "{}[{}] = {}".format(
        get_value(layer, "list"),
        get_value(layer, "index"), get_value(layer, "item"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
287 288
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = not {}.stop_gradient".format(layer.outputs[0],
S
SunAhong1993 已提交
289
                                              get_value(layer, "input"))
S
SunAhong1993 已提交
290 291 292
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
293 294 295 296 297 298 299 300
def prim_rsub(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} - {} * {}".format(layer.outputs[0],
                                      get_value(layer, "y"),
                                      get_value(layer, "x"),
                                      get_value(layer, "alpha"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
301
def prim_select(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
302
    line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
303 304
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
305
    line += (get_value(layer, "index") + "]")
S
SunAhong1993 已提交
306 307 308
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
309 310 311 312 313
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
314 315 316 317 318 319 320
def prim_set_item(layer, indent=1, init_func=[], forward_func=[]):
    line = "{}[{}] = {}".format(
        get_value(layer, "dict"),
        get_value(layer, "key"), get_value(layer, "value"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
321
def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
322
    line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
323 324 325
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
326 327 328 329 330 331 332
def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}.shape[{}]".format(layer.outputs[0],
                                      get_value(layer, "input"),
                                      get_value(layer, "dim"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
333
def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
334 335 336 337 338
    line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
                                        get_value(layer, "input"),
                                        get_value(layer, "start"),
                                        get_value(layer, "end"),
                                        get_value(layer, "step"))
S
SunAhong1993 已提交
339 340 341 342
    forward_func.extend(gen_codes([line], indent=indent))


def prim_sub(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
343 344
    line = "{} = {} - {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
345 346 347 348
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
349 350 351 352
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
353 354 355 356 357 358 359
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]):
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
360
    line = "{} = {}".format(outputs_str, get_value(layer, "input"))
S
SunAhong1993 已提交
361 362 363 364 365
    forward_func.extend(gen_codes([line], indent=indent))


def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
    lines = ["import warnings"]
S
SunAhong1993 已提交
366 367
    line = "warnings.warn({}, stacklevel={})".format(
        get_value(layer, "input"), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
368 369
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))