提交 0b257940 编写于 作者: 李寅

Support TF shufflenet v2

上级 aca4c5e2
......@@ -75,8 +75,8 @@ class SplitOp<DeviceType::CPU, T> : public Operation {
#pragma omp parallel for
for (int outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
int input_idx = outer_idx * input_channels * inner_size;
int output_idx = outer_idx * output_channels * inner_size;
index_t input_idx = outer_idx * input_channels * inner_size;
index_t output_idx = outer_idx * output_channels * inner_size;
for (size_t i = 0; i < outputs_count; ++i) {
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
memcpy(output_ptrs[i]+output_idx, input_ptr+input_idx,
......
......@@ -257,6 +257,7 @@ class TransformerRule(Enum):
FOLD_EMBEDDING_LOOKUP = 35
TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36
FOLD_FC_RESHAPE = 37
TRANSFORM_CHANNEL_SHUFFLE = 38
class ConverterInterface(object):
......@@ -463,6 +464,7 @@ class ConverterOption(object):
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.FOLD_FC_RESHAPE,
TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
# Model data format related transformation
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT,
......
......@@ -100,6 +100,8 @@ class Transformer(base_converter.ConverterInterface):
self.check_quantize_info,
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN:
self.transform_caffe_reshape_and_flatten,
TransformerRule.TRANSFORM_CHANNEL_SHUFFLE:
self.transform_channel_shuffle,
}
self._option = option
......@@ -231,7 +233,7 @@ class Transformer(base_converter.ConverterInterface):
# that the op is identity op and its input is a tensor.
mace_check(len(op.output) == 1 and len(op.input) == 1,
"cannot remove op that w/o replace op specified"
" and input/output length > 1" + str(op))
" and input/output length > 1\n" + str(op))
for consumer_op in self._consumers.get(op.output[0], []):
self.replace(consumer_op.input, op.output[0], op.input[0])
......@@ -1789,3 +1791,45 @@ class Transformer(base_converter.ConverterInterface):
self.safe_remove_node(consumer, None)
return True
return False
def transform_channel_shuffle(self):
net = self._model
for op in net.op:
if op.type == MaceOp.Transpose.name and \
len(op.output_shape[0].dims) == 5:
perm = ConverterUtil.get_arg(op,
MaceKeyword.mace_dims_str).ints
if [0, 1, 2, 4, 3] == list(perm):
# Remove the following Reshape op
reshape_op = self._consumers.get(op.output[0], None)
if (reshape_op and
len(reshape_op) == 1 and
reshape_op[0].type == MaceOp.Reshape.name and
len(reshape_op[0].output_shape[0].dims) == 4):
print("Transform channel shuffle")
output_shape = reshape_op[0].output_shape[0].dims
self.safe_remove_node(reshape_op[0], op,
remove_input_tensor=True)
else:
return False
# Change Transpose op to ChannelShuffle
op.type = MaceOp.ChannelShuffle.name
del op.arg[:]
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
group_arg.i = op.output_shape[0].dims[4]
op.output_shape[0].dims[:] = output_shape
# Remove previous Reshape op
producer_op = self._producer.get(op.input[0], None)
if producer_op:
if producer_op.type == MaceOp.Reshape.name:
self.safe_remove_node(producer_op, None)
elif producer_op.type == MaceOp.Stack.name:
print("Change channel shuffle stack to concat")
# Change previous Stack op to Concat if any
producer_op.type = MaceOp.Concat.name
producer_op.output_shape[0].dims[:] = output_shape
return True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册