提交 1d0f9cc1 编写于 作者: 李寅

Refactor caffe depthwise

上级 1f9b2ee2
......@@ -144,6 +144,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'ReLU': ActivationType.RELU,
'PReLU': ActivationType.PRELU,
'TanH': ActivationType.TANH,
'Sigmoid': ActivationType.SIGMOID,
}
def __init__(self, option, src_model_file, src_weight_file):
......@@ -337,10 +338,15 @@ class CaffeConverter(base_converter.ConverterInterface):
param = caffe_op.layer.convolution_param
is_depthwise = False
if param.HasField(caffe_group_str):
mace_check(param.group == caffe_op.blob[0].shape[1] and
caffe_op.blob[0].shape[0] == 1,
filter_data = caffe_op.blobs[0]
mace_check(param.group == filter_data.shape[0] and
filter_data.shape[1] == 1,
"Mace do not support group convolution yet")
is_depthwise = True
caffe_op.blobs[0] = filter_data.reshape(1,
filter_data.shape[0],
filter_data.shape[2],
filter_data.shape[3])
if is_depthwise:
op.type = MaceOp.DepthwiseConv2d.name
......
......@@ -17,6 +17,7 @@ class ShapeInference(object):
def __init__(self, net, input_nodes):
self._op_shape_inference = {
MaceOp.Conv2D.name: self.infer_shape_conv_pool_shape,
MaceOp.DepthwiseConv2d.name: self.infer_shape_conv_pool_shape,
MaceOp.Eltwise.name: self.infer_shape_general,
MaceOp.FoldedBatchNorm.name: self.infer_shape_general,
MaceOp.AddN.name: self.infer_shape_general,
......@@ -104,7 +105,10 @@ class ShapeInference(object):
if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa
# filter format: OIHW
output_shape[1] = filter_shape[0]
if op.type == MaceOp.DepthwiseConv2d.name:
output_shape[1] = filter_shape[0] * filter_shape[1]
else:
output_shape[1] = filter_shape[0]
output_shape[2] = int(
round_func((input_shape[2] + paddings[0] - filter_shape[2] -
(filter_shape[2] - 1) *
......
......@@ -692,7 +692,7 @@ class Transformer(base_converter.ConverterInterface):
filter_data = filter_data.transpose(2, 3, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
elif op.type == MaceOp.Depthwiseconv2d.name:
elif op.type == MaceOp.DepthwiseConv2d.name:
filter_data = filter_data.transpose(2, 3, 1, 0)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册