diff --git a/paddle/function/neon/NeonDepthwiseConv.cpp b/paddle/function/neon/NeonDepthwiseConv.cpp index 3fe28b1de37462f1968290797eabf54aa5e1f220..f09e98587d1681d29a79a9cb0303c2d4356c6935 100644 --- a/paddle/function/neon/NeonDepthwiseConv.cpp +++ b/paddle/function/neon/NeonDepthwiseConv.cpp @@ -509,10 +509,9 @@ public: size_t filterMultiplier = outputChannels / groups_; CHECK_EQ(inputChannels, groups_); - // only support + // only support strideH() == strideW() and filterHeight == filterWidth. CHECK_EQ(strideH(), strideW()); CHECK_EQ(filterHeight, filterWidth); - CHECK_LT(strideH(), size_t(3)); float* inputData = inputs[0].data(); float* filterData = inputs[1].data(); @@ -538,49 +537,32 @@ public: inputWidth += 2 * paddingW(); } - for (size_t i = 0; i < batchSize; i++) { - if (filterWidth == 3 && strideH() == 1) { - DepthwiseConvKernel<3, 1>::run(inputPadding, - filterData, - inputHeight, - inputWidth, - outputChannels, - outputHeight, - outputWidth, - filterMultiplier, - outputData); - } else if (filterWidth == 3 && strideH() == 2) { - DepthwiseConvKernel<3, 2>::run(inputPadding, - filterData, - inputHeight, - inputWidth, - outputChannels, - outputHeight, - outputWidth, - filterMultiplier, - outputData); - } else if (filterWidth == 4 && strideH() == 1) { - DepthwiseConvKernel<4, 1>::run(inputPadding, - filterData, - inputHeight, - inputWidth, - outputChannels, - outputHeight, - outputWidth, - filterMultiplier, - outputData); - } else if (filterWidth == 4 && strideH() == 2) { - DepthwiseConvKernel<4, 2>::run(inputPadding, - filterData, - inputHeight, - inputWidth, - outputChannels, - outputHeight, - outputWidth, - filterMultiplier, - outputData); - } + std::function + DepthWiseConv; + + if (filterWidth == 3 && strideW() == 1) { + DepthWiseConv = DepthwiseConvKernel<3, 1>::run; + } else if (filterWidth == 3 && strideW() == 2) { + DepthWiseConv = DepthwiseConvKernel<3, 2>::run; + } else if (filterWidth == 4 && strideW() == 1) { + DepthWiseConv = DepthwiseConvKernel<4, 1>::run; + } else if (filterWidth == 4 && strideW() == 2) { + DepthWiseConv = DepthwiseConvKernel<4, 2>::run; + } else { + LOG(FATAL) << "Not supported"; + } + for (size_t i = 0; i < batchSize; i++) { + DepthWiseConv(inputPadding, + filterData, + inputHeight, + inputWidth, + outputChannels, + outputHeight, + outputWidth, + filterMultiplier, + outputData); inputPadding += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; }