提交 b5514602 编写于 作者: H hedaoyuan

Add the use_nnpack parameter in ExpandConvLayer, so that the convolution...

Add the use_nnpack parameter in ExpandConvLayer, so that the convolution calculation can be switched to the NNPACK function.
上级 cdf8d990
......@@ -70,6 +70,9 @@ public:
if (threadpool_) {
pthreadpool_destroy(threadpool_);
}
if (workspaceBuffer_) {
free(workspaceBuffer_);
}
}
virtual void check(const BufferArgs& inputs,
......@@ -160,7 +163,7 @@ public:
CHECK_EQ(status, nnp_status_success);
}
LOG(INFO) << "workspace size is " << needSize;
VLOG(3) << "workspace size is " << needSize;
if (needSize > workspaceSize_) {
workspaceSize_ = needSize;
if (workspaceBuffer_) {
......
......@@ -16,6 +16,10 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
DEFINE_bool(use_nnpack,
false,
"Whether to use nnpack for convolution calculation.");
namespace paddle {
/*
......@@ -37,26 +41,38 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
createFunction(forward_,
!isDeconv_ ? "GemmConv" : "GemmConvGradInput",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
!isDeconv_ ? "GemmConvGradInput" : "GemmConv",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
"GemmConvGradFilter",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
if (FLAGS_use_nnpack) {
CHECK_EQ(isDeconv_, false);
createFunction(forward_,
"NNPACKConv",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i])
.set("algo", "auto"));
} else {
createFunction(forward_,
!isDeconv_ ? "GemmConv" : "GemmConvGradInput",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
!isDeconv_ ? "GemmConvGradInput" : "GemmConv",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
"GemmConvGradFilter",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
}
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册