From 4b6b7251c10371ceceb84c55ebc587715591c436 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 31 Aug 2017 21:29:49 +0800 Subject: [PATCH] Refine NeonDepthwiseConv. --- paddle/function/neon/NeonDepthwiseConv.cpp | 35 +++++++++++----------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/paddle/function/neon/NeonDepthwiseConv.cpp b/paddle/function/neon/NeonDepthwiseConv.cpp index 7e5f752a0b2..3d502f5d6d1 100644 --- a/paddle/function/neon/NeonDepthwiseConv.cpp +++ b/paddle/function/neon/NeonDepthwiseConv.cpp @@ -18,8 +18,6 @@ limitations under the License. */ namespace paddle { -namespace neon { - #if defined(__ARM_NEON__) || defined(__ARM_NEON) template @@ -45,16 +43,16 @@ public: const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); - size_t batchSize = input[0]; - size_t inputChannels = input[1]; - size_t inputHeight = input[2]; - size_t inputWidth = input[3]; - size_t filterHeight = getFilterHeight(filter); - size_t filterWidth = getFilterWidth(filter); - size_t outputChannels = output[1]; - size_t outputHeight = output[2]; - size_t outputWidth = output[3]; - size_t filterMultiplier = outputChannels / groups_; + int batchSize = input[0]; + int inputChannels = input[1]; + int inputHeight = input[2]; + int inputWidth = input[3]; + int filterHeight = getFilterHeight(filter); + int filterWidth = getFilterWidth(filter); + int outputChannels = output[1]; + int outputHeight = output[2]; + int outputWidth = output[3]; + int filterMultiplier = outputChannels / groups_; CHECK_EQ(inputChannels, groups_); // only support strideH() == strideW() and filterHeight == filterWidth. @@ -90,18 +88,18 @@ public: DepthWiseConv; if (filterWidth == 3 && strideW() == 1) { - DepthWiseConv = DepthwiseConvKernel<3, 1>::run; + DepthWiseConv = neon::DepthwiseConvKernel<3, 1>::run; } else if (filterWidth == 3 && strideW() == 2) { - DepthWiseConv = DepthwiseConvKernel<3, 2>::run; + DepthWiseConv = neon::DepthwiseConvKernel<3, 2>::run; } else if (filterWidth == 4 && strideW() == 1) { - DepthWiseConv = DepthwiseConvKernel<4, 1>::run; + DepthWiseConv = neon::DepthwiseConvKernel<4, 1>::run; } else if (filterWidth == 4 && strideW() == 2) { - DepthWiseConv = DepthwiseConvKernel<4, 2>::run; + DepthWiseConv = neon::DepthwiseConvKernel<4, 2>::run; } else { LOG(FATAL) << "Not supported"; } - for (size_t i = 0; i < batchSize; i++) { + for (int i = 0; i < batchSize; i++) { DepthWiseConv(inputPadding, filterData, inputHeight, @@ -117,9 +115,10 @@ public: } }; +#ifndef PADDLE_TYPE_DOUBLE REGISTER_TYPED_FUNC(NeonDepthwiseConv, CPU, NeonDepthwiseConvFunction); +#endif #endif -} // namespace neon } // namespace paddle -- GitLab