caffe_decoder.py 10.0 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14 15 16 17 18 19

import os
import sys
from google.protobuf import text_format
import numpy as np
from x2paddle.core.graph import GraphNode, Graph
S
SunAhong1993 已提交
20
from x2paddle.core.fluid_code import FluidCode
S
SunAhong1993 已提交
21
from x2paddle.op_mapper import caffe_shape
S
SunAhong1993 已提交
22 23 24


class CaffeResolver(object):
S
SunAhong1993 已提交
25
    def __init__(self, caffe_proto):
S
SunAhong1993 已提交
26
        self.caffe_proto = caffe_proto
S
SunAhong1993 已提交
27 28 29
        self.import_caffe()

    def import_caffepb(self):
S
SunAhong1993 已提交
30 31 32 33 34 35 36 37 38 39 40 41
        if self.caffe_proto is None:
            from x2paddle.decoder import caffe_pb2
            out = caffe_pb2
        else:
            if not os.path.isfile(self.caffe_proto):
                raise Exception(
                    "The .py file compiled by caffe.proto is not exist.")
            (filepath,
             tempfilename) = os.path.split(os.path.abspath(self.caffe_proto))
            (filename, extension) = os.path.splitext(tempfilename)
            sys.path.append(filepath)
            out = __import__(filename)
S
SunAhong1993 已提交
42
        return out
S
SunAhong1993 已提交
43 44

    def import_caffe(self):
S
SunAhong1993 已提交
45
        self.caffepb = self.import_caffepb()
S
SunAhong1993 已提交
46 47 48 49
        self.NetParameter = self.caffepb.NetParameter


class CaffeGraphNode(GraphNode):
50
    def __init__(self, layer, type_str, layer_name=None):
S
SunAhong1993 已提交
51
        if layer_name is None:
S
SunAhong1993 已提交
52 53 54
            super(CaffeGraphNode,
                  self).__init__(layer,
                                 layer.name.replace('/', '_').replace('-', '_'))
S
SunAhong1993 已提交
55
        else:
S
SunAhong1993 已提交
56 57 58
            super(CaffeGraphNode,
                  self).__init__(layer,
                                 layer_name.replace('/', '_').replace('-', '_'))
59
        self.layer_type = type_str
S
SunAhong1993 已提交
60
        self.fluid_code = FluidCode()
S
SunAhong1993 已提交
61
        self.data = None
S
SunAhong1993 已提交
62 63 64 65 66 67

    def set_params(self, params):
        self.data = params


class CaffeGraph(Graph):
68
    def __init__(self, model, params, caffe_pb):
S
SunAhong1993 已提交
69
        self.params = params
70
        self.caffe_pb = caffe_pb
S
SunAhong1993 已提交
71 72 73 74 75 76 77 78
        super(CaffeGraph, self).__init__(model)

    def filter_layers(self, layers):
        '''Filter out layers based on the current phase.'''
        phase_map = {0: 'train', 1: 'test'}
        filtered_layer_names = set()
        filtered_layers = []
        for layer in layers:
79 80 81
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
S
SunAhong1993 已提交
82 83 84 85 86 87 88 89 90 91
            phase = 'test'
            if len(layer.include):
                phase = phase_map[layer.include[0].phase]
            if len(layer.exclude):
                phase = phase_map[1 - layer.include[0].phase]
            exclude = (phase != 'test')
            # Dropout layers appear in a fair number of Caffe
            # test-time networks. These are just ignored. We'll
            # filter them out here.
            if (not exclude) and (phase == 'test'):
92
                exclude = (type_str == 'Dropout')
S
SunAhong1993 已提交
93 94 95 96 97 98
            if not exclude:
                filtered_layers.append(layer)
                # Guard against dupes.
                assert layer.name not in filtered_layer_names
                filtered_layer_names.add(layer.name)
            else:
S
SunAhong1993 已提交
99
                print('The filter layer:' + layer.name)
S
SunAhong1993 已提交
100 101
        return filtered_layers

S
SunAhong1993 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    def generate_input_layer(self, dims, index):
        dim_str = ''
        for dim in dims:
            dim_str += 'dim: {}\n'.format(str(dim))
        input_str = 'layer {\n'
        input_str += 'name: \"{}\"\n '.format(str(self.model.input[index]))
        input_str += 'type: "Input"\n'
        input_str += 'top: \"{}\"\n'.format(str(self.model.input[index]))
        input_str += 'input_param {\n'
        input_str += 'shape {\n'
        input_str += dim_str
        input_str += '}}}'
        input_str = str.encode(input_str)
        net = self.caffe_pb.NetParameter()
        text_format.Merge(input_str, net)
        return net.layers or net.layer

119 120 121 122 123 124 125 126 127 128
    def input2layers(self, input_layers=[]):
        inputs_num = len(self.model.input)
        if inputs_num != 0:
            input_dims_num = len(self.model.input_dim)
            if input_dims_num != 0:
                if input_dims_num > 0 and input_dims_num != inputs_num * 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                for i in range(inputs_num):
                    dims = self.model.input_dim[i * 4:(i + 1) * 4]
S
SunAhong1993 已提交
129
                    l = self.generate_input_layer(dims, i)
130 131 132 133
                    input_layers.append(l[0])
            else:
                for i in range(inputs_num):
                    dims = self.model.input_shape[i].dim[0:4]
S
SunAhong1993 已提交
134
                    l = self.generate_input_layer(dims, i)
135 136 137 138 139 140 141 142 143 144
                    input_layers.append(l[0])

    def transform_input_layers(self, layers, input_layers=[]):
        for layer in layers:
            if hasattr(layer, 'input'):
                input_dims_num = len(layers.input_dim)
                if input_dims_num > 0 and input_dims_num != 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                dims = self.model.input_dim[0:4]
S
SunAhong1993 已提交
145
                l = self.generate_input_layer(dims, i)
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
                input_layers.append(l[0])

    def get_layer_type(self, layer):
        if isinstance(layer.type, int):
            enum_values = self.caffe_pb._V1LAYERPARAMETER_LAYERTYPE.values
            vals = [val for val in enum_values if val.number == layer.type]
            part = vals[0].name.split('_')
            part = [s.capitalize() for s in part]
            type_str = ''
            type_str = type_str.join(part)
            if 'relu' in type_str.lower():
                type_str = type_str.replace('elu', 'eLU')
            elif type_str.lower() == 'lrn':
                type_str = 'LRN'
            return type_str
        else:
            return layer.type

S
SunAhong1993 已提交
164 165
    def build(self):
        layers = self.model.layers or self.model.layer
166

S
SunAhong1993 已提交
167 168
        layers = self.filter_layers(layers)

169 170 171 172 173 174
        input_layers = []

        self.input2layers(input_layers)
        self.transform_input_layers(layers, input_layers)
        layers = input_layers + layers

S
SunAhong1993 已提交
175
        top_layer = {}
S
SunAhong1993 已提交
176
        for layer in layers:
177 178 179 180
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
            self.node_map[layer.name] = CaffeGraphNode(layer, type_str)
S
SunAhong1993 已提交
181 182 183
            for in_name in layer.bottom:
                if in_name in top_layer:
                    self.connect(top_layer[in_name][-1], layer.name)
S
SunAhong1993 已提交
184 185 186
                else:
                    raise Exception(
                        'input[{}] of node[{}] does not exist in node_map'.
S
SunAhong1993 已提交
187 188 189 190 191 192
                        format(in_name, layer.name))
            for out_name in layer.top:
                if out_name not in top_layer:
                    top_layer[out_name] = [layer.name]
                else:
                    top_layer[out_name].append(layer.name)
S
SunAhong1993 已提交
193 194 195
        for layer_name, data in self.params:
            if layer_name in self.node_map:
                node = self.node_map[layer_name]
S
SunAhong1993 已提交
196
                node.set_params(data)
S
SunAhong1993 已提交
197
            else:
S
SunAhong1993 已提交
198
                print('Ignoring parameters for non-existent layer: %s' % \
S
SunAhong1993 已提交
199
                       layer_name)
S
SunAhong1993 已提交
200

S
SunAhong1993 已提交
201 202
        super(CaffeGraph, self).build()

S
SunAhong1993 已提交
203 204 205 206 207 208
    def get_bottom_node(self, node, idx=0, copy=False):
        input_node_name = node.inputs[idx]
        assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format(
            name)
        input_node = self.node_map[input_node_name]
        if len(input_node.layer.top) > 1:
S
SunAhong1993 已提交
209 210
            need_idx = list(input_node.layer.top).index(node.layer.bottom[idx])
            name = input_node_name + ':' + str(need_idx)
S
SunAhong1993 已提交
211 212 213 214
        else:
            name = input_node_name
        return self.get_node(name, copy=copy)

S
SunAhong1993 已提交
215

J
jiangjiajun 已提交
216
class CaffeDecoder(object):
S
SunAhong1993 已提交
217
    def __init__(self, proto_path, model_path, caffe_proto):
S
SunAhong1993 已提交
218 219 220
        self.proto_path = proto_path
        self.model_path = model_path

S
SunAhong1993 已提交
221
        self.resolver = CaffeResolver(caffe_proto=caffe_proto)
S
SunAhong1993 已提交
222 223
        self.net = self.resolver.NetParameter()
        with open(proto_path, 'rb') as proto_file:
224
            proto_str = proto_file.read()
S
SunAhong1993 已提交
225 226
            text_format.Merge(proto_str, self.net)

S
SunAhong1993 已提交
227
        self.load_using_pb()
228 229 230

        self.caffe_graph = CaffeGraph(self.net, self.params,
                                      self.resolver.caffepb)
S
SunAhong1993 已提交
231 232 233 234
        self.caffe_graph.build()

    def load_using_pb(self):
        data = self.resolver.NetParameter()
235

S
SunAhong1993 已提交
236 237 238 239 240 241 242 243 244 245
        data.MergeFromString(open(self.model_path, 'rb').read())
        pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
        layers = data.layers or data.layer
        self.params = [pair(layer) for layer in layers if layer.blobs]

    def normalize_pb_data(self, layer):
        transformed = []
        for blob in layer.blobs:
            if len(blob.shape.dim):
                dims = blob.shape.dim
S
SunAhong1993 已提交
246
                if layer.type == 'PReLU':
S
SunAhong1993 已提交
247 248
                    c_o, c_i, h, w = map(int, [1] + \
                        list(dims) + [1]* (3 - len(dims)))
S
SunAhong1993 已提交
249 250 251 252
                elif layer.type == 'Normalize':
                    data = np.asarray(list(blob.data), dtype=np.float32)
                    transformed.append(data)
                    continue
S
SunAhong1993 已提交
253
                else:
S
SunAhong1993 已提交
254 255
                    c_o, c_i, h, w = map(int, [1] * (4 - len(dims)) \
                        + list(dims))
S
SunAhong1993 已提交
256

S
SunAhong1993 已提交
257 258 259 260 261
            else:
                c_o = blob.num
                c_i = blob.channels
                h = blob.height
                w = blob.width
262 263 264
            data = np.asarray(list(blob.data),
                              dtype=np.float32).reshape(c_o, c_i, h, w)

S
SunAhong1993 已提交
265 266
            transformed.append(data)
        return transformed