fluid_code.py 5.0 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from x2paddle.core.graph import GraphNode
16
from x2paddle.core.util import *
M
mamingjie-China 已提交
17 18
import collections
import six
19

J
jiangjiajun 已提交
20 21 22 23 24

class Layer(object):
    def __init__(self):
        self.op = None
        self.param_attr = dict()
J
jiangjiajun 已提交
25
        self.inputs = dict()
J
jiangjiajun 已提交
26
        self.output = None
J
jiangjiajun 已提交
27
        self.is_custom_layer = False
C
Channingss 已提交
28
        self.use_fluid = False
J
jiangjiajun 已提交
29 30

    def get_code(self):
J
jiangjiajun 已提交
31 32
        layer_code = ""
        if self.output is not None:
M
mamingjie-China 已提交
33
            if isinstance(self.output, six.string_types):
34 35 36 37
                layer_code = self.output + " = "
            else:
                layer_code = self.output.layer_name + " = "

J
jiangjiajun 已提交
38 39
        if self.is_custom_layer:
            layer_code = layer_code + self.op + "("
J
jiangjiajun@baidu.com 已提交
40 41
        elif self.op == "=":
            layer_code = layer_code
C
Channingss 已提交
42 43
        elif self.use_fluid:
            layer_code = layer_code + "fluid." + self.op + "("
S
SunAhong1993 已提交
44 45
        elif self.op == "full_like":
            layer_code = layer_code + "paddle." + self.op + "("
J
jiangjiajun 已提交
46 47
        else:
            layer_code = layer_code + "fluid.layers." + self.op + "("
J
jiangjiajun 已提交
48

J
jiangjiajun 已提交
49 50 51
        if isinstance(self.inputs, list):
            in_list = "["
            for input in self.inputs:
J
jiangjiajun 已提交
52 53
                if isinstance(input, GraphNode):
                    if hasattr(input, "index"):
54 55 56
                        in_list += (
                            input.layer_name + "[{}]".format(input.index) + ", "
                        )
J
jiangjiajun 已提交
57 58
                    else:
                        in_list += (input.layer_name + ", ")
M
mamingjie-China 已提交
59
                elif isinstance(input, six.string_types):
J
jiangjiajun 已提交
60
                    in_list += (input + ", ")
J
jiangjiajun 已提交
61
                else:
J
jiangjiajun 已提交
62 63
                    raise Exception(
                        "Element of inputs should GraphNode or String")
J
jiangjiajun 已提交
64 65
            in_list = in_list.strip(", ") + "], "
            layer_code += in_list
J
jiangjiajun 已提交
66
        elif isinstance(self.inputs, dict):
J
jiangjiajun 已提交
67 68 69 70 71 72 73 74 75
            inputs = collections.OrderedDict(self.inputs)
            for key, input in inputs.items():
                if isinstance(input, GraphNode):
                    if hasattr(input, "index"):
                        layer_code = layer_code + key + "={}, ".format(
                            input.layer_name + "[{}]".format(input.index))
                    else:
                        layer_code = layer_code + key + "={}, ".format(
                            input.layer_name)
J
jiangjiajun 已提交
76
                else:
77
                    layer_code = layer_code + key + "={}, ".format(input)
J
jiangjiajun 已提交
78 79
        elif isinstance(self.inputs, GraphNode):
            if hasattr(self.inputs, "index"):
80 81
                layer_code += (
                    self.inputs.layer_name + "[{}]".format(self.inputs.index))
J
jiangjiajun 已提交
82
            else:
J
jiangjiajun@baidu.com 已提交
83 84 85
                layer_code += (self.inputs.layer_name)
            if self.op != "=":
                layer_code += ", "
M
mamingjie-China 已提交
86
        elif isinstance(self.inputs, six.string_types):
J
jiangjiajun@baidu.com 已提交
87 88 89
            layer_code += (self.inputs)
            if self.op != "=":
                layer_code += ", "
J
jiangjiajun 已提交
90 91
        else:
            raise Exception("Unknown type of inputs.")
J
jiangjiajun 已提交
92

J
jiangjiajun 已提交
93 94
        param_attr = collections.OrderedDict(self.param_attr)
        for key, value in param_attr.items():
95 96
            if '\n' in str(value):
                value = string(str(value).replace('\n', ','))
97 98
            if str(key) == 'attr':
                value = 'ParamAttr(' + str(value) + ')'
S
SunAhong1993 已提交
99 100 101
            layer_code = layer_code + key + "={}, ".format(value)
        layer_code = layer_code.strip(", ")

J
jiangjiajun@baidu.com 已提交
102 103 104
        if self.op != "=":
            layer_code += ")"
        return layer_code
S
SunAhong1993 已提交
105

J
jiangjiajun 已提交
106 107 108

class FluidCode(object):
    def __init__(self):
J
jiangjiajun 已提交
109
        self.layers = list()
J
jiangjiajun 已提交
110

S
SunAhong1993 已提交
111 112 113 114 115
    def add_layer(self,
                  op,
                  inputs,
                  output,
                  param_attr=None,
C
Channingss 已提交
116
                  use_fluid=False,
S
SunAhong1993 已提交
117
                  is_custom_layer=False):
J
jiangjiajun 已提交
118 119
        layer = Layer()
        layer.op = op
C
Channingss 已提交
120
        layer.use_fluid = use_fluid
S
SunAhong1993 已提交
121
        layer.is_custom_layer = is_custom_layer
122 123
        if inputs is not None:
            layer.inputs = inputs
J
jiangjiajun 已提交
124 125 126 127 128 129 130 131 132
        layer.output = output
        if param_attr is not None:
            layer.param_attr = param_attr
        self.layers.append(layer)

    def add_note(self, note):
        # note should be string
        self.layers.append(note)

J
jiangjiajun 已提交
133 134 135
    def clear(self):
        self.layers = list()

J
jiangjiajun 已提交
136 137 138 139
    def gen_codes(self):
        codes = list()
        for layer in self.layers:
            if isinstance(layer, Layer):
J
jiangjiajun 已提交
140
                codes.append(layer.get_code())
M
mamingjie-China 已提交
141
            elif isinstance(layer, six.string_types):
J
jiangjiajun 已提交
142
                codes.append(layer)
143
        return codes