From e4f6ffa191ca866900761070fc35e9fd4fa468a1 Mon Sep 17 00:00:00 2001 From: walloollaw <37680514+walloollaw@users.noreply.github.com> Date: Tue, 10 Jul 2018 10:34:23 +0800 Subject: [PATCH] caffe2fluid:support ssd model conversion (#1039) --- .../kaffe/custom_layers/__init__.py | 5 ++++ .../kaffe/custom_layers/flatten.py | 23 +++++---------- .../kaffe/custom_layers/reshape.py | 14 +++++++-- .../caffe2fluid/kaffe/layers.py | 29 +++++++++++++++---- .../caffe2fluid/kaffe/paddle/network.py | 4 ++- .../caffe2fluid/kaffe/paddle/transformer.py | 28 ++++++------------ .../caffe2fluid/kaffe/shapes.py | 16 +++++++--- .../caffe2fluid/kaffe/transformers.py | 2 ++ 8 files changed, 75 insertions(+), 46 deletions(-) diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py index 703c6a0a..b1dbe215 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/__init__.py @@ -8,6 +8,11 @@ import axpy import flatten import argmax import reshape +import roipooling +import priorbox +import permute +import detection_out +import normalize #custom layer import ends diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py index 8f7af426..ebb97718 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/flatten.py @@ -4,11 +4,6 @@ from .register import register -def import_fluid(): - import paddle.fluid as fluid - return fluid - - def flatten_shape(input_shape, axis=1, end_axis=-1): """ calculate the output shape of this layer using input shape @@ -28,7 +23,7 @@ def flatten_shape(input_shape, axis=1, end_axis=-1): start_axis += len(input_shape) if end_axis < 0: - end_axis += len(input_shape) + end_axis += len(input_shape) + 1 assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\ % (start_axis, end_axis) @@ -52,18 +47,16 @@ def flatten_layer(input, name, axis=1, end_axis=-1): Returns: output (variable): output variable for this layer """ - fluid = import_fluid() + import paddle.fluid as fluid input_shape = list(input.shape) - dims = len(input_shape) - start_axis = axis if axis >= 0 else axis + dims - end_axis = end_axis if end_axis >= 0 else end_axis + dims - assert start_axis <= end_axis, 'invalid axis or end_axis params' - output_shape = input_shape[0:start_axis] - flat_sz = reduce(lambda a, b: a * b, input_shape[start_axis:end_axis]) - output_shape += [flat_sz] - output_shape += input_shape[end_axis:-1] + if input_shape[0] == -1: + input_shape[0] = 1 + output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis) + output_shape[0] = -1 + else: + output_shape = flatten_shape(input_shape, axis=axis, end_axis=end_axis) output = fluid.layers.reshape(input, shape=output_shape, name=name) diff --git a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/reshape.py b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/reshape.py index 6b8d5681..da82e4d6 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/custom_layers/reshape.py +++ b/fluid/image_classification/caffe2fluid/kaffe/custom_layers/reshape.py @@ -68,15 +68,23 @@ def reshape_shape(input_sp, shape, axis=0, num_axes=-1): top_dim = shape['dim'][i] if top_dim == 0: copy_axes.append(i) + copy_axis_index = start_axis + i + output_shape[copy_axis_index] = input_shape[copy_axis_index] elif top_dim == -1: assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims" + inferred_axis = i else: constant_count *= top_dim if inferred_axis >= 0: explicit_count = constant_count - explicit_count *= count(input_shape[0:start_axis]) - explicit_count *= count(input_shape[end_axis:]) + l = input_shape[0:start_axis] + if len(l) > 0: + explicit_count *= count(l) + + l = input_shape[end_axis:] + if len(l) > 0: + explicit_count *= count(l) for i in range(len(copy_axes)): explicit_count *= output_shape[start_axis + copy_axes[i]] @@ -84,6 +92,7 @@ def reshape_shape(input_sp, shape, axis=0, num_axes=-1): assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\ "must be divisible by product of the specified dimensions[%d] "\ % (input_count, explicit_count) + output_shape[start_axis + inferred_axis] = input_count / explicit_count output_count = count(output_shape) assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % ( @@ -117,6 +126,7 @@ def reshape_layer(input, name, shape, axis=0, num_axes=-1): output_shape = reshape_shape(input_shape, shape, axis, num_axes) output = fluid.layers.reshape(input, shape=output_shape, name=name) + return output diff --git a/fluid/image_classification/caffe2fluid/kaffe/layers.py b/fluid/image_classification/caffe2fluid/kaffe/layers.py index f2d54c59..7a13cf8b 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/layers.py +++ b/fluid/image_classification/caffe2fluid/kaffe/layers.py @@ -179,6 +179,9 @@ class LayerAdapter(object): @property def parameters(self): name = NodeDispatch.get_handler_name(self.kind) + if self.kind.lower() == "normalize": + name = "norm" + name = '_'.join((name, 'param')) try: return getattr(self.layer, name) @@ -217,9 +220,25 @@ class LayerAdapter(object): params.stride_w, params.stride, 1, default=1) p_h = self.get_kernel_value(params.pad_h, params.pad, 0, default=0) p_w = self.get_kernel_value(params.pad_w, params.pad, 1, default=0) - return KernelParameters(k_h, k_w, s_h, s_w, p_h, p_w) - -KernelParameters = namedtuple('KernelParameters', [ - 'kernel_h', 'kernel_w', 'stride_h', 'stride_w', 'pad_h', 'pad_w' -]) + dila_h = dila_w = 1 + if self.kind in (NodeKind.Convolution, ): + dila_len = len(params.dilation) + if dila_len == 2: + dila_h = params.dilation[0] + dila_w = params.dilation[1] + elif dila_len == 1: + dila_h = dila_w = params.dilation[0] + else: + assert dila_len == 0, "invalid length[%s] of dilation in convolution" % ( + dila_len) + + return KernelParameters(k_h, k_w, s_h, s_w, p_h, p_w, dila_h, dila_w) + + +KernelParameters = namedtuple( + 'KernelParameters', + [ + 'kernel_h', 'kernel_w', 'stride_h', 'stride_w', 'pad_h', 'pad_w', + 'dila_h', 'dila_w' + ], ) diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py index e8b0f2c3..a6e5eaa3 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py +++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/network.py @@ -91,7 +91,7 @@ class Network(object): name = '%s_%s' % (op_name, param_name) v = fluid.global_scope().find_var(name) w = v.get_tensor() - w.set(data, place) + w.set(data.reshape(w.shape()), place) except ValueError: if not ignore_missing: raise @@ -144,6 +144,7 @@ class Network(object): relu=True, relu_negative_slope=0.0, padding=None, + dilation=1, group=1, biased=True): if padding is None: @@ -173,6 +174,7 @@ class Network(object): num_filters=c_o, stride=[s_h, s_w], padding=padding, + dilation=dilation, groups=group, param_attr=fluid.ParamAttr(name=prefix + "weights"), bias_attr=fluid.ParamAttr(name=prefix + "biases"), diff --git a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py index 02a600bc..7cb7b598 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py +++ b/fluid/image_classification/caffe2fluid/kaffe/paddle/transformer.py @@ -9,21 +9,6 @@ from ..transformers import (DataInjector, DataReshaper, NodeRenamer, from . import network -def get_padding_type(kernel_params, input_shape, output_shape): - '''Translates Caffe's numeric padding to one of ('SAME', 'VALID'). - Caffe supports arbitrary padding values, while Paddle only - supports 'SAME' and 'VALID' modes. So, not all Caffe paddings - can be translated to Paddle. There are some subtleties to - how the padding edge-cases are handled. These are described here: - https://github.com/Yangqing/caffe2/blob/master/caffe2/proto/caffe2_legacy.proto - ''' - k_h, k_w, s_h, s_w, p_h, p_w = kernel_params - if p_h > 0 or p_w > 0: - return [p_h, p_w] - else: - return None - - class PaddleNode(object): '''An intermediate representation for Paddle operations.''' @@ -78,10 +63,11 @@ class PaddleMapper(NodeMapper): def get_kernel_params(self, node): kernel_params = node.layer.kernel_parameters input_shape = node.get_only_parent().output_shape - padding = get_padding_type(kernel_params, input_shape, - node.output_shape) - # Only emit the padding if it's not the default value. - padding = {'padding': padding} if padding is not None else {} + padding = [kernel_params.pad_h, kernel_params.pad_w] + if padding[0] == 0 and padding[1] == 0: + padding = {} + else: + padding = {'padding': padding} return (kernel_params, padding) def map_convolution(self, node): @@ -95,6 +81,10 @@ class PaddleMapper(NodeMapper): kwargs['group'] = group if not node.parameters.bias_term: kwargs['biased'] = False + + if kernel_params.dila_h != 1 or kernel_params.dila_w != 1: + kwargs['dilation'] = (kernel_params.dila_h, kernel_params.dila_w) + assert kernel_params.kernel_h == h assert kernel_params.kernel_w == w return MaybeActivated(node)( diff --git a/fluid/image_classification/caffe2fluid/kaffe/shapes.py b/fluid/image_classification/caffe2fluid/kaffe/shapes.py index 379cfce6..70db5828 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/shapes.py +++ b/fluid/image_classification/caffe2fluid/kaffe/shapes.py @@ -6,6 +6,8 @@ from .errors import KaffeError Tensor4DShape = namedtuple('Tensor4DShape', ['batch_size', 'channels', 'height', 'width']) +Tensor3DShape = namedtuple('Tensor3DShape', ['batch_size', 'data1', 'data2']) + Tensor2DShape = namedtuple('Tensor2DShape', ['batch_size', 'data']) ScalarShape = namedtuple('ScalarShape', ['batch_size']) @@ -14,6 +16,8 @@ ScalarShape = namedtuple('ScalarShape', ['batch_size']) def make_tensor(batch_size, d1=None, d2=None, d3=None): if d3 is not None: return Tensor4DShape(batch_size, d1, d2, d3) + elif d1 is not None and d2 is not None: + return Tensor3DShape(batch_size, d1, d2) elif d1 is not None and d2 is None: return Tensor2DShape(batch_size, d1) elif d1 is None and d2 is None and d3 is None: @@ -24,10 +28,14 @@ def make_tensor(batch_size, d1=None, d2=None, d3=None): def get_filter_output_shape(i_h, i_w, params, round_func): - o_h = (i_h + 2 * params.pad_h - params.kernel_h - ) / float(params.stride_h) + 1 - o_w = (i_w + 2 * params.pad_w - params.kernel_w - ) / float(params.stride_w) + 1 + dila_h = getattr(params, 'dila_h', 1) + dila_w = getattr(params, 'dila_w', 1) + + o_h = (i_h + 2 * params.pad_h - + (dila_h * (params.kernel_h - 1) + 1)) / float(params.stride_h) + 1 + o_w = (i_w + 2 * params.pad_w - + (dila_w * (params.kernel_w - 1) + 1)) / float(params.stride_w) + 1 + return (int(round_func(o_h)), int(round_func(o_w))) diff --git a/fluid/image_classification/caffe2fluid/kaffe/transformers.py b/fluid/image_classification/caffe2fluid/kaffe/transformers.py index 6b53e05a..7c505a79 100644 --- a/fluid/image_classification/caffe2fluid/kaffe/transformers.py +++ b/fluid/image_classification/caffe2fluid/kaffe/transformers.py @@ -337,6 +337,8 @@ class ParameterNamer(object): names = ('scale', ) if getattr(node.parameters, 'bias_term', False): names = ('scale', 'offset') + elif node.kind == "Normalize": + names = ('scale', ) else: warn('Unhandled parameters when naming this it[%s]' % (node.kind)) -- GitLab