prim2code.py 9.9 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


S
SunAhong1993 已提交
16 17 18 19 20 21 22 23 24 25 26
def gen_codes(code_list, indent=0):
    indent_blank = "    " * indent
    codes = []
    for code_line in code_list:
        if code_line.strip() == "":
            codes.append('\n')
        else:
            codes.append(indent_blank + code_line + '\n')
    return codes


S
SunAhong1993 已提交
27 28 29 30 31 32 33 34 35 36
def get_value(layer, key):
    """ 进行optimizer后可能把inputs的value直接用数值代替(ConstantFuser),
        会把input换成attr,所以需要此处的操作。
    """
    if key in layer.inputs:
        return layer.inputs[key]
    else:
        return str(layer.attrs[key])


S
SunAhong1993 已提交
37
def prim_add(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
38 39
    line = "{} = {} + {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
40 41 42 43
    forward_func.extend(gen_codes([line], indent=indent))


def prim_add_(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
44 45 46 47
    line = "{} = {} + {} * {}".format(layer.outputs[0],
                                      get_value(layer, "x"),
                                      layer.attrs["alpha"],
                                      get_value(layer, "y"))
S
SunAhong1993 已提交
48 49 50 51
    forward_func.extend(gen_codes([line], indent=indent))


def prim_and(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
52 53
    line = "{} = {} and {}".format(layer.outputs[0],
                                   get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
54 55 56 57
    forward_func.extend(gen_codes([line], indent=indent))


def prim_append(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
58 59
    line = "{}.append({})".format(
        get_value(layer, "list"), get_value(layer, "element"))
S
SunAhong1993 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
    forward_func.extend(gen_codes([line], indent=indent))


def prim_assert(layer, indent=1, init_func=[], forward_func=[]):
    if layer.attrs["type"] == "eq":
        if isinstance(layer.attrs["value"], list):
            s = ""
            for v in layer.attrs["value"]:
                s += "{} == {} or ".format(layer.attrs["key"], v)
            if len(s) > 0:
                s = s[:-4]
            line = "assert {}, \'The {} must be {}!\'".format(
                s, layer.attrs["key"], layer.attrs["value"])
        else:
            line = "assert {} == {}, \'The {} must be {}!\'".format(
                layer.attrs["key"], layer.attrs["value"], layer.attrs["key"],
                layer.attrs["value"])
    else:
        raise Exception("Not implement yet!")
    forward_func.extend(gen_codes([line], indent=indent))


def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
    forward_func.extend(gen_codes([line], indent=indent))


def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
88 89
    line = "{} = {} == {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
90 91 92 93
    forward_func.extend(gen_codes([line], indent=indent))


def prim_equal(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
94
    line = "{} = {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
95 96 97 98
    forward_func.extend(gen_codes([line], indent=indent))


def prim_exception(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
99
    line = "raise RaiseException({})".format(get_value(layer, "input"))
S
SunAhong1993 已提交
100 101 102 103
    forward_func.extend(gen_codes([line], indent=indent))


def prim_if(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
104
    line = "if {} :".format(get_value(layer, "input"))
S
SunAhong1993 已提交
105 106 107 108 109 110 111 112
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)
    block = layer.blocks[1]
    if len(block.layers) > 0:
        line = "else:"
S
SunAhong1993 已提交
113 114 115 116 117
        forward_func.extend(gen_codes([line], indent=indent))
        b_init_lines, b_forward_lines = block.gen_dygraph_code(
            indent=indent + 1)
        init_func.extend(b_init_lines)
        forward_func.extend(b_forward_lines)
S
SunAhong1993 已提交
118 119 120


def prim_getitem(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
121 122 123
    line = "{} = {}[{}]".format(layer.outputs[0],
                                get_value(layer, "list"),
                                get_value(layer, "index"))
S
SunAhong1993 已提交
124 125 126 127
    forward_func.extend(gen_codes([line], indent=indent))


def prim_gt(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
128 129
    line = "{} = {} > {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
130 131 132 133
    forward_func.extend(gen_codes([line], indent=indent))


def prim_le(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
134 135
    line = "{} = {} <= {}".format(layer.outputs[0],
                                  get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
136 137 138 139
    forward_func.extend(gen_codes([line], indent=indent))


def prim_len(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
140
    line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
141
    forward_func.extend(gen_codes([line], indent=indent))
S
SunAhong1993 已提交
142 143 144


def prim_lt(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
145 146
    line = "{} = {} < {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
147 148 149 150
    forward_func.extend(gen_codes([line], indent=indent))


def prim_list(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
151 152 153 154
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
155 156 157 158 159 160
    inputs_str = ', '.join(inputs_list)
    line = "{} = [{}]".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


def prim_loop(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
161
    loop_range = get_value(layer, "input")
S
SunAhong1993 已提交
162 163 164 165 166 167 168 169 170
    line = "for {} in range({}):".format(layer.outputs[1], loop_range)
    forward_func.extend(gen_codes([line], indent=indent))
    block = layer.blocks[0]
    b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1)
    init_func.extend(b_init_lines)
    forward_func.extend(b_forward_lines)


def prim_min(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
171
    line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
172 173 174 175
    forward_func.extend(gen_codes([line], indent=indent))


def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
176 177
    line = "{} = {} * {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
178 179 180 181
    forward_func.extend(gen_codes([line], indent=indent))


def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
182 183
    line = "{} = {} < {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
184 185 186 187
    forward_func.extend(gen_codes([line], indent=indent))


def prim_neg(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
188
    line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
189 190 191 192
    forward_func.extend(gen_codes([line], indent=indent))


def prim_not(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
193
    line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
194 195 196 197 198
    forward_func.extend(gen_codes([line], indent=indent))


def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
    line = "{} = not {}.stop_gradient".format(layer.outputs[0],
S
SunAhong1993 已提交
199
                                              get_value(layer, "input"))
S
SunAhong1993 已提交
200 201 202 203
    forward_func.extend(gen_codes([line], indent=indent))


def prim_select(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
204
    line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
205 206
    for dim in range(layer.attrs["dim"]):
        line += ":, "
S
SunAhong1993 已提交
207
    line += (get_value(layer, "index") + "]")
S
SunAhong1993 已提交
208 209 210 211
    forward_func.extend(gen_codes([line], indent=indent))


def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
212
    line = "{} = {}.shape".format(layer.outputs[0], get_value(layer, "input"))
S
SunAhong1993 已提交
213 214 215 216
    forward_func.extend(gen_codes([line], indent=indent))


def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
217 218 219 220 221
    line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
                                        get_value(layer, "input"),
                                        get_value(layer, "start"),
                                        get_value(layer, "end"),
                                        get_value(layer, "step"))
S
SunAhong1993 已提交
222 223 224 225
    forward_func.extend(gen_codes([line], indent=indent))


def prim_sub(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
226 227
    line = "{} = {} - {}".format(layer.outputs[0],
                                 get_value(layer, "x"), get_value(layer, "y"))
S
SunAhong1993 已提交
228 229 230 231
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple(layer, indent=1, init_func=[], forward_func=[]):
S
SunAhong1993 已提交
232 233 234 235
    input_len = len(layer.inputs) + len(layer.attrs)
    inputs_list = list()
    for i in range(input_len):
        inputs_list.append(get_value(layer, "input{}".format(i)))
S
SunAhong1993 已提交
236 237 238 239 240 241 242
    inputs_str = ', '.join(inputs_list)
    line = "{} = ({})".format(layer.outputs[0], inputs_str)
    forward_func.extend(gen_codes([line], indent=indent))


def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[]):
    outputs_str = ', '.join(layer.outputs)
S
SunAhong1993 已提交
243
    line = "{} = {}".format(outputs_str, get_value(layer, "input"))
S
SunAhong1993 已提交
244 245 246 247 248
    forward_func.extend(gen_codes([line], indent=indent))


def prim_warnings(layer, indent=1, init_func=[], forward_func=[]):
    lines = ["import warnings"]
S
SunAhong1993 已提交
249 250
    line = "warnings.warn({}, stacklevel={})".format(
        get_value(layer, "input"), layer.attrs["stacklevel"])
S
SunAhong1993 已提交
251 252
    lines.append(line)
    forward_func.extend(gen_codes(lines, indent=indent))