prim2code.py 13.0 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


S
SunAhong1993 已提交
16 17 18 19 20 21 22 23 24 25 26
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
27 28 29 30 31 32 33 34 35 36
def get_value(layer, key):
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
        return str(layer.attrs[key])


S
SunAhong1993 已提交
37
def prim_add(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
38 39
    line = "{} = {} + {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
40 41 42 43
    forward_func.extend(gen_codes([line], indent=indent))


def prim_add_(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
44 45 46 47
    line = "{} = {} + {} * {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      layer.attrs["alpha"],
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
48 49 50 51
    forward_func.extend(gen_codes([line], indent=indent))


def prim_and(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
52 53
    line = "{} = {} and {}".format(layer.outputs[0],
                                   get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
54 55 56 57
    forward_func.extend(gen_codes([line], indent=indent))


def prim_append(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
58 59
    line = "{}.append({})".format(
        get_value(layer, "list"), get_value(layer, "element"))
S
SunAhong1993 已提交
60 61 62 63 64
    forward_func.extend(gen_codes([line], indent=indent))


def prim_assert(layer, indent=1, init_func=[], forward_func=[]):
    if layer.attrs["type"] == "eq":
S
SunAhong1993 已提交
65 66 67 68
        values = get_value(layer, "key")
        if "value" in layer.attrs:
            values = layer.attrs["value"]
        if isinstance(values, list):
S
SunAhong1993 已提交
69
            s = ""
S
SunAhong1993 已提交
70 71
            for v in values:
                s += "{} == {} or ".format(get_value(layer, "key"), v)
S
SunAhong1993 已提交
72 73 74
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
S
SunAhong1993 已提交
75
                s, get_value(layer, "key"), get_value(layer, "value"))
S
SunAhong1993 已提交
76 77
        else:
            line = "assert {} == {}, \'The {} must be {}!\'".format(
S
SunAhong1993 已提交
78 79 80
                get_value(layer, "key"),
                get_value(layer, "value"),
                get_value(layer, "key"), get_value(layer, "value"))
S
SunAhong1993 已提交
81 82 83 84 85 86 87 88 89 90
    else:
        raise Exception("Not implement yet!")
    forward_func.extend(gen_codes([line], indent=indent))


def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
91 92 93 94 95 96 97 98 99 100 101 102
def prim_contain(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} in {}".format(layer.outputs[0],
                                  get_value(layer, "element"),
                                  get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_dict(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = dict()".format(layer.outputs[0])
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
103
def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
104 105
    line = "{} = {} == {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
106 107 108 109
    forward_func.extend(gen_codes([line], indent=indent))


def prim_equal(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
110
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
111 112 113 114
    forward_func.extend(gen_codes([line], indent=indent))


def prim_exception(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
115
    line = "raise RaiseException({})".format(get_value(layer, "input"))
S
SunAhong1993 已提交
116 117 118
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
119 120 121 122 123 124 125 126 127 128 129
def prim_float(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_floor(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = math.floor({})".format(layer.outputs[0],
                                        get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


130 131 132 133 134 135
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} // {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148
def prim_getitem(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}[{}]".format(layer.outputs[0],
                                get_value(layer, "list"),
                                get_value(layer, "index"))
    forward_func.extend(gen_codes([line], indent=indent))


def prim_gt(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} > {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
149
def prim_if(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
150
    line = "if {} :".format(get_value(layer, "input"))
S
SunAhong1993 已提交
151 152 153 154 155 156 157 158
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)
    block = layer.blocks[1]
    if len(block.layers) > 0:
        line = "else:"
S
SunAhong1993 已提交
159 160 161 162 163
        forward_func.extend(gen_codes([line], indent=indent))
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)
S
SunAhong1993 已提交
164 165


S
SunAhong1993 已提交
166 167 168
def prim_is(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} is {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
169 170 171
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
172 173 174 175
def prim_isnot(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} is not {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
176 177 178 179
    forward_func.extend(gen_codes([line], indent=indent))


def prim_le(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
180 181
    line = "{} = {} <= {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
182 183 184 185
    forward_func.extend(gen_codes([line], indent=indent))


def prim_len(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
186
    line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
187
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
188 189


190 191 192 193 194 195 196 197
def prim_len2list(layer, indent=1, init_func=[], forward_func=[]):
    lines = []
    lines.append("{} = []".format(layer.outputs[0]))
    lines.append("for i in range({}):".format(get_value(layer, "len")))
    lines.append("    {}.append(i)".format(layer.outputs[0]))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
198
def prim_lt(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
199 200
    line = "{} = {} < {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
201 202 203 204
    forward_func.extend(gen_codes([line], indent=indent))


def prim_list(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
205 206 207 208
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
209 210 211 212 213
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


214 215 216 217 218
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
219
def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
220
    loop_range = get_value(layer, "input")
S
SunAhong1993 已提交
221 222 223 224 225 226 227 228 229
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


def prim_min(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
230
    line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
231 232 233 234
    forward_func.extend(gen_codes([line], indent=indent))


def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
235 236
    line = "{} = {} * {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
237 238 239 240
    forward_func.extend(gen_codes([line], indent=indent))


def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
241 242
    line = "{} = {} != {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
243 244 245 246
    forward_func.extend(gen_codes([line], indent=indent))


def prim_neg(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
247
    line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
248 249 250 251
    forward_func.extend(gen_codes([line], indent=indent))


def prim_not(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
252
    line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
253 254 255
    forward_func.extend(gen_codes([line], indent=indent))


256 257 258 259 260 261 262
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[]):
    line = "{}[{}] = {}".format(
        get_value(layer, "list"),
        get_value(layer, "index"), get_value(layer, "item"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
263 264
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = not {}.stop_gradient".format(layer.outputs[0],
S
SunAhong1993 已提交
265
                                              get_value(layer, "input"))
S
SunAhong1993 已提交
266 267 268 269
    forward_func.extend(gen_codes([line], indent=indent))


def prim_select(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
270
    line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
271 272
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
273
    line += (get_value(layer, "index") + "]")
S
SunAhong1993 已提交
274 275 276
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
277 278 279 280 281
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
282 283 284 285 286 287 288
def prim_set_item(layer, indent=1, init_func=[], forward_func=[]):
    line = "{}[{}] = {}".format(
        get_value(layer, "dict"),
        get_value(layer, "key"), get_value(layer, "value"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
289
def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
290
    line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
291 292 293 294
    forward_func.extend(gen_codes([line], indent=indent))


def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
295 296 297 298 299
    line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
                                        get_value(layer, "input"),
                                        get_value(layer, "start"),
                                        get_value(layer, "end"),
                                        get_value(layer, "step"))
S
SunAhong1993 已提交
300 301 302 303
    forward_func.extend(gen_codes([line], indent=indent))


def prim_sub(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
304 305
    line = "{} = {} - {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
306 307 308 309
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
310 311 312 313
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
314 315 316 317 318 319 320
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]):
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
321
    line = "{} = {}".format(outputs_str, get_value(layer, "input"))
S
SunAhong1993 已提交
322 323 324 325 326
    forward_func.extend(gen_codes([line], indent=indent))


def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
    lines = ["import warnings"]
S
SunAhong1993 已提交
327 328
    line = "warnings.warn({}, stacklevel={})".format(
        get_value(layer, "input"), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
329 330
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))