caffe_decoder.py 10.5 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14 15 16 17 18 19

import os
import sys
from google.protobuf import text_format
import numpy as np
from x2paddle.core.graph import GraphNode, Graph
S
SunAhong1993 已提交
20
from x2paddle.core.fluid_code import FluidCode
S
SunAhong1993 已提交
21
from x2paddle.op_mapper import caffe_shape
S
SunAhong1993 已提交
22 23 24


class CaffeResolver(object):
S
SunAhong1993 已提交
25
    def __init__(self, caffe_proto):
S
SunAhong1993 已提交
26
        self.caffe_proto = caffe_proto
S
SunAhong1993 已提交
27 28 29
        self.import_caffe()

    def import_caffepb(self):
S
SunAhong1993 已提交
30 31 32 33 34 35 36
        if self.caffe_proto is None:
            from x2paddle.decoder import caffe_pb2
            out = caffe_pb2
        else:
            if not os.path.isfile(self.caffe_proto):
                raise Exception(
                    "The .py file compiled by caffe.proto is not exist.")
J
jiangjiajun 已提交
37 38
            (filepath,
             tempfilename) = os.path.split(os.path.abspath(self.caffe_proto))
S
SunAhong1993 已提交
39 40 41
            (filename, extension) = os.path.splitext(tempfilename)
            sys.path.append(filepath)
            out = __import__(filename)
S
SunAhong1993 已提交
42
        return out
S
SunAhong1993 已提交
43 44

    def import_caffe(self):
S
SunAhong1993 已提交
45
        self.caffepb = self.import_caffepb()
S
SunAhong1993 已提交
46 47 48 49
        self.NetParameter = self.caffepb.NetParameter


class CaffeGraphNode(GraphNode):
50
    def __init__(self, layer, type_str, layer_name=None):
S
SunAhong1993 已提交
51
        if layer_name is None:
J
jiangjiajun 已提交
52
            super(CaffeGraphNode, self).__init__(
J
jiangjiajun 已提交
53
                layer, layer.name.replace('/', '_').replace('-', '_'))
S
SunAhong1993 已提交
54
        else:
J
jiangjiajun 已提交
55
            super(CaffeGraphNode, self).__init__(
J
jiangjiajun 已提交
56
                layer, layer_name.replace('/', '_').replace('-', '_'))
57
        self.layer_type = type_str
S
SunAhong1993 已提交
58
        self.fluid_code = FluidCode()
S
SunAhong1993 已提交
59
        self.data = None
S
SunAhong1993 已提交
60 61 62 63 64 65

    def set_params(self, params):
        self.data = params


class CaffeGraph(Graph):
66
    def __init__(self, model, params, caffe_pb):
S
SunAhong1993 已提交
67
        self.params = params
68
        self.caffe_pb = caffe_pb
S
SunAhong1993 已提交
69 70 71 72 73 74 75 76
        super(CaffeGraph, self).__init__(model)

    def filter_layers(self, layers):
        '''Filter out layers based on the current phase.'''
        phase_map = {0: 'train', 1: 'test'}
        filtered_layer_names = set()
        filtered_layers = []
        for layer in layers:
77 78 79
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
S
SunAhong1993 已提交
80 81 82 83 84 85 86 87 88 89
            phase = 'test'
            if len(layer.include):
                phase = phase_map[layer.include[0].phase]
            if len(layer.exclude):
                phase = phase_map[1 - layer.include[0].phase]
            exclude = (phase != 'test')
            # Dropout layers appear in a fair number of Caffe
            # test-time networks. These are just ignored. We'll
            # filter them out here.
            if (not exclude) and (phase == 'test'):
90
                exclude = (type_str == 'Dropout')
S
SunAhong1993 已提交
91 92 93 94 95 96
            if not exclude:
                filtered_layers.append(layer)
                # Guard against dupes.
                assert layer.name not in filtered_layer_names
                filtered_layer_names.add(layer.name)
            else:
S
SunAhong1993 已提交
97
                print('The filter layer:' + layer.name)
S
SunAhong1993 已提交
98 99
        return filtered_layers

S
SunAhong1993 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    def generate_input_layer(self, dims, index):
        dim_str = ''
        for dim in dims:
            dim_str += 'dim: {}\n'.format(str(dim))
        input_str = 'layer {\n'
        input_str += 'name: \"{}\"\n '.format(str(self.model.input[index]))
        input_str += 'type: "Input"\n'
        input_str += 'top: \"{}\"\n'.format(str(self.model.input[index]))
        input_str += 'input_param {\n'
        input_str += 'shape {\n'
        input_str += dim_str
        input_str += '}}}'
        input_str = str.encode(input_str)
        net = self.caffe_pb.NetParameter()
        text_format.Merge(input_str, net)
        return net.layers or net.layer

117 118 119 120 121 122 123 124 125 126
    def input2layers(self, input_layers=[]):
        inputs_num = len(self.model.input)
        if inputs_num != 0:
            input_dims_num = len(self.model.input_dim)
            if input_dims_num != 0:
                if input_dims_num > 0 and input_dims_num != inputs_num * 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                for i in range(inputs_num):
                    dims = self.model.input_dim[i * 4:(i + 1) * 4]
S
SunAhong1993 已提交
127
                    l = self.generate_input_layer(dims, i)
128 129 130 131
                    input_layers.append(l[0])
            else:
                for i in range(inputs_num):
                    dims = self.model.input_shape[i].dim[0:4]
S
SunAhong1993 已提交
132
                    l = self.generate_input_layer(dims, i)
133 134 135 136 137 138 139 140 141 142
                    input_layers.append(l[0])

    def transform_input_layers(self, layers, input_layers=[]):
        for layer in layers:
            if hasattr(layer, 'input'):
                input_dims_num = len(layers.input_dim)
                if input_dims_num > 0 and input_dims_num != 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                dims = self.model.input_dim[0:4]
S
SunAhong1993 已提交
143
                l = self.generate_input_layer(dims, i)
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
                input_layers.append(l[0])

    def get_layer_type(self, layer):
        if isinstance(layer.type, int):
            enum_values = self.caffe_pb._V1LAYERPARAMETER_LAYERTYPE.values
            vals = [val for val in enum_values if val.number == layer.type]
            part = vals[0].name.split('_')
            part = [s.capitalize() for s in part]
            type_str = ''
            type_str = type_str.join(part)
            if 'relu' in type_str.lower():
                type_str = type_str.replace('elu', 'eLU')
            elif type_str.lower() == 'lrn':
                type_str = 'LRN'
            return type_str
        else:
            return layer.type

S
SunAhong1993 已提交
162 163
    def build(self):
        layers = self.model.layers or self.model.layer
164

S
SunAhong1993 已提交
165 166
        layers = self.filter_layers(layers)

167 168 169 170 171
        input_layers = []

        self.input2layers(input_layers)
        self.transform_input_layers(layers, input_layers)
        layers = input_layers + layers
S
SunAhong1993 已提交
172 173 174 175 176 177 178 179
        for layer in layers:
            if hasattr(layer, 'name'):
                name = getattr(layer, 'name')
                setattr(layer, 'name', name.replace('/', '_').replace('-', '_'))
            for i, name in enumerate(layer.bottom):
                layer.bottom[i] = name.replace('/', '_').replace('-', '_')
            for i, name in enumerate(layer.top):
                layer.top[i] = name.replace('/', '_').replace('-', '_')
180

S
SunAhong1993 已提交
181
        top_layer = {}
S
SunAhong1993 已提交
182
        for layer in layers:
183 184 185 186
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
            self.node_map[layer.name] = CaffeGraphNode(layer, type_str)
S
SunAhong1993 已提交
187 188 189
            for in_name in layer.bottom:
                if in_name in top_layer:
                    self.connect(top_layer[in_name][-1], layer.name)
S
SunAhong1993 已提交
190 191 192
                else:
                    raise Exception(
                        'input[{}] of node[{}] does not exist in node_map'.
S
SunAhong1993 已提交
193 194 195 196 197 198
                        format(in_name, layer.name))
            for out_name in layer.top:
                if out_name not in top_layer:
                    top_layer[out_name] = [layer.name]
                else:
                    top_layer[out_name].append(layer.name)
S
SunAhong1993 已提交
199 200 201
        for layer_name, data in self.params:
            if layer_name in self.node_map:
                node = self.node_map[layer_name]
S
SunAhong1993 已提交
202
                node.set_params(data)
S
SunAhong1993 已提交
203
            else:
S
SunAhong1993 已提交
204
                print('Ignoring parameters for non-existent layer: %s' % \
S
SunAhong1993 已提交
205
                       layer_name)
S
SunAhong1993 已提交
206

S
SunAhong1993 已提交
207 208
        super(CaffeGraph, self).build()

S
SunAhong1993 已提交
209 210 211 212 213 214
    def get_bottom_node(self, node, idx=0, copy=False):
        input_node_name = node.inputs[idx]
        assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format(
            name)
        input_node = self.node_map[input_node_name]
        if len(input_node.layer.top) > 1:
S
SunAhong1993 已提交
215 216
            need_idx = list(input_node.layer.top).index(node.layer.bottom[idx])
            name = input_node_name + ':' + str(need_idx)
S
SunAhong1993 已提交
217 218 219 220
        else:
            name = input_node_name
        return self.get_node(name, copy=copy)

S
SunAhong1993 已提交
221

J
jiangjiajun 已提交
222
class CaffeDecoder(object):
S
SunAhong1993 已提交
223
    def __init__(self, proto_path, model_path, caffe_proto):
S
SunAhong1993 已提交
224 225 226
        self.proto_path = proto_path
        self.model_path = model_path

S
SunAhong1993 已提交
227
        self.resolver = CaffeResolver(caffe_proto=caffe_proto)
S
SunAhong1993 已提交
228 229
        self.net = self.resolver.NetParameter()
        with open(proto_path, 'rb') as proto_file:
230
            proto_str = proto_file.read()
S
SunAhong1993 已提交
231 232
            text_format.Merge(proto_str, self.net)

S
SunAhong1993 已提交
233
        self.load_using_pb()
234 235 236

        self.caffe_graph = CaffeGraph(self.net, self.params,
                                      self.resolver.caffepb)
S
SunAhong1993 已提交
237 238 239 240 241 242
        self.caffe_graph.build()

    def load_using_pb(self):
        data = self.resolver.NetParameter()
        data.MergeFromString(open(self.model_path, 'rb').read())
        layers = data.layers or data.layer
S
SunAhong1993 已提交
243 244 245 246
        for layer in layers:
            setattr(layer, 'name',
                    layer.name.replace('/', '_').replace('-', '_'))
        pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
S
SunAhong1993 已提交
247 248 249 250 251 252 253
        self.params = [pair(layer) for layer in layers if layer.blobs]

    def normalize_pb_data(self, layer):
        transformed = []
        for blob in layer.blobs:
            if len(blob.shape.dim):
                dims = blob.shape.dim
S
SunAhong1993 已提交
254
                if layer.type == 'PReLU':
S
SunAhong1993 已提交
255 256
                    c_o, c_i, h, w = map(int, [1] + \
                        list(dims) + [1]* (3 - len(dims)))
S
SunAhong1993 已提交
257
                elif layer.type == 'Normalize' and len(dims) == 4:
S
SunAhong1993 已提交
258 259 260
                    data = np.asarray(list(blob.data), dtype=np.float32)
                    transformed.append(data)
                    continue
S
SunAhong1993 已提交
261
                else:
S
SunAhong1993 已提交
262 263
                    c_o, c_i, h, w = map(int,
                                         [1] * (4 - len(dims)) + list(dims))
S
SunAhong1993 已提交
264 265 266 267 268
            else:
                c_o = blob.num
                c_i = blob.channels
                h = blob.height
                w = blob.width
J
jiangjiajun 已提交
269 270
            data = np.asarray(
                list(blob.data), dtype=np.float32).reshape(c_o, c_i, h, w)
271

S
SunAhong1993 已提交
272 273
            transformed.append(data)
        return transformed