From 92332bbad4f024a3beb4bc7d4fe441be71d60c01 Mon Sep 17 00:00:00 2001 From: eclipsess Date: Tue, 30 Oct 2018 17:15:37 +0800 Subject: [PATCH] temp fix dw3x3 --- src/operators/kernel/central-arm-func/conv_add_arm_func.h | 6 ++++-- .../kernel/central-arm-func/conv_add_bn_relu_arm_func.h | 6 ++++-- src/operators/kernel/central-arm-func/conv_arm_func.h | 6 ++++-- .../kernel/central-arm-func/conv_bn_add_relu_arm_func.h | 6 ++++-- .../kernel/central-arm-func/conv_bn_relu_arm_func.h | 6 ++++-- .../kernel/central-arm-func/depthwise_conv_arm_func.h | 6 ++++-- .../kernel/central-arm-func/dwconv_bn_relu_arm_func.h | 6 ++++-- 7 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d71bc23597..143ce56c86 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -118,13 +118,15 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), param.Bias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { // math::DepthwiseConv3x3(param.Input(), param.Strides(), // param.Paddings(), // param.Filter(), param.Bias(), diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index a7d14fbad1..2d7825ae14 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -118,14 +118,16 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index e7a8c7f52d..14b5992e39 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -124,13 +124,15 @@ void ConvCompute(const ConvParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), nullptr, false); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3) { + param.Filter()->dims()[2] == 3 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), param.Filter(), nullptr, param.Output(), false); } else { diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h index 7c31eed196..67015e14d1 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -122,14 +122,16 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h index c6300f96e1..426e2840b4 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -117,14 +117,16 @@ void ConvBNReluCompute(const FusionConvBNReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index 73170bdab9..c8e969b854 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -30,13 +30,15 @@ void DepthwiseConvCompute(const ConvParam ¶m) { Bias.mutable_data({param.Groups()}); if (param.Groups() == param.Input()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), &Bias, false); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { // math::DepthwiseConv3x3(param.Input(), param.Strides(), // param.Paddings(), // param.Filter(), &Bias, param.Output(), false); diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index b60bf9b4d6..797848365e 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -115,14 +115,16 @@ void DWConvBNReluCompute(const FusionDWConvBNReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[2]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); -- GitLab