提交 b5514602 编写于 作者: H hedaoyuan

Add the use_nnpack parameter in ExpandConvLayer, so that the convolution...

Add the use_nnpack parameter in ExpandConvLayer, so that the convolution calculation can be switched to the NNPACK function.
上级 cdf8d990
...@@ -70,6 +70,9 @@ public: ...@@ -70,6 +70,9 @@ public:
if (threadpool_) { if (threadpool_) {
pthreadpool_destroy(threadpool_); pthreadpool_destroy(threadpool_);
} }
if (workspaceBuffer_) {
free(workspaceBuffer_);
}
} }
virtual void check(const BufferArgs& inputs, virtual void check(const BufferArgs& inputs,
...@@ -160,7 +163,7 @@ public: ...@@ -160,7 +163,7 @@ public:
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
} }
LOG(INFO) << "workspace size is " << needSize; VLOG(3) << "workspace size is " << needSize;
if (needSize > workspaceSize_) { if (needSize > workspaceSize_) {
workspaceSize_ = needSize; workspaceSize_ = needSize;
if (workspaceBuffer_) { if (workspaceBuffer_) {
......
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
DEFINE_bool(use_nnpack,
false,
"Whether to use nnpack for convolution calculation.");
namespace paddle { namespace paddle {
/* /*
...@@ -37,26 +41,38 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, ...@@ -37,26 +41,38 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for (int i = 0; i < config_.inputs_size(); i++) { for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]}; std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
createFunction(forward_,
!isDeconv_ ? "GemmConv" : "GemmConvGradInput", if (FLAGS_use_nnpack) {
FuncConfig() CHECK_EQ(isDeconv_, false);
.set("paddings", paddings) createFunction(forward_,
.set("strides", strides) "NNPACKConv",
.set("groups", (size_t)groups_[i])); FuncConfig()
.set("paddings", paddings)
createFunction(backward_, .set("strides", strides)
!isDeconv_ ? "GemmConvGradInput" : "GemmConv", .set("groups", (size_t)groups_[i])
FuncConfig() .set("algo", "auto"));
.set("paddings", paddings) } else {
.set("strides", strides) createFunction(forward_,
.set("groups", (size_t)groups_[i])); !isDeconv_ ? "GemmConv" : "GemmConvGradInput",
FuncConfig()
createFunction(backward_, .set("paddings", paddings)
"GemmConvGradFilter", .set("strides", strides)
FuncConfig() .set("groups", (size_t)groups_[i]));
.set("paddings", paddings)
.set("strides", strides) createFunction(backward_,
.set("groups", (size_t)groups_[i])); !isDeconv_ ? "GemmConvGradInput" : "GemmConv",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
"GemmConvGradFilter",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
}
} }
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册