提交 ac177d61 编写于 作者: Z Zhang, Guoming

merge is_test feature for pooling op(mkldnn)

上级 fa164241
......@@ -87,6 +87,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
bool is_test = ctx.Attr<bool>("is_test");
if (ctx.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
......@@ -145,17 +147,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
CreatePrimitiveDesc(src_md, dst_md, propagation, strides, padding_left_top,
padding_right_bottom, ksize, pooling_type,
mkldnn_engine, ceil_mode);
mkldnn_engine, ceil_mode,is_test);
// save pool_pd into global device context to be referred in backward path
dev_ctx.SetBlob(key_pool_pd, pool_pd);
std::shared_ptr<mkldnn::memory> workspace_memory =
CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine);
// save pool_workspace_memory to be referred in backward path
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
to_void_cast<T>(input_data));
auto dst_memory =
......@@ -164,14 +160,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory);
if (propagation == mkldnn::prop_kind::forward_training) {
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
*(dst_memory.get()),
*workspace_memory);
} else{
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
*(dst_memory.get()));//,
//*workspace_memory);
if (is_test) {
pool_p = std::make_shared<pooling_forward>(
*pool_pd, *(src_memory.get()), *(dst_memory.get()));
} else {
std::shared_ptr<mkldnn::memory> workspace_memory =
CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine);
// save pool_workspace_memory to be referred in backward path
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
pool_p = std::make_shared<pooling_forward>(
*pool_pd, *(src_memory.get()), *(dst_memory.get()),
*workspace_memory);
}
dev_ctx.SetBlob(key_pool_p, pool_p);
......@@ -211,17 +210,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& stride, const std::vector<int>& padding_left_top,
const std::vector<int>& padding_right_bot, const std::vector<int>& kernel,
const std::string& pooling_type, const mkldnn::engine& engine,
bool ceil_mode) const {
auto pool_desc = mkldnn::pooling_forward::desc(
propagation,
bool ceil_mode, bool is_test) const {
auto mkldnn_forward_prop_kind = is_test
? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
std::cout<<is_test<<" "<<__LINE__<<std::endl;
auto pool_desc = mkldnn::pooling_forward::desc(
mkldnn_forward_prop_kind,
pooling_type == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg,
src, dst, stride, kernel, padding_left_top, padding_right_bot,
mkldnn::padding_kind::zero);
auto p_pool_pd =
new mkldnn::pooling_forward::primitive_desc(pool_desc, engine);
return std::unique_ptr<mkldnn::pooling_forward::primitive_desc>(p_pool_pd);
auto p_pool_pd =
new mkldnn::pooling_forward::primitive_desc(pool_desc, engine);
return std::unique_ptr<mkldnn::pooling_forward::primitive_desc>(p_pool_pd);
}
std::unique_ptr<mkldnn::memory> CreateWorkspaceMemory(
......@@ -233,7 +237,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
: mkldnn::memory::primitive_desc({{},
platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::nchw},
engine);
engine);
auto p_workspace_memory = new mkldnn::memory(workspace_md);
return std::unique_ptr<mkldnn::memory>(p_workspace_memory);
......
......@@ -206,6 +206,11 @@ void Pool2dOpMaker::Make() {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddAttr<bool>("is_test",
"(bool, default false) If true, the forward pass is not "
"part of training."
"MKL-DNN might be faster if this is set to true.")
.SetDefault(false);
// TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册