提交 6898b44c 编写于 作者: L liuqi

Fix concat and slice aix < 0 bug.

上级 5f10b2c3
......@@ -114,8 +114,8 @@ static void Deconv2d(int iters,
BM_DECONV_2D_##N##_##C##_##H##_##W##_##KH##_##KW##_##STRIDE##_##OH##_##OW##\
_##P##_##OC##_##TYPE##_##DEVICE)
// TODO(liutuo): add cpu benchmark when optimized.
#define BM_DECONV_2D(N, C, H, W, KH, KW, S, OH, OW, P, OC) \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, CPU); \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, GPU); \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, half, GPU);
......
......@@ -482,6 +482,7 @@ class CaffeConverter(base_converter.ConverterInterface):
axis_arg.i = param.axis
elif param.HasField('concat_dim'):
axis_arg.i = param.concat_dim
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
mace_check(axis_arg.i == 1, "only support concat at channel dimension")
def convert_slice(self, caffe_op):
......@@ -490,7 +491,8 @@ class CaffeConverter(base_converter.ConverterInterface):
if caffe_op.layer.HasField('slice_param'):
param = caffe_op.layer.slice_param
mace_check(not param.HasField('axis') or param.axis == 1,
mace_check(not param.HasField('axis') or param.axis == 1
or param.axis == -3,
"Mace do not support slice with axis %d" % param.axis)
mace_check(len(param.slice_point) == 0,
"Mace do not support slice with slice_point")
......@@ -503,7 +505,8 @@ class CaffeConverter(base_converter.ConverterInterface):
param = caffe_op.layer.inner_product_param
op.type = MaceOp.FullyConnected.name
mace_check(param.axis == 1 and not param.transpose,
mace_check((param.axis == 1 or param.axis == -3)
and not param.transpose,
"Do not support non-default axis and transpose")
mace_check(caffe_op.blobs[0].ndim in [2, 4],
"Unexpected fc weigth ndim.")
......
......@@ -404,6 +404,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis = tf_op.inputs[-1].eval().astype(np.int32)
axis = 4 + axis if axis < 0 else axis
axis_arg.i = axis
mace_check(axis == 3, "only support concat at channel dimension")
......
......@@ -459,7 +459,7 @@ class Transformer(base_converter.ConverterInterface):
padding_arg.i = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_str).i
elif ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_values_str)\
op, MaceKeyword.mace_padding_values_str) \
is not None:
padding_arg = wt_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
......@@ -687,7 +687,7 @@ class Transformer(base_converter.ConverterInterface):
or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
if ConverterUtil.get_arg(
op, MaceKeyword.mace_winograd_filter_transformed)\
op, MaceKeyword.mace_winograd_filter_transformed) \
is None:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册