提交 7d632862 编写于 作者: 刘琦

Merge branch 'transform' into 'master'

Refactor caffe depthwise

See merge request !481
...@@ -144,6 +144,7 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -144,6 +144,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'ReLU': ActivationType.RELU, 'ReLU': ActivationType.RELU,
'PReLU': ActivationType.PRELU, 'PReLU': ActivationType.PRELU,
'TanH': ActivationType.TANH, 'TanH': ActivationType.TANH,
'Sigmoid': ActivationType.SIGMOID,
} }
def __init__(self, option, src_model_file, src_weight_file): def __init__(self, option, src_model_file, src_weight_file):
...@@ -337,10 +338,15 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -337,10 +338,15 @@ class CaffeConverter(base_converter.ConverterInterface):
param = caffe_op.layer.convolution_param param = caffe_op.layer.convolution_param
is_depthwise = False is_depthwise = False
if param.HasField(caffe_group_str): if param.HasField(caffe_group_str):
mace_check(param.group == caffe_op.blob[0].shape[1] and filter_data = caffe_op.blobs[0]
caffe_op.blob[0].shape[0] == 1, mace_check(param.group == filter_data.shape[0] and
filter_data.shape[1] == 1,
"Mace do not support group convolution yet") "Mace do not support group convolution yet")
is_depthwise = True is_depthwise = True
caffe_op.blobs[0] = filter_data.reshape(1,
filter_data.shape[0],
filter_data.shape[2],
filter_data.shape[3])
if is_depthwise: if is_depthwise:
op.type = MaceOp.DepthwiseConv2d.name op.type = MaceOp.DepthwiseConv2d.name
......
...@@ -17,6 +17,7 @@ class ShapeInference(object): ...@@ -17,6 +17,7 @@ class ShapeInference(object):
def __init__(self, net, input_nodes): def __init__(self, net, input_nodes):
self._op_shape_inference = { self._op_shape_inference = {
MaceOp.Conv2D.name: self.infer_shape_conv_pool_shape, MaceOp.Conv2D.name: self.infer_shape_conv_pool_shape,
MaceOp.DepthwiseConv2d.name: self.infer_shape_conv_pool_shape,
MaceOp.Eltwise.name: self.infer_shape_general, MaceOp.Eltwise.name: self.infer_shape_general,
MaceOp.FoldedBatchNorm.name: self.infer_shape_general, MaceOp.FoldedBatchNorm.name: self.infer_shape_general,
MaceOp.AddN.name: self.infer_shape_general, MaceOp.AddN.name: self.infer_shape_general,
...@@ -104,7 +105,10 @@ class ShapeInference(object): ...@@ -104,7 +105,10 @@ class ShapeInference(object):
if ConverterUtil.data_format(op) == DataFormat.NCHW \ if ConverterUtil.data_format(op) == DataFormat.NCHW \
and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa and ConverterUtil.filter_format(self._net) == FilterFormat.OIHW: # noqa
# filter format: OIHW # filter format: OIHW
output_shape[1] = filter_shape[0] if op.type == MaceOp.DepthwiseConv2d.name:
output_shape[1] = filter_shape[0] * filter_shape[1]
else:
output_shape[1] = filter_shape[0]
output_shape[2] = int( output_shape[2] = int(
round_func((input_shape[2] + paddings[0] - filter_shape[2] - round_func((input_shape[2] + paddings[0] - filter_shape[2] -
(filter_shape[2] - 1) * (filter_shape[2] - 1) *
......
...@@ -692,7 +692,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -692,7 +692,7 @@ class Transformer(base_converter.ConverterInterface):
filter_data = filter_data.transpose(2, 3, 0, 1) filter_data = filter_data.transpose(2, 3, 0, 1)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
elif op.type == MaceOp.Depthwiseconv2d.name: elif op.type == MaceOp.DepthwiseConv2d.name:
filter_data = filter_data.transpose(2, 3, 1, 0) filter_data = filter_data.transpose(2, 3, 1, 0)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册