From 0b2579402b03e75de817f7bb03a4469ef36937df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Tue, 25 Dec 2018 16:46:31 +0800 Subject: [PATCH] Support TF shufflenet v2 --- mace/ops/split.cc | 4 +- .../tools/converter_tool/base_converter.py | 2 + .../tools/converter_tool/transformer.py | 46 ++++++++++++++++++- 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/mace/ops/split.cc b/mace/ops/split.cc index 0f9dcc04..d7f33965 100644 --- a/mace/ops/split.cc +++ b/mace/ops/split.cc @@ -75,8 +75,8 @@ class SplitOp : public Operation { #pragma omp parallel for for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) { - int input_idx = outer_idx * input_channels * inner_size; - int output_idx = outer_idx * output_channels * inner_size; + index_t input_idx = outer_idx * input_channels * inner_size; + index_t output_idx = outer_idx * output_channels * inner_size; for (size_t i = 0; i < outputs_count; ++i) { if (DataTypeCanUseMemcpy(DataTypeToEnum::v())) { memcpy(output_ptrs[i]+output_idx, input_ptr+input_idx, diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index cb28e2e1..9bfd6909 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -257,6 +257,7 @@ class TransformerRule(Enum): FOLD_EMBEDDING_LOOKUP = 35 TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36 FOLD_FC_RESHAPE = 37 + TRANSFORM_CHANNEL_SHUFFLE = 38 class ConverterInterface(object): @@ -463,6 +464,7 @@ class ConverterOption(object): TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC, TransformerRule.RESHAPE_FC_WEIGHT, TransformerRule.FOLD_FC_RESHAPE, + TransformerRule.TRANSFORM_CHANNEL_SHUFFLE, # Model data format related transformation TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_DATA_FORMAT, diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index ec4a0376..49cba5b8 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -100,6 +100,8 @@ class Transformer(base_converter.ConverterInterface): self.check_quantize_info, TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN: self.transform_caffe_reshape_and_flatten, + TransformerRule.TRANSFORM_CHANNEL_SHUFFLE: + self.transform_channel_shuffle, } self._option = option @@ -231,7 +233,7 @@ class Transformer(base_converter.ConverterInterface): # that the op is identity op and its input is a tensor. mace_check(len(op.output) == 1 and len(op.input) == 1, "cannot remove op that w/o replace op specified" - " and input/output length > 1" + str(op)) + " and input/output length > 1\n" + str(op)) for consumer_op in self._consumers.get(op.output[0], []): self.replace(consumer_op.input, op.output[0], op.input[0]) @@ -1789,3 +1791,45 @@ class Transformer(base_converter.ConverterInterface): self.safe_remove_node(consumer, None) return True return False + + def transform_channel_shuffle(self): + net = self._model + for op in net.op: + if op.type == MaceOp.Transpose.name and \ + len(op.output_shape[0].dims) == 5: + perm = ConverterUtil.get_arg(op, + MaceKeyword.mace_dims_str).ints + if [0, 1, 2, 4, 3] == list(perm): + # Remove the following Reshape op + reshape_op = self._consumers.get(op.output[0], None) + if (reshape_op and + len(reshape_op) == 1 and + reshape_op[0].type == MaceOp.Reshape.name and + len(reshape_op[0].output_shape[0].dims) == 4): + print("Transform channel shuffle") + output_shape = reshape_op[0].output_shape[0].dims + self.safe_remove_node(reshape_op[0], op, + remove_input_tensor=True) + else: + return False + + # Change Transpose op to ChannelShuffle + op.type = MaceOp.ChannelShuffle.name + del op.arg[:] + group_arg = op.arg.add() + group_arg.name = MaceKeyword.mace_group_str + group_arg.i = op.output_shape[0].dims[4] + op.output_shape[0].dims[:] = output_shape + + # Remove previous Reshape op + producer_op = self._producer.get(op.input[0], None) + if producer_op: + if producer_op.type == MaceOp.Reshape.name: + self.safe_remove_node(producer_op, None) + elif producer_op.type == MaceOp.Stack.name: + print("Change channel shuffle stack to concat") + # Change previous Stack op to Concat if any + producer_op.type = MaceOp.Concat.name + producer_op.output_shape[0].dims[:] = output_shape + + return True -- GitLab