From 953f8ddf05e7db8b07f2f9b011a3f3efd8d3976f Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Tue, 1 Aug 2017 20:18:46 +0800 Subject: [PATCH] Support groups in NNPACKFunction. --- paddle/function/nnpack/NNPACKConvOp.cpp | 100 +++++++++++++----------- 1 file changed, 53 insertions(+), 47 deletions(-) diff --git a/paddle/function/nnpack/NNPACKConvOp.cpp b/paddle/function/nnpack/NNPACKConvOp.cpp index f0ec77a5d00..00d048eb216 100644 --- a/paddle/function/nnpack/NNPACKConvOp.cpp +++ b/paddle/function/nnpack/NNPACKConvOp.cpp @@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase { public: void init(const FuncConfig& config) override { ConvFunctionBase::init(config); - CHECK_EQ(groups_, (size_t)1); algorithm_ = get_nnp_convolution_algorithm(config.get("algo")); - // algorithm_ = nnp_convolution_algorithm_auto; transform_strategy_ = nnp_convolution_transform_strategy_compute; nnp_status status = nnp_initialize(); CHECK_EQ(status, nnp_status_success); @@ -67,8 +65,7 @@ public: } } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); @@ -91,8 +88,8 @@ public: size_t filterHeight = getFilterHeight(filter); size_t filterWidth = getFilterWidth(filter); size_t outputChannels = output[1]; - // size_t outputHeight = output[2]; - // size_t outputWidth = output[3]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; nnp_size inputSize = {.width = inputWidth, .height = inputHeight}; nnp_padding padding = {.top = (size_t)paddingH(), @@ -171,49 +168,58 @@ public: } } + size_t inputOffset = inputChannels / groups_ * inputHeight * inputWidth; + size_t outputOffset = outputChannels / groups_ * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + if (batchSize == 1) { - nnp_status status = - nnp_convolution_inference(algorithm_, - transform_strategy_, - inputChannels, - outputChannels, - inputSize, - padding, - kernelSize, - outputSubsampling, - inputData, - filterData, - nullptr, /* bias */ - outputData, - bufferPtr, - sizePtr, - nnp_activation_identity, - nullptr, - threadpool_, /* threadpool */ - nullptr); - CHECK_EQ(status, nnp_status_success); + for (size_t g = 0; g < groups_; g++) { + nnp_status status = + nnp_convolution_inference(algorithm_, + transform_strategy_, + inputChannels / groups_, + outputChannels / groups_, + inputSize, + padding, + kernelSize, + outputSubsampling, + inputData + inputOffset * g, + filterData + filterOffset * g, + nullptr, /* bias */ + outputData + outputOffset * g, + bufferPtr, + sizePtr, + nnp_activation_identity, + nullptr, + threadpool_, /* threadpool */ + nullptr); + CHECK_EQ(status, nnp_status_success); + } } else { - // only supports stride = 1 - CHECK_EQ(strideH(), 1); - CHECK_EQ(strideW(), 1); - nnp_status status = nnp_convolution_output(algorithm_, - batchSize, - inputChannels, - outputChannels, - inputSize, - padding, - kernelSize, - inputData, - filterData, - nullptr, /* bias */ - outputData, - bufferPtr, - sizePtr, - nnp_activation_identity, - nullptr, - threadpool_, /* threadpool */ - nullptr); - CHECK_EQ(status, nnp_status_success); + for (size_t g = 0; g < groups_; g++) { + // only supports stride = 1 + CHECK_EQ(strideH(), 1); + CHECK_EQ(strideW(), 1); + nnp_status status = + nnp_convolution_output(algorithm_, + batchSize, + inputChannels / groups_, + outputChannels / groups_, + inputSize, + padding, + kernelSize, + inputData + inputOffset * g, + filterData + filterOffset * g, + nullptr, /* bias */ + outputData + outputOffset * g, + bufferPtr, + sizePtr, + nnp_activation_identity, + nullptr, + threadpool_, /* threadpool */ + nullptr); + CHECK_EQ(status, nnp_status_success); + } } } -- GitLab