From da616a6f2fe22b42faa9aab1caa5f2ff8c875111 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Fri, 11 Aug 2017 14:14:26 +0800 Subject: [PATCH] Fix some bugs. --- paddle/function/ConvOpTest.h | 5 +-- paddle/function/nnpack/NNPACKConvOp.cpp | 41 ++++++++++++------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/paddle/function/ConvOpTest.h b/paddle/function/ConvOpTest.h index d8c3bb03b3a..cb02a96d0db 100644 --- a/paddle/function/ConvOpTest.h +++ b/paddle/function/ConvOpTest.h @@ -202,9 +202,10 @@ void DepthwiseConvolution(const std::string& conv1, for (size_t outputChannels : {32, 64}) { for (size_t stride : {1, 2}) { for (size_t padding : {0, 1}) { - // NNPACK only supports stride = 1 if batchSize > 1 + // NNPACK only supports stride = 1 if batchSize > 1, + // and there has some bug when batchSize > 1 and groups != 1 if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") && - batchSize > 1 && stride > 1) + batchSize > 1) break; size_t outputSize = diff --git a/paddle/function/nnpack/NNPACKConvOp.cpp b/paddle/function/nnpack/NNPACKConvOp.cpp index c9f1ddcd92d..6ccc487cf1c 100644 --- a/paddle/function/nnpack/NNPACKConvOp.cpp +++ b/paddle/function/nnpack/NNPACKConvOp.cpp @@ -201,28 +201,25 @@ public: CHECK_EQ(strideW(), 1); // TODO(hedaoyuan): There has some bug when batchSize > 1 and groups_ > 1. - CHECK_EQ(groups_, (size_t)1); - for (size_t g = 0; g < groups_; g++) { - nnp_status status = - nnp_convolution_output(algorithm_, - batchSize, - inputChannels / groups_, - outputChannels / groups_, - inputSize, - padding, - kernelSize, - inputData + inputOffset * g, - filterData + filterOffset * g, - nullptr, /* bias */ - outputData + outputOffset * g, - bufferPtr, - sizePtr, - nnp_activation_identity, - nullptr, - threadpool_, /* threadpool */ - nullptr); - CHECK_EQ(status, nnp_status_success); - } + CHECK_EQ(groups_, static_cast(1)); + nnp_status status = nnp_convolution_output(algorithm_, + batchSize, + inputChannels, + outputChannels, + inputSize, + padding, + kernelSize, + inputData, + filterData, + nullptr, /* bias */ + outputData, + bufferPtr, + sizePtr, + nnp_activation_identity, + nullptr, + threadpool_, /* threadpool */ + nullptr); + CHECK_EQ(status, nnp_status_success); } } -- GitLab