提交 1b272a6f 编写于 作者: 刘托

Merge branch 'master' into 'master'

Add channel shuffel caffe convert

See merge request !901
......@@ -187,6 +187,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'BatchNorm': self.convert_folded_batchnorm,
'Crop': self.convert_crop,
'Scale': self.convert_scale,
'ShuffleChannel': self.convert_channel_shuffle,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -656,3 +657,14 @@ class CaffeConverter(base_converter.ConverterInterface):
ConverterUtil.add_data_format_arg(biasadd_op,
DataFormat.NCHW)
def convert_channel_shuffle(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.shuffle_channel_param
op.type = MaceOp.ChannelShuffle.name
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
group_arg.i = 1
if param.HasField('group'):
group_arg.i = param.group
......@@ -48,6 +48,7 @@ class ShapeInference(object):
MaceOp.FullyConnected.name: self.infer_shape_fully_connected,
MaceOp.Crop.name: self.infer_shape_crop,
MaceOp.BiasAdd.name: self.infer_shape_general,
MaceOp.ChannelShuffle.name: self.infer_shape_channel_shuffle,
}
self._net = net
......@@ -220,3 +221,7 @@ class ShapeInference(object):
mace_check(len(op.input) == 2, "crop layer needs two inputs")
output_shape = self._output_shape_cache[op.input[1]]
self.add_output_shape(op, [output_shape])
def infer_shape_channel_shuffle(self, op):
output_shape = self._output_shape_cache[op.input[0]]
self.add_output_shape(op, [output_shape])
......@@ -421,6 +421,7 @@ message LayerParameter {
optional ThresholdParameter threshold_param = 128;
optional TileParameter tile_param = 138;
optional WindowDataParameter window_data_param = 129;
optional ShuffleChannelParameter shuffle_channel_param = 164;
}
// Message that stores parameters used to apply transformation
......@@ -1439,3 +1440,7 @@ message PReLUParameter {
// Whether or not slope parameters are shared across channels.
optional bool channel_shared = 2 [default = false];
}
message ShuffleChannelParameter {
optional uint32 group = 1[default = 1]; // The number of group
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册