From 1d0f9cc12f9df7a7efaa179c8e8d111018b934fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 16 May 2018 11:26:49 +0800 Subject: [PATCH] Refactor caffe depthwise --- mace/python/tools/converter_tool/caffe_converter.py | 10 ++++++++-- mace/python/tools/converter_tool/shape_inference.py | 6 +++++- mace/python/tools/converter_tool/transformer.py | 2 +- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index a0298bb1..9084ee82 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -144,6 +144,7 @@ class CaffeConverter(base_converter.ConverterInterface): 'ReLU': ActivationType.RELU, 'PReLU': ActivationType.PRELU, 'TanH': ActivationType.TANH, + 'Sigmoid': ActivationType.SIGMOID, } def __init__(self, option, src_model_file, src_weight_file): @@ -337,10 +338,15 @@ class CaffeConverter(base_converter.ConverterInterface): param = caffe_op.layer.convolution_param is_depthwise = False if param.HasField(caffe_group_str): - mace_check(param.group == caffe_op.blob[0].shape[1] and - caffe_op.blob[0].shape[0] == 1, + filter_data = caffe_op.blobs[0] + mace_check(param.group == filter_data.shape[0] and + filter_data.shape[1] == 1, "Mace do not support group convolution yet") is_depthwise = True + caffe_op.blobs[0] = filter_data.reshape(1, + filter_data.shape[0], + filter_data.shape[2], + filter_data.shape[3]) if is_depthwise: op.type = MaceOp.DepthwiseConv2d.name diff --git a/mace/python/tools/converter_tool/shape_inference.py b/mace/python/tools/converter_tool/shape_inference.py index f6dfda11..a260be1c 100644 --- a/mace/python/tools/converter_tool/shape_inference.py +++ b/mace/python/tools/converter_tool/shape_inference.py @@ -17,6 +17,7 @@ class ShapeInference(object): def __init__(self, net, input_nodes): self._op_shape_inference = { MaceOp.Conv2D.name: self.infer_shape_conv_pool_shape, + MaceOp.DepthwiseConv2d.name: self.infer_shape_conv_pool_shape, MaceOp.Eltwise.name: self.infer_shape_general, MaceOp.FoldedBatchNorm.name: self.infer_shape_general, MaceOp.AddN.name: self.infer_shape_general, @@ -104,7 +105,10 @@ class ShapeInference(object): if ConverterUtil.data_format(op) == DataFormat.NCHW \ and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa # filter format: OIHW - output_shape[1] = filter_shape[0] + if op.type == MaceOp.DepthwiseConv2d.name: + output_shape[1] = filter_shape[0] * filter_shape[1] + else: + output_shape[1] = filter_shape[0] output_shape[2] = int( round_func((input_shape[2] + paddings[0] - filter_shape[2] - (filter_shape[2] - 1) * diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index adfe3e0c..6dc51b7d 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -692,7 +692,7 @@ class Transformer(base_converter.ConverterInterface): filter_data = filter_data.transpose(2, 3, 0, 1) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape - elif op.type == MaceOp.Depthwiseconv2d.name: + elif op.type == MaceOp.DepthwiseConv2d.name: filter_data = filter_data.transpose(2, 3, 1, 0) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape -- GitLab