diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index 3ff2fa231e2b461481a41c966b65831a78c955bc..56e8b645089758a23d7d70017e9032b3a6273432 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -187,6 +187,7 @@ class CaffeConverter(base_converter.ConverterInterface): 'BatchNorm': self.convert_folded_batchnorm, 'Crop': self.convert_crop, 'Scale': self.convert_scale, + 'ShuffleChannel': self.convert_channel_shuffle, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -656,3 +657,14 @@ class CaffeConverter(base_converter.ConverterInterface): ConverterUtil.add_data_format_arg(biasadd_op, DataFormat.NCHW) + + def convert_channel_shuffle(self, caffe_op): + op = self.convert_general_op(caffe_op) + param = caffe_op.layer.shuffle_channel_param + op.type = MaceOp.ChannelShuffle.name + + group_arg = op.arg.add() + group_arg.name = MaceKeyword.mace_group_str + group_arg.i = 1 + if param.HasField('group'): + group_arg.i = param.group diff --git a/mace/python/tools/converter_tool/shape_inference.py b/mace/python/tools/converter_tool/shape_inference.py index da6541f499459532e9ead7b277cb1192a9b3fdd1..aeb19022badc324855d82820a61d5b9a2a7f1cb1 100644 --- a/mace/python/tools/converter_tool/shape_inference.py +++ b/mace/python/tools/converter_tool/shape_inference.py @@ -48,6 +48,7 @@ class ShapeInference(object): MaceOp.FullyConnected.name: self.infer_shape_fully_connected, MaceOp.Crop.name: self.infer_shape_crop, MaceOp.BiasAdd.name: self.infer_shape_general, + MaceOp.ChannelShuffle.name: self.infer_shape_channel_shuffle, } self._net = net @@ -220,3 +221,7 @@ class ShapeInference(object): mace_check(len(op.input) == 2, "crop layer needs two inputs") output_shape = self._output_shape_cache[op.input[1]] self.add_output_shape(op, [output_shape]) + + def infer_shape_channel_shuffle(self, op): + output_shape = self._output_shape_cache[op.input[0]] + self.add_output_shape(op, [output_shape]) diff --git a/third_party/caffe/caffe.proto b/third_party/caffe/caffe.proto index cec617b99b8a5cde9e93e2bb14be0cab21794908..54ccf20ca2378f7e15881930333f0014c1923b63 100644 --- a/third_party/caffe/caffe.proto +++ b/third_party/caffe/caffe.proto @@ -421,6 +421,7 @@ message LayerParameter { optional ThresholdParameter threshold_param = 128; optional TileParameter tile_param = 138; optional WindowDataParameter window_data_param = 129; + optional ShuffleChannelParameter shuffle_channel_param = 164; } // Message that stores parameters used to apply transformation @@ -1439,3 +1440,7 @@ message PReLUParameter { // Whether or not slope parameters are shared across channels. optional bool channel_shared = 2 [default = false]; } + +message ShuffleChannelParameter { + optional uint32 group = 1[default = 1]; // The number of group +}