From b5514602b6019a4b30515079e4be17bf4276cb19 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Fri, 30 Jun 2017 17:46:48 +0800 Subject: [PATCH] Add the use_nnpack parameter in ExpandConvLayer, so that the convolution calculation can be switched to the NNPACK function. --- paddle/function/nnpack/NNPACKConvOp.cpp | 5 +- paddle/gserver/layers/ExpandConvLayer.cpp | 56 +++++++++++++++-------- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/paddle/function/nnpack/NNPACKConvOp.cpp b/paddle/function/nnpack/NNPACKConvOp.cpp index d75fab04033..e8080c3d714 100644 --- a/paddle/function/nnpack/NNPACKConvOp.cpp +++ b/paddle/function/nnpack/NNPACKConvOp.cpp @@ -70,6 +70,9 @@ public: if (threadpool_) { pthreadpool_destroy(threadpool_); } + if (workspaceBuffer_) { + free(workspaceBuffer_); + } } virtual void check(const BufferArgs& inputs, @@ -160,7 +163,7 @@ public: CHECK_EQ(status, nnp_status_success); } - LOG(INFO) << "workspace size is " << needSize; + VLOG(3) << "workspace size is " << needSize; if (needSize > workspaceSize_) { workspaceSize_ = needSize; if (workspaceBuffer_) { diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index 914689e66cd..29e2113afff 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -16,6 +16,10 @@ limitations under the License. */ #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" +DEFINE_bool(use_nnpack, + false, + "Whether to use nnpack for convolution calculation."); + namespace paddle { /* @@ -37,26 +41,38 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, for (int i = 0; i < config_.inputs_size(); i++) { std::vector paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector strides = {(size_t)strideY_[i], (size_t)stride_[i]}; - createFunction(forward_, - !isDeconv_ ? "GemmConv" : "GemmConvGradInput", - FuncConfig() - .set("paddings", paddings) - .set("strides", strides) - .set("groups", (size_t)groups_[i])); - - createFunction(backward_, - !isDeconv_ ? "GemmConvGradInput" : "GemmConv", - FuncConfig() - .set("paddings", paddings) - .set("strides", strides) - .set("groups", (size_t)groups_[i])); - - createFunction(backward_, - "GemmConvGradFilter", - FuncConfig() - .set("paddings", paddings) - .set("strides", strides) - .set("groups", (size_t)groups_[i])); + + if (FLAGS_use_nnpack) { + CHECK_EQ(isDeconv_, false); + createFunction(forward_, + "NNPACKConv", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i]) + .set("algo", "auto")); + } else { + createFunction(forward_, + !isDeconv_ ? "GemmConv" : "GemmConvGradInput", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + + createFunction(backward_, + !isDeconv_ ? "GemmConvGradInput" : "GemmConv", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + + createFunction(backward_, + "GemmConvGradFilter", + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)groups_[i])); + } } return true; } -- GitLab