fluid_code.py 5.5 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from x2paddle.core.graph import GraphNode
J
jiangjiajun 已提交
16
import collections
17

J
jiangjiajun 已提交
18 19 20 21 22

class Layer(object):
    def __init__(self):
        self.op = None
        self.param_attr = dict()
J
jiangjiajun 已提交
23
        self.inputs = dict()
J
jiangjiajun 已提交
24
        self.output = None
S
SunAhong1993 已提交
25
        self.is_new = False
J
jiangjiajun 已提交
26 27

    def get_code(self):
J
jiangjiajun 已提交
28 29
        layer_code = ""
        if self.output is not None:
30 31 32 33 34
            if isinstance(self.output, str):
                layer_code = self.output + " = "
            else:
                layer_code = self.output.layer_name + " = "

J
jiangjiajun 已提交
35 36
        layer_code = layer_code + "fluid.layers." + self.op + "("

J
jiangjiajun 已提交
37 38 39
        if isinstance(self.inputs, list):
            in_list = "["
            for input in self.inputs:
J
jiangjiajun 已提交
40 41 42 43 44 45 46 47
                if isinstance(input, GraphNode):
                    if hasattr(input, "index"):
                        in_list += (input.layer_name +
                                    "[{}]".format(input.index) + ", ")
                    else:
                        in_list += (input.layer_name + ", ")
                elif isinstance(input, str):
                    in_list += (input + ", ")
J
jiangjiajun 已提交
48
                else:
J
jiangjiajun 已提交
49 50
                    raise Exception(
                        "Element of inputs should GraphNode or String")
J
jiangjiajun 已提交
51 52
            in_list = in_list.strip(", ") + "], "
            layer_code += in_list
J
jiangjiajun 已提交
53
        elif isinstance(self.inputs, dict):
J
jiangjiajun 已提交
54 55 56 57 58 59 60 61 62 63 64
            inputs = collections.OrderedDict(self.inputs)
            for key, input in inputs.items():
                if isinstance(input, GraphNode):
                    if hasattr(input, "index"):
                        layer_code = layer_code + key + "={}, ".format(
                            input.layer_name + "[{}]".format(input.index))
                    else:
                        layer_code = layer_code + key + "={}, ".format(
                            input.layer_name)
                elif isinstance(input, str):
                    layer_code = layer_code + key + "={}, ".format(input)
J
jiangjiajun 已提交
65
                else:
J
jiangjiajun 已提交
66 67
                    raise Exception(
                        "Element of inputs should GraphNode or String")
J
jiangjiajun 已提交
68 69 70 71 72 73
        elif isinstance(self.inputs, GraphNode):
            if hasattr(self.inputs, "index"):
                layer_code += (self.inputs.layer_name +
                               "[{}]".format(self.inputs.index) + ", ")
            else:
                layer_code += (self.inputs.layer_name + ", ")
J
jiangjiajun 已提交
74 75
        elif isinstance(self.inputs, str):
            layer_code += (self.inputs + ", ")
J
jiangjiajun 已提交
76 77
        else:
            raise Exception("Unknown type of inputs.")
J
jiangjiajun 已提交
78

J
jiangjiajun 已提交
79 80
        param_attr = collections.OrderedDict(self.param_attr)
        for key, value in param_attr.items():
S
SunAhong1993 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
            layer_code = layer_code + key + "={}, ".format(value)
        layer_code = layer_code.strip(", ")

        return layer_code + ")"

    def get_custom_code(self):
        layer_code = ""
        if self.output is not None:
            if isinstance(self.output, str):
                layer_code = self.output + " = "
            else:
                layer_code = self.output.layer_name + " = "

        layer_code = layer_code + self.op + "("

        if isinstance(self.inputs, list):
            in_list = "["
            for input in self.inputs:
                assert isinstance(
                    input, GraphNode), "Type of input should be GraphNode"
                if hasattr(input, "index"):
                    in_list += (input.layer_name + "[{}]".format(input.index) +
                                ", ")
                else:
                    in_list += (input.layer_name + ", ")
            in_list = in_list.strip(", ") + "], "
            layer_code += in_list
        else:
            raise Exception("Unknown type of inputs.")

J
jiangjiajun 已提交
111 112
        param_attr = collections.OrderedDict(self.param_attr)
        for key, value in param_attr.items():
113
            layer_code = layer_code + key + "={}, ".format(value)
J
jiangjiajun 已提交
114
        layer_code = layer_code.strip(", ")
J
jiangjiajun 已提交
115

116 117
        return layer_code + ")"

J
jiangjiajun 已提交
118 119 120

class FluidCode(object):
    def __init__(self):
J
jiangjiajun 已提交
121
        self.layers = list()
J
jiangjiajun 已提交
122

S
SunAhong1993 已提交
123 124 125 126 127 128
    def add_layer(self,
                  op,
                  inputs,
                  output,
                  param_attr=None,
                  is_custom_layer=False):
J
jiangjiajun 已提交
129 130
        layer = Layer()
        layer.op = op
S
SunAhong1993 已提交
131
        layer.is_custom_layer = is_custom_layer
132 133
        if inputs is not None:
            layer.inputs = inputs
J
jiangjiajun 已提交
134 135 136 137 138 139 140 141 142
        layer.output = output
        if param_attr is not None:
            layer.param_attr = param_attr
        self.layers.append(layer)

    def add_note(self, note):
        # note should be string
        self.layers.append(note)

J
jiangjiajun 已提交
143 144 145
    def clear(self):
        self.layers = list()

J
jiangjiajun 已提交
146 147 148 149
    def gen_codes(self):
        codes = list()
        for layer in self.layers:
            if isinstance(layer, Layer):
S
SunAhong1993 已提交
150 151 152 153
                if layer.is_custom_layer:
                    codes.append(layer.get_custom_code())
                else:
                    codes.append(layer.get_code())
J
jiangjiajun 已提交
154 155
            elif isinstance(layer, str):
                codes.append(layer)
156
        return codes