diff --git a/paddle/function/nnpack/NNPACKConvOp.cpp b/paddle/function/nnpack/NNPACKConvOp.cpp index f0ec77a5d00333993427fb8d0bc938c884e50c95..00d048eb216baf37c875c870a31cfd55a97f2974 100644 --- a/paddle/function/nnpack/NNPACKConvOp.cpp +++ b/paddle/function/nnpack/NNPACKConvOp.cpp @@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase { public: void init(const FuncConfig& config) override { ConvFunctionBase::init(config); - CHECK_EQ(groups_, (size_t)1); algorithm_ = get_nnp_convolution_algorithm(config.get("algo")); - // algorithm_ = nnp_convolution_algorithm_auto; transform_strategy_ = nnp_convolution_transform_strategy_compute; nnp_status status = nnp_initialize(); CHECK_EQ(status, nnp_status_success); @@ -67,8 +65,7 @@ public: } } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); @@ -91,8 +88,8 @@ public: size_t filterHeight = getFilterHeight(filter); size_t filterWidth = getFilterWidth(filter); size_t outputChannels = output[1]; - // size_t outputHeight = output[2]; - // size_t outputWidth = output[3]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; nnp_size inputSize = {.width = inputWidth, .height = inputHeight}; nnp_padding padding = {.top = (size_t)paddingH(), @@ -171,49 +168,58 @@ public: } } + size_t inputOffset = inputChannels / groups_ * inputHeight * inputWidth; + size_t outputOffset = outputChannels / groups_ * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + if (batchSize == 1) { - nnp_status status = - nnp_convolution_inference(algorithm_, - transform_strategy_, - inputChannels, - outputChannels, - inputSize, - padding, - kernelSize, - outputSubsampling, - inputData, - filterData, - nullptr, /* bias */ - outputData, - bufferPtr, - sizePtr, - nnp_activation_identity, - nullptr, - threadpool_, /* threadpool */ - nullptr); - CHECK_EQ(status, nnp_status_success); + for (size_t g = 0; g < groups_; g++) { + nnp_status status = + nnp_convolution_inference(algorithm_, + transform_strategy_, + inputChannels / groups_, + outputChannels / groups_, + inputSize, + padding, + kernelSize, + outputSubsampling, + inputData + inputOffset * g, + filterData + filterOffset * g, + nullptr, /* bias */ + outputData + outputOffset * g, + bufferPtr, + sizePtr, + nnp_activation_identity, + nullptr, + threadpool_, /* threadpool */ + nullptr); + CHECK_EQ(status, nnp_status_success); + } } else { - // only supports stride = 1 - CHECK_EQ(strideH(), 1); - CHECK_EQ(strideW(), 1); - nnp_status status = nnp_convolution_output(algorithm_, - batchSize, - inputChannels, - outputChannels, - inputSize, - padding, - kernelSize, - inputData, - filterData, - nullptr, /* bias */ - outputData, - bufferPtr, - sizePtr, - nnp_activation_identity, - nullptr, - threadpool_, /* threadpool */ - nullptr); - CHECK_EQ(status, nnp_status_success); + for (size_t g = 0; g < groups_; g++) { + // only supports stride = 1 + CHECK_EQ(strideH(), 1); + CHECK_EQ(strideW(), 1); + nnp_status status = + nnp_convolution_output(algorithm_, + batchSize, + inputChannels / groups_, + outputChannels / groups_, + inputSize, + padding, + kernelSize, + inputData + inputOffset * g, + filterData + filterOffset * g, + nullptr, /* bias */ + outputData + outputOffset * g, + bufferPtr, + sizePtr, + nnp_activation_identity, + nullptr, + threadpool_, /* threadpool */ + nullptr); + CHECK_EQ(status, nnp_status_success); + } } } diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index 783e02e47cb91e28eb88b079f1e94439d34fa775..0ece2799318ea5ecc91f97f71289d4d07246dcaa 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -57,8 +57,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, convGradFilterType = "GemmConvGradFilter"; } - if (FLAGS_use_nnpack) { - CHECK_EQ(isDeconv_, false); + if (FLAGS_use_nnpack && !isDeconv_) { createFunction(forward_, "NNPACKConv", FuncConfig()