caffe_decoder.py 11.4 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14 15 16 17 18 19

import os
import sys
from google.protobuf import text_format
import numpy as np
from x2paddle.core.graph import GraphNode, Graph
S
SunAhong1993 已提交
20
from x2paddle.core.fluid_code import FluidCode
S
SunAhong1993 已提交
21
from x2paddle.op_mapper import caffe_shape
S
SunAhong1993 已提交
22 23 24


class CaffeResolver(object):
S
SunAhong1993 已提交
25
    def __init__(self, caffe_proto):
S
SunAhong1993 已提交
26
        self.caffe_proto = caffe_proto
S
SunAhong1993 已提交
27 28 29
        self.import_caffe()

    def import_caffepb(self):
S
SunAhong1993 已提交
30 31 32 33 34 35 36
        if self.caffe_proto is None:
            from x2paddle.decoder import caffe_pb2
            out = caffe_pb2
        else:
            if not os.path.isfile(self.caffe_proto):
                raise Exception(
                    "The .py file compiled by caffe.proto is not exist.")
J
jiangjiajun 已提交
37 38
            (filepath,
             tempfilename) = os.path.split(os.path.abspath(self.caffe_proto))
S
SunAhong1993 已提交
39 40 41
            (filename, extension) = os.path.splitext(tempfilename)
            sys.path.append(filepath)
            out = __import__(filename)
S
SunAhong1993 已提交
42
        return out
S
SunAhong1993 已提交
43 44

    def import_caffe(self):
S
SunAhong1993 已提交
45
        self.caffepb = self.import_caffepb()
S
SunAhong1993 已提交
46 47 48 49
        self.NetParameter = self.caffepb.NetParameter


class CaffeGraphNode(GraphNode):
50
    def __init__(self, layer, type_str, layer_name=None):
S
SunAhong1993 已提交
51
        if layer_name is None:
J
jiangjiajun 已提交
52
            super(CaffeGraphNode, self).__init__(
J
jiangjiajun 已提交
53
                layer, layer.name.replace('/', '_').replace('-', '_'))
S
SunAhong1993 已提交
54
        else:
J
jiangjiajun 已提交
55
            super(CaffeGraphNode, self).__init__(
J
jiangjiajun 已提交
56
                layer, layer_name.replace('/', '_').replace('-', '_'))
57
        self.layer_type = type_str
S
SunAhong1993 已提交
58
        self.fluid_code = FluidCode()
S
SunAhong1993 已提交
59
        self.data = None
S
SunAhong1993 已提交
60 61 62 63 64 65

    def set_params(self, params):
        self.data = params


class CaffeGraph(Graph):
66
    def __init__(self, model, params, caffe_pb):
S
SunAhong1993 已提交
67
        self.params = params
68
        self.caffe_pb = caffe_pb
S
SunAhong1993 已提交
69 70 71 72 73 74 75 76
        super(CaffeGraph, self).__init__(model)

    def filter_layers(self, layers):
        '''Filter out layers based on the current phase.'''
        phase_map = {0: 'train', 1: 'test'}
        filtered_layer_names = set()
        filtered_layers = []
        for layer in layers:
77 78 79
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
S
SunAhong1993 已提交
80 81 82 83 84 85 86 87 88 89
            phase = 'test'
            if len(layer.include):
                phase = phase_map[layer.include[0].phase]
            if len(layer.exclude):
                phase = phase_map[1 - layer.include[0].phase]
            exclude = (phase != 'test')
            # Dropout layers appear in a fair number of Caffe
            # test-time networks. These are just ignored. We'll
            # filter them out here.
            if (not exclude) and (phase == 'test'):
90
                exclude = (type_str == 'Dropout')
91 92 93 94 95
                if layer.type == 'Dropout':
                    drop_layer_top = layer.top[0]
                    drop_layer_bottom = layer.bottom[0]
                    if drop_layer_top != drop_layer_bottom:
                        for next_layer in layers:
96 97
                            for next_layer_bottom_idx, next_layer_bottom in enumerate(
                                    next_layer.bottom):
98 99
                                if drop_layer_top == next_layer_bottom:
                                    next_layer.bottom.remove(drop_layer_top)
100 101 102
                                    next_layer.bottom.insert(
                                        next_layer_bottom_idx,
                                        drop_layer_bottom)
103

S
SunAhong1993 已提交
104 105 106
            if not exclude:
                filtered_layers.append(layer)
                # Guard against dupes.
S
SunAhong1993 已提交
107 108
                if layer.name in filtered_layer_names:
                    layer.name += "_0"
S
SunAhong1993 已提交
109 110 111
                assert layer.name not in filtered_layer_names
                filtered_layer_names.add(layer.name)
            else:
S
SunAhong1993 已提交
112
                print('The filter layer:' + layer.name)
S
SunAhong1993 已提交
113 114
        return filtered_layers

S
SunAhong1993 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    def generate_input_layer(self, dims, index):
        dim_str = ''
        for dim in dims:
            dim_str += 'dim: {}\n'.format(str(dim))
        input_str = 'layer {\n'
        input_str += 'name: \"{}\"\n '.format(str(self.model.input[index]))
        input_str += 'type: "Input"\n'
        input_str += 'top: \"{}\"\n'.format(str(self.model.input[index]))
        input_str += 'input_param {\n'
        input_str += 'shape {\n'
        input_str += dim_str
        input_str += '}}}'
        input_str = str.encode(input_str)
        net = self.caffe_pb.NetParameter()
        text_format.Merge(input_str, net)
        return net.layers or net.layer

132 133 134 135 136 137 138 139 140 141
    def input2layers(self, input_layers=[]):
        inputs_num = len(self.model.input)
        if inputs_num != 0:
            input_dims_num = len(self.model.input_dim)
            if input_dims_num != 0:
                if input_dims_num > 0 and input_dims_num != inputs_num * 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                for i in range(inputs_num):
                    dims = self.model.input_dim[i * 4:(i + 1) * 4]
S
SunAhong1993 已提交
142
                    l = self.generate_input_layer(dims, i)
143 144 145 146
                    input_layers.append(l[0])
            else:
                for i in range(inputs_num):
                    dims = self.model.input_shape[i].dim[0:4]
S
SunAhong1993 已提交
147
                    l = self.generate_input_layer(dims, i)
148 149 150 151 152 153 154 155 156 157
                    input_layers.append(l[0])

    def transform_input_layers(self, layers, input_layers=[]):
        for layer in layers:
            if hasattr(layer, 'input'):
                input_dims_num = len(layers.input_dim)
                if input_dims_num > 0 and input_dims_num != 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                dims = self.model.input_dim[0:4]
S
SunAhong1993 已提交
158
                l = self.generate_input_layer(dims, i)
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
                input_layers.append(l[0])

    def get_layer_type(self, layer):
        if isinstance(layer.type, int):
            enum_values = self.caffe_pb._V1LAYERPARAMETER_LAYERTYPE.values
            vals = [val for val in enum_values if val.number == layer.type]
            part = vals[0].name.split('_')
            part = [s.capitalize() for s in part]
            type_str = ''
            type_str = type_str.join(part)
            if 'relu' in type_str.lower():
                type_str = type_str.replace('elu', 'eLU')
            elif type_str.lower() == 'lrn':
                type_str = 'LRN'
            return type_str
        else:
            return layer.type

S
SunAhong1993 已提交
177 178
    def build(self):
        layers = self.model.layers or self.model.layer
179

S
SunAhong1993 已提交
180 181
        layers = self.filter_layers(layers)

182 183 184 185 186
        input_layers = []

        self.input2layers(input_layers)
        self.transform_input_layers(layers, input_layers)
        layers = input_layers + layers
S
SunAhong1993 已提交
187 188 189 190 191 192 193 194
        for layer in layers:
            if hasattr(layer, 'name'):
                name = getattr(layer, 'name')
                setattr(layer, 'name', name.replace('/', '_').replace('-', '_'))
            for i, name in enumerate(layer.bottom):
                layer.bottom[i] = name.replace('/', '_').replace('-', '_')
            for i, name in enumerate(layer.top):
                layer.top[i] = name.replace('/', '_').replace('-', '_')
195

S
SunAhong1993 已提交
196
        top_layer = {}
S
SunAhong1993 已提交
197
        for layer in layers:
198 199 200 201
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
            self.node_map[layer.name] = CaffeGraphNode(layer, type_str)
S
SunAhong1993 已提交
202 203 204
            for in_name in layer.bottom:
                if in_name in top_layer:
                    self.connect(top_layer[in_name][-1], layer.name)
S
SunAhong1993 已提交
205 206 207
                else:
                    raise Exception(
                        'input[{}] of node[{}] does not exist in node_map'.
S
SunAhong1993 已提交
208 209 210 211 212 213
                        format(in_name, layer.name))
            for out_name in layer.top:
                if out_name not in top_layer:
                    top_layer[out_name] = [layer.name]
                else:
                    top_layer[out_name].append(layer.name)
S
SunAhong1993 已提交
214 215 216
        for layer_name, data in self.params:
            if layer_name in self.node_map:
                node = self.node_map[layer_name]
S
SunAhong1993 已提交
217
                node.set_params(data)
S
SunAhong1993 已提交
218
            else:
S
SunAhong1993 已提交
219
                print('Ignoring parameters for non-existent layer: %s' % \
S
SunAhong1993 已提交
220
                       layer_name)
S
SunAhong1993 已提交
221

S
SunAhong1993 已提交
222 223
        super(CaffeGraph, self).build()

S
SunAhong1993 已提交
224 225 226 227 228
    def get_bottom_node(self, node, idx=0, copy=False):
        input_node_name = node.inputs[idx]
        assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format(
            name)
        input_node = self.node_map[input_node_name]
S
SunAhong1993 已提交
229
        if len(input_node.layer.top) > 1 and input_node.layer_type != "Input":
S
SunAhong1993 已提交
230 231
            need_idx = list(input_node.layer.top).index(node.layer.bottom[idx])
            name = input_node_name + ':' + str(need_idx)
S
SunAhong1993 已提交
232 233 234 235
        else:
            name = input_node_name
        return self.get_node(name, copy=copy)

S
SunAhong1993 已提交
236

J
jiangjiajun 已提交
237
class CaffeDecoder(object):
S
SunAhong1993 已提交
238
    def __init__(self, proto_path, model_path, caffe_proto):
S
SunAhong1993 已提交
239 240 241
        self.proto_path = proto_path
        self.model_path = model_path

S
SunAhong1993 已提交
242
        self.resolver = CaffeResolver(caffe_proto=caffe_proto)
S
SunAhong1993 已提交
243 244
        self.net = self.resolver.NetParameter()
        with open(proto_path, 'rb') as proto_file:
245
            proto_str = proto_file.read()
S
SunAhong1993 已提交
246 247
            text_format.Merge(proto_str, self.net)

S
SunAhong1993 已提交
248
        self.load_using_pb()
249 250 251

        self.caffe_graph = CaffeGraph(self.net, self.params,
                                      self.resolver.caffepb)
S
SunAhong1993 已提交
252 253 254 255 256 257
        self.caffe_graph.build()

    def load_using_pb(self):
        data = self.resolver.NetParameter()
        data.MergeFromString(open(self.model_path, 'rb').read())
        layers = data.layers or data.layer
S
SunAhong1993 已提交
258 259 260 261
        for layer in layers:
            setattr(layer, 'name',
                    layer.name.replace('/', '_').replace('-', '_'))
        pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
S
SunAhong1993 已提交
262 263 264 265 266 267 268
        self.params = [pair(layer) for layer in layers if layer.blobs]

    def normalize_pb_data(self, layer):
        transformed = []
        for blob in layer.blobs:
            if len(blob.shape.dim):
                dims = blob.shape.dim
S
SunAhong1993 已提交
269
                if layer.type == 'PReLU':
S
SunAhong1993 已提交
270 271
                    c_o, c_i, h, w = map(int, [1] + \
                        list(dims) + [1]* (3 - len(dims)))
S
SunAhong1993 已提交
272
                elif layer.type == 'Normalize' and len(dims) == 4:
S
SunAhong1993 已提交
273 274 275
                    data = np.asarray(list(blob.data), dtype=np.float32)
                    transformed.append(data)
                    continue
S
SunAhong1993 已提交
276
                else:
S
SunAhong1993 已提交
277 278
                    c_o, c_i, h, w = map(int,
                                         [1] * (4 - len(dims)) + list(dims))
S
SunAhong1993 已提交
279 280 281 282 283
            else:
                c_o = blob.num
                c_i = blob.channels
                h = blob.height
                w = blob.width
J
jiangjiajun 已提交
284 285
            data = np.asarray(
                list(blob.data), dtype=np.float32).reshape(c_o, c_i, h, w)
286

S
SunAhong1993 已提交
287 288
            transformed.append(data)
        return transformed