From 6898b44c426c0dd2b1bf14d9b313514a4d515b2f Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 16 May 2018 18:34:44 +0800 Subject: [PATCH] Fix concat and slice aix < 0 bug. --- mace/ops/deconv_2d_benchmark.cc | 2 +- mace/python/tools/converter_tool/caffe_converter.py | 7 +++++-- mace/python/tools/converter_tool/tensorflow_converter.py | 1 + mace/python/tools/converter_tool/transformer.py | 4 ++-- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mace/ops/deconv_2d_benchmark.cc b/mace/ops/deconv_2d_benchmark.cc index a25e2dae..4305b575 100644 --- a/mace/ops/deconv_2d_benchmark.cc +++ b/mace/ops/deconv_2d_benchmark.cc @@ -114,8 +114,8 @@ static void Deconv2d(int iters, BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW##\ _##P##_##OC##_##TYPE##_##DEVICE) +// TODO(liutuo): add cpu benchmark when optimized. #define BM_DECONV_2D(N, C, H, W, KH, KW, S, OH, OW, P, OC) \ - BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, CPU); \ BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, GPU); \ BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, half, GPU); diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index 6af87764..25637df5 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -482,6 +482,7 @@ class CaffeConverter(base_converter.ConverterInterface): axis_arg.i = param.axis elif param.HasField('concat_dim'): axis_arg.i = param.concat_dim + axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i mace_check(axis_arg.i == 1, "only support concat at channel dimension") def convert_slice(self, caffe_op): @@ -490,7 +491,8 @@ class CaffeConverter(base_converter.ConverterInterface): if caffe_op.layer.HasField('slice_param'): param = caffe_op.layer.slice_param - mace_check(not param.HasField('axis') or param.axis == 1, + mace_check(not param.HasField('axis') or param.axis == 1 + or param.axis == -3, "Mace do not support slice with axis %d" % param.axis) mace_check(len(param.slice_point) == 0, "Mace do not support slice with slice_point") @@ -503,7 +505,8 @@ class CaffeConverter(base_converter.ConverterInterface): param = caffe_op.layer.inner_product_param op.type = MaceOp.FullyConnected.name - mace_check(param.axis == 1 and not param.transpose, + mace_check((param.axis == 1 or param.axis == -3) + and not param.transpose, "Do not support non-default axis and transpose") mace_check(caffe_op.blobs[0].ndim in [2, 4], "Unexpected fc weigth ndim.") diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 1d961938..6f2247e6 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -404,6 +404,7 @@ class TensorflowConverter(base_converter.ConverterInterface): axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str axis = tf_op.inputs[-1].eval().astype(np.int32) + axis = 4 + axis if axis < 0 else axis axis_arg.i = axis mace_check(axis == 3, "only support concat at channel dimension") diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 5ccd3697..14b76893 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -459,7 +459,7 @@ class Transformer(base_converter.ConverterInterface): padding_arg.i = ConverterUtil.get_arg( op, MaceKeyword.mace_padding_str).i elif ConverterUtil.get_arg( - op, MaceKeyword.mace_padding_values_str)\ + op, MaceKeyword.mace_padding_values_str) \ is not None: padding_arg = wt_op.arg.add() padding_arg.name = MaceKeyword.mace_padding_values_str @@ -687,7 +687,7 @@ class Transformer(base_converter.ConverterInterface): or op.type == MaceOp.Deconv2D.name \ or op.type == MaceOp.DepthwiseConv2d.name: if ConverterUtil.get_arg( - op, MaceKeyword.mace_winograd_filter_transformed)\ + op, MaceKeyword.mace_winograd_filter_transformed) \ is None: filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( -- GitLab