prim2code.py 11.1 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


S
SunAhong1993 已提交
16 17 18 19 20 21 22 23 24 25 26
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
27 28 29 30 31 32 33 34 35 36
def get_value(layer, key):
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
        return str(layer.attrs[key])


S
SunAhong1993 已提交
37
def prim_add(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
38 39
    line = "{} = {} + {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
40 41 42 43
    forward_func.extend(gen_codes([line], indent=indent))


def prim_add_(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
44 45 46 47
    line = "{} = {} + {} * {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      layer.attrs["alpha"],
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
48 49 50 51
    forward_func.extend(gen_codes([line], indent=indent))


def prim_and(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
52 53
    line = "{} = {} and {}".format(layer.outputs[0],
                                   get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
54 55 56 57
    forward_func.extend(gen_codes([line], indent=indent))


def prim_append(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
58 59
    line = "{}.append({})".format(
        get_value(layer, "list"), get_value(layer, "element"))
S
SunAhong1993 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    forward_func.extend(gen_codes([line], indent=indent))


def prim_assert(layer, indent=1, init_func=[], forward_func=[]):
    if layer.attrs["type"] == "eq":
        if isinstance(layer.attrs["value"], list):
            s = ""
            for v in layer.attrs["value"]:
                s += "{} == {} or ".format(layer.attrs["key"], v)
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
                s, layer.attrs["key"], layer.attrs["value"])
        else:
            line = "assert {} == {}, \'The {} must be {}!\'".format(
                layer.attrs["key"], layer.attrs["value"], layer.attrs["key"],
                layer.attrs["value"])
    else:
        raise Exception("Not implement yet!")
    forward_func.extend(gen_codes([line], indent=indent))


def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
88 89
    line = "{} = {} == {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
90 91 92 93
    forward_func.extend(gen_codes([line], indent=indent))


def prim_equal(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
94
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
95 96 97 98
    forward_func.extend(gen_codes([line], indent=indent))


def prim_exception(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
99
    line = "raise RaiseException({})".format(get_value(layer, "input"))
S
SunAhong1993 已提交
100 101 102
    forward_func.extend(gen_codes([line], indent=indent))


103 104 105 106 107 108
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {} // {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
109
def prim_if(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
110
    line = "if {} :".format(get_value(layer, "input"))
S
SunAhong1993 已提交
111 112 113 114 115 116 117 118
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)
    block = layer.blocks[1]
    if len(block.layers) > 0:
        line = "else:"
S
SunAhong1993 已提交
119 120 121 122 123
        forward_func.extend(gen_codes([line], indent=indent))
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)
S
SunAhong1993 已提交
124 125 126


def prim_getitem(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
127 128 129
    line = "{} = {}[{}]".format(layer.outputs[0],
                                get_value(layer, "list"),
                                get_value(layer, "index"))
S
SunAhong1993 已提交
130 131 132 133
    forward_func.extend(gen_codes([line], indent=indent))


def prim_gt(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
134 135
    line = "{} = {} > {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
136 137 138 139
    forward_func.extend(gen_codes([line], indent=indent))


def prim_le(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
140 141
    line = "{} = {} <= {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
142 143 144 145
    forward_func.extend(gen_codes([line], indent=indent))


def prim_len(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
146
    line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
147
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
148 149


150 151 152 153 154 155 156 157
def prim_len2list(layer, indent=1, init_func=[], forward_func=[]):
    lines = []
    lines.append("{} = []".format(layer.outputs[0]))
    lines.append("for i in range({}):".format(get_value(layer, "len")))
    lines.append("    {}.append(i)".format(layer.outputs[0]))
    forward_func.extend(gen_codes(lines, indent=indent))


S
SunAhong1993 已提交
158
def prim_lt(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
159 160
    line = "{} = {} < {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
161 162 163 164
    forward_func.extend(gen_codes([line], indent=indent))


def prim_list(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
165 166 167 168
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
169 170 171 172 173
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


174 175 176 177 178
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
179
def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
180
    loop_range = get_value(layer, "input")
S
SunAhong1993 已提交
181 182 183 184 185 186 187 188 189
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


def prim_min(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
190
    line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
191 192 193 194
    forward_func.extend(gen_codes([line], indent=indent))


def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
195 196
    line = "{} = {} * {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
197 198 199 200
    forward_func.extend(gen_codes([line], indent=indent))


def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
201 202
    line = "{} = {} != {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
203 204 205 206
    forward_func.extend(gen_codes([line], indent=indent))


def prim_neg(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
207
    line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
208 209 210 211
    forward_func.extend(gen_codes([line], indent=indent))


def prim_not(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
212
    line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
213 214 215
    forward_func.extend(gen_codes([line], indent=indent))


216 217 218 219 220 221 222
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[]):
    line = "{}[{}] = {}".format(
        get_value(layer, "list"),
        get_value(layer, "index"), get_value(layer, "item"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
223 224
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = not {}.stop_gradient".format(layer.outputs[0],
S
SunAhong1993 已提交
225
                                              get_value(layer, "input"))
S
SunAhong1993 已提交
226 227 228 229
    forward_func.extend(gen_codes([line], indent=indent))


def prim_select(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
230
    line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
231 232
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
233
    line += (get_value(layer, "index") + "]")
S
SunAhong1993 已提交
234 235 236
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
237 238 239 240 241
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
    forward_func.extend(gen_codes([line], indent=indent))


S
SunAhong1993 已提交
242
def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
243
    line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
244 245 246 247
    forward_func.extend(gen_codes([line], indent=indent))


def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
248 249 250 251 252
    line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
                                        get_value(layer, "input"),
                                        get_value(layer, "start"),
                                        get_value(layer, "end"),
                                        get_value(layer, "step"))
S
SunAhong1993 已提交
253 254 255 256
    forward_func.extend(gen_codes([line], indent=indent))


def prim_sub(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
257 258
    line = "{} = {} - {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
259 260 261 262
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
263 264 265 266
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
267 268 269 270 271 272 273
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]):
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
274
    line = "{} = {}".format(outputs_str, get_value(layer, "input"))
S
SunAhong1993 已提交
275 276 277 278 279
    forward_func.extend(gen_codes([line], indent=indent))


def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
    lines = ["import warnings"]
S
SunAhong1993 已提交
280 281
    line = "warnings.warn({}, stacklevel={})".format(
        get_value(layer, "input"), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
282 283
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))