caffe_decoder.py 11.4 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
S
SunAhong1993 已提交
14 15 16 17 18 19

import os
import sys
from google.protobuf import text_format
import numpy as np
from x2paddle.core.graph import GraphNode, Graph
S
SunAhong1993 已提交
20
from x2paddle.core.fluid_code import FluidCode
S
SunAhong1993 已提交
21 22 23


class CaffeResolver(object):
S
SunAhong1993 已提交
24
    def __init__(self, caffe_proto):
S
SunAhong1993 已提交
25
        self.caffe_proto = caffe_proto
S
SunAhong1993 已提交
26 27 28
        self.import_caffe()

    def import_caffepb(self):
S
SunAhong1993 已提交
29 30 31 32 33 34 35
        if self.caffe_proto is None:
            from x2paddle.decoder import caffe_pb2
            out = caffe_pb2
        else:
            if not os.path.isfile(self.caffe_proto):
                raise Exception(
                    "The .py file compiled by caffe.proto is not exist.")
J
jiangjiajun 已提交
36 37
            (filepath,
             tempfilename) = os.path.split(os.path.abspath(self.caffe_proto))
S
SunAhong1993 已提交
38 39 40
            (filename, extension) = os.path.splitext(tempfilename)
            sys.path.append(filepath)
            out = __import__(filename)
S
SunAhong1993 已提交
41
        return out
S
SunAhong1993 已提交
42 43

    def import_caffe(self):
S
SunAhong1993 已提交
44
        self.caffepb = self.import_caffepb()
S
SunAhong1993 已提交
45 46 47 48
        self.NetParameter = self.caffepb.NetParameter


class CaffeGraphNode(GraphNode):
49
    def __init__(self, layer, type_str, layer_name=None):
S
SunAhong1993 已提交
50
        if layer_name is None:
J
jiangjiajun 已提交
51
            super(CaffeGraphNode, self).__init__(
S
SunAhong1993 已提交
52
                layer, layer.name.replace('/', '_').replace('-', '_').lower())
S
SunAhong1993 已提交
53
        else:
J
jiangjiajun 已提交
54
            super(CaffeGraphNode, self).__init__(
S
SunAhong1993 已提交
55
                layer, layer_name.replace('/', '_').replace('-', '_').lower())
56
        self.layer_type = type_str
S
SunAhong1993 已提交
57
        self.fluid_code = FluidCode()
S
SunAhong1993 已提交
58
        self.data = None
S
SunAhong1993 已提交
59 60 61 62 63 64

    def set_params(self, params):
        self.data = params


class CaffeGraph(Graph):
65
    def __init__(self, model, params, caffe_pb):
S
SunAhong1993 已提交
66
        self.params = params
67
        self.caffe_pb = caffe_pb
S
SunAhong1993 已提交
68 69 70 71 72 73 74
        if hasattr(model, "name"):
            if model.name == "":
                self.graph_name = "CaffeModel"
            else:
                self.graph_name = model.name
        else:
            self.graph_name = "CaffeModel"
S
SunAhong1993 已提交
75 76 77 78 79 80 81 82
        super(CaffeGraph, self).__init__(model)

    def filter_layers(self, layers):
        '''Filter out layers based on the current phase.'''
        phase_map = {0: 'train', 1: 'test'}
        filtered_layer_names = set()
        filtered_layers = []
        for layer in layers:
83 84 85
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
S
SunAhong1993 已提交
86 87 88 89 90 91 92 93 94 95
            phase = 'test'
            if len(layer.include):
                phase = phase_map[layer.include[0].phase]
            if len(layer.exclude):
                phase = phase_map[1 - layer.include[0].phase]
            exclude = (phase != 'test')
            # Dropout layers appear in a fair number of Caffe
            # test-time networks. These are just ignored. We'll
            # filter them out here.
            if (not exclude) and (phase == 'test'):
96
                exclude = (type_str == 'Dropout')
97 98 99 100 101
                if layer.type == 'Dropout':
                    drop_layer_top = layer.top[0]
                    drop_layer_bottom = layer.bottom[0]
                    if drop_layer_top != drop_layer_bottom:
                        for next_layer in layers:
102 103
                            for next_layer_bottom_idx, next_layer_bottom in enumerate(
                                    next_layer.bottom):
104 105
                                if drop_layer_top == next_layer_bottom:
                                    next_layer.bottom.remove(drop_layer_top)
106 107 108
                                    next_layer.bottom.insert(
                                        next_layer_bottom_idx,
                                        drop_layer_bottom)
109

S
SunAhong1993 已提交
110 111 112 113 114 115
            if not exclude:
                filtered_layers.append(layer)
                # Guard against dupes.
                assert layer.name not in filtered_layer_names
                filtered_layer_names.add(layer.name)
            else:
S
SunAhong1993 已提交
116
                print('The filter layer:' + layer.name)
S
SunAhong1993 已提交
117 118
        return filtered_layers

S
SunAhong1993 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
    def generate_input_layer(self, dims, index):
        dim_str = ''
        for dim in dims:
            dim_str += 'dim: {}\n'.format(str(dim))
        input_str = 'layer {\n'
        input_str += 'name: \"{}\"\n '.format(str(self.model.input[index]))
        input_str += 'type: "Input"\n'
        input_str += 'top: \"{}\"\n'.format(str(self.model.input[index]))
        input_str += 'input_param {\n'
        input_str += 'shape {\n'
        input_str += dim_str
        input_str += '}}}'
        input_str = str.encode(input_str)
        net = self.caffe_pb.NetParameter()
        text_format.Merge(input_str, net)
        return net.layers or net.layer

136 137 138 139 140 141 142 143 144 145
    def input2layers(self, input_layers=[]):
        inputs_num = len(self.model.input)
        if inputs_num != 0:
            input_dims_num = len(self.model.input_dim)
            if input_dims_num != 0:
                if input_dims_num > 0 and input_dims_num != inputs_num * 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                for i in range(inputs_num):
                    dims = self.model.input_dim[i * 4:(i + 1) * 4]
S
SunAhong1993 已提交
146
                    l = self.generate_input_layer(dims, i)
147 148 149 150
                    input_layers.append(l[0])
            else:
                for i in range(inputs_num):
                    dims = self.model.input_shape[i].dim[0:4]
S
SunAhong1993 已提交
151
                    l = self.generate_input_layer(dims, i)
152 153 154 155 156 157 158 159 160 161
                    input_layers.append(l[0])

    def transform_input_layers(self, layers, input_layers=[]):
        for layer in layers:
            if hasattr(layer, 'input'):
                input_dims_num = len(layers.input_dim)
                if input_dims_num > 0 and input_dims_num != 4:
                    raise Error('invalid input_dim[%d] param in prototxt' %
                                (input_dims_num))
                dims = self.model.input_dim[0:4]
S
SunAhong1993 已提交
162
                l = self.generate_input_layer(dims, i)
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
                input_layers.append(l[0])

    def get_layer_type(self, layer):
        if isinstance(layer.type, int):
            enum_values = self.caffe_pb._V1LAYERPARAMETER_LAYERTYPE.values
            vals = [val for val in enum_values if val.number == layer.type]
            part = vals[0].name.split('_')
            part = [s.capitalize() for s in part]
            type_str = ''
            type_str = type_str.join(part)
            if 'relu' in type_str.lower():
                type_str = type_str.replace('elu', 'eLU')
            elif type_str.lower() == 'lrn':
                type_str = 'LRN'
            return type_str
        else:
            return layer.type

S
SunAhong1993 已提交
181 182
    def build(self):
        layers = self.model.layers or self.model.layer
183

S
SunAhong1993 已提交
184 185
        layers = self.filter_layers(layers)

186 187 188 189 190
        input_layers = []

        self.input2layers(input_layers)
        self.transform_input_layers(layers, input_layers)
        layers = input_layers + layers
S
SunAhong1993 已提交
191 192 193 194 195 196 197 198
        for layer in layers:
            if hasattr(layer, 'name'):
                name = getattr(layer, 'name')
                setattr(layer, 'name', name.replace('/', '_').replace('-', '_'))
            for i, name in enumerate(layer.bottom):
                layer.bottom[i] = name.replace('/', '_').replace('-', '_')
            for i, name in enumerate(layer.top):
                layer.top[i] = name.replace('/', '_').replace('-', '_')
199

S
SunAhong1993 已提交
200
        top_layer = {}
S
SunAhong1993 已提交
201
        for layer in layers:
202 203 204 205
            if hasattr(layer, 'input'):
                continue
            type_str = self.get_layer_type(layer)
            self.node_map[layer.name] = CaffeGraphNode(layer, type_str)
S
SunAhong1993 已提交
206 207 208
            for in_name in layer.bottom:
                if in_name in top_layer:
                    self.connect(top_layer[in_name][-1], layer.name)
S
SunAhong1993 已提交
209 210 211
                else:
                    raise Exception(
                        'input[{}] of node[{}] does not exist in node_map'.
S
SunAhong1993 已提交
212 213 214 215 216 217
                        format(in_name, layer.name))
            for out_name in layer.top:
                if out_name not in top_layer:
                    top_layer[out_name] = [layer.name]
                else:
                    top_layer[out_name].append(layer.name)
S
SunAhong1993 已提交
218 219 220
        for layer_name, data in self.params:
            if layer_name in self.node_map:
                node = self.node_map[layer_name]
S
SunAhong1993 已提交
221
                node.set_params(data)
S
SunAhong1993 已提交
222
            else:
S
SunAhong1993 已提交
223
                print('Ignoring parameters for non-existent layer: %s' % \
S
SunAhong1993 已提交
224
                       layer_name)
S
SunAhong1993 已提交
225

S
SunAhong1993 已提交
226 227
        super(CaffeGraph, self).build()

S
SunAhong1993 已提交
228 229 230 231 232 233
    def get_bottom_node(self, node, idx=0, copy=False):
        input_node_name = node.inputs[idx]
        assert input_node_name in self.node_map, 'The {} isn\'t a valid node'.format(
            name)
        input_node = self.node_map[input_node_name]
        if len(input_node.layer.top) > 1:
S
SunAhong1993 已提交
234 235
            need_idx = list(input_node.layer.top).index(node.layer.bottom[idx])
            name = input_node_name + ':' + str(need_idx)
S
SunAhong1993 已提交
236 237 238 239
        else:
            name = input_node_name
        return self.get_node(name, copy=copy)

S
SunAhong1993 已提交
240

J
jiangjiajun 已提交
241
class CaffeDecoder(object):
S
SunAhong1993 已提交
242
    def __init__(self, proto_path, model_path, caffe_proto):
S
SunAhong1993 已提交
243 244 245
        self.proto_path = proto_path
        self.model_path = model_path

S
SunAhong1993 已提交
246
        self.resolver = CaffeResolver(caffe_proto=caffe_proto)
S
SunAhong1993 已提交
247 248
        self.net = self.resolver.NetParameter()
        with open(proto_path, 'rb') as proto_file:
249
            proto_str = proto_file.read()
S
SunAhong1993 已提交
250
            text_format.Merge(proto_str, self.net)
S
SunAhong1993 已提交
251
        
S
SunAhong1993 已提交
252
        self.load_using_pb()
253 254 255

        self.caffe_graph = CaffeGraph(self.net, self.params,
                                      self.resolver.caffepb)
S
SunAhong1993 已提交
256 257 258 259 260 261
        self.caffe_graph.build()

    def load_using_pb(self):
        data = self.resolver.NetParameter()
        data.MergeFromString(open(self.model_path, 'rb').read())
        layers = data.layers or data.layer
S
SunAhong1993 已提交
262 263 264 265
        for layer in layers:
            setattr(layer, 'name',
                    layer.name.replace('/', '_').replace('-', '_'))
        pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
S
SunAhong1993 已提交
266 267 268 269 270 271 272
        self.params = [pair(layer) for layer in layers if layer.blobs]

    def normalize_pb_data(self, layer):
        transformed = []
        for blob in layer.blobs:
            if len(blob.shape.dim):
                dims = blob.shape.dim
S
SunAhong1993 已提交
273
                if layer.type == 'PReLU':
S
SunAhong1993 已提交
274 275
                    c_o, c_i, h, w = map(int, [1] + \
                        list(dims) + [1]* (3 - len(dims)))
S
SunAhong1993 已提交
276
                elif layer.type == 'Normalize' and len(dims) == 4:
S
SunAhong1993 已提交
277 278 279
                    data = np.asarray(list(blob.data), dtype=np.float32)
                    transformed.append(data)
                    continue
S
SunAhong1993 已提交
280
                else:
S
SunAhong1993 已提交
281 282
                    c_o, c_i, h, w = map(int,
                                         [1] * (4 - len(dims)) + list(dims))
S
SunAhong1993 已提交
283 284 285 286 287
            else:
                c_o = blob.num
                c_i = blob.channels
                h = blob.height
                w = blob.width
J
jiangjiajun 已提交
288 289
            data = np.asarray(
                list(blob.data), dtype=np.float32).reshape(c_o, c_i, h, w)
290

S
SunAhong1993 已提交
291 292
            transformed.append(data)
        return transformed