You need to sign in or sign up before continuing.
提交 834d1824 编写于 作者: 李寅

Merge branch 'fix-concat-bug' into 'master'

Fix concat and slice aix < 0 bug.

See merge request !485
...@@ -114,8 +114,8 @@ static void Deconv2d(int iters, ...@@ -114,8 +114,8 @@ static void Deconv2d(int iters,
BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW##\ BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW##\
_##P##_##OC##_##TYPE##_##DEVICE) _##P##_##OC##_##TYPE##_##DEVICE)
// TODO(liutuo): add cpu benchmark when optimized.
#define BM_DECONV_2D(N, C, H, W, KH, KW, S, OH, OW, P, OC) \ #define BM_DECONV_2D(N, C, H, W, KH, KW, S, OH, OW, P, OC) \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, CPU); \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, GPU); \ BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, GPU); \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, half, GPU); BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, half, GPU);
......
...@@ -482,6 +482,7 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -482,6 +482,7 @@ class CaffeConverter(base_converter.ConverterInterface):
axis_arg.i = param.axis axis_arg.i = param.axis
elif param.HasField('concat_dim'): elif param.HasField('concat_dim'):
axis_arg.i = param.concat_dim axis_arg.i = param.concat_dim
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
mace_check(axis_arg.i == 1, "only support concat at channel dimension") mace_check(axis_arg.i == 1, "only support concat at channel dimension")
def convert_slice(self, caffe_op): def convert_slice(self, caffe_op):
...@@ -490,7 +491,8 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -490,7 +491,8 @@ class CaffeConverter(base_converter.ConverterInterface):
if caffe_op.layer.HasField('slice_param'): if caffe_op.layer.HasField('slice_param'):
param = caffe_op.layer.slice_param param = caffe_op.layer.slice_param
mace_check(not param.HasField('axis') or param.axis == 1, mace_check(not param.HasField('axis') or param.axis == 1
or param.axis == -3,
"Mace do not support slice with axis %d" % param.axis) "Mace do not support slice with axis %d" % param.axis)
mace_check(len(param.slice_point) == 0, mace_check(len(param.slice_point) == 0,
"Mace do not support slice with slice_point") "Mace do not support slice with slice_point")
...@@ -503,7 +505,8 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -503,7 +505,8 @@ class CaffeConverter(base_converter.ConverterInterface):
param = caffe_op.layer.inner_product_param param = caffe_op.layer.inner_product_param
op.type = MaceOp.FullyConnected.name op.type = MaceOp.FullyConnected.name
mace_check(param.axis == 1 and not param.transpose, mace_check((param.axis == 1 or param.axis == -3)
and not param.transpose,
"Do not support non-default axis and transpose") "Do not support non-default axis and transpose")
mace_check(caffe_op.blobs[0].ndim in [2, 4], mace_check(caffe_op.blobs[0].ndim in [2, 4],
"Unexpected fc weigth ndim.") "Unexpected fc weigth ndim.")
......
...@@ -404,6 +404,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -404,6 +404,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
axis_arg = op.arg.add() axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str axis_arg.name = MaceKeyword.mace_axis_str
axis = tf_op.inputs[-1].eval().astype(np.int32) axis = tf_op.inputs[-1].eval().astype(np.int32)
axis = 4 + axis if axis < 0 else axis
axis_arg.i = axis axis_arg.i = axis
mace_check(axis == 3, "only support concat at channel dimension") mace_check(axis == 3, "only support concat at channel dimension")
......
...@@ -459,7 +459,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -459,7 +459,7 @@ class Transformer(base_converter.ConverterInterface):
padding_arg.i = ConverterUtil.get_arg( padding_arg.i = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_str).i op, MaceKeyword.mace_padding_str).i
elif ConverterUtil.get_arg( elif ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_values_str)\ op, MaceKeyword.mace_padding_values_str) \
is not None: is not None:
padding_arg = wt_op.arg.add() padding_arg = wt_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str padding_arg.name = MaceKeyword.mace_padding_values_str
...@@ -687,7 +687,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -687,7 +687,7 @@ class Transformer(base_converter.ConverterInterface):
or op.type == MaceOp.Deconv2D.name \ or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name: or op.type == MaceOp.DepthwiseConv2d.name:
if ConverterUtil.get_arg( if ConverterUtil.get_arg(
op, MaceKeyword.mace_winograd_filter_transformed)\ op, MaceKeyword.mace_winograd_filter_transformed) \
is None: is None:
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册