From faa456100c0682c8cd873baa4b9dc32feb8bdd0e Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Thu, 28 Feb 2019 11:08:43 +0800 Subject: [PATCH] use gemm to s1p0 instead of s1p1 --- src/operators/kernel/central-arm-func/conv_add_arm_func.h | 3 ++- src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h | 3 ++- .../kernel/central-arm-func/conv_bn_add_relu_arm_func.h | 3 ++- src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d6aa5052dd..0051fc9ae8 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -121,7 +121,8 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.paddings_[0] == 1) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), param.Bias(), true, false); } else if (param.Groups() == param.Input()->dims()[1] && diff --git a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h index 04a84fc976..9f8e885a31 100644 --- a/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_relu_arm_func.h @@ -125,7 +125,8 @@ void ConvAddReluCompute(const FusionConvAddReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.paddings_[0] == 1) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), param.Bias(), true, true); } else if (param.Groups() == param.Input()->dims()[1] && diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h index caaf467141..1ff51aa39c 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -122,7 +122,8 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.paddings_[0] == 1) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h index 7eeb7f7667..5606eb3304 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -120,7 +120,8 @@ void ConvBNReluCompute(const FusionConvBNReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.paddings_[0] == 1) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); -- GitLab