From ac177d6148a1d1648780127491eec2fcc47d031f Mon Sep 17 00:00:00 2001 From: "Zhang, Guoming" Date: Tue, 6 Nov 2018 10:26:26 +0800 Subject: [PATCH] merge is_test feature for pooling op(mkldnn) --- paddle/fluid/operators/pool_mkldnn_op.cc | 48 +++++++++++++----------- paddle/fluid/operators/pool_op.cc | 5 +++ 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index 28beac66b01..0b13a7394d6 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -87,6 +87,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); + bool is_test = ctx.Attr("is_test"); + if (ctx.Attr("global_pooling")) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; @@ -145,17 +147,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr pool_pd = CreatePrimitiveDesc(src_md, dst_md, propagation, strides, padding_left_top, padding_right_bottom, ksize, pooling_type, - mkldnn_engine, ceil_mode); + mkldnn_engine, ceil_mode,is_test); // save pool_pd into global device context to be referred in backward path dev_ctx.SetBlob(key_pool_pd, pool_pd); - std::shared_ptr workspace_memory = - CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); - - // save pool_workspace_memory to be referred in backward path - dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); - auto src_memory = std::make_shared(pool_pd->src_primitive_desc(), to_void_cast(input_data)); auto dst_memory = @@ -164,14 +160,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); - if (propagation == mkldnn::prop_kind::forward_training) { - pool_p = std::make_shared(*pool_pd, *(src_memory.get()), - *(dst_memory.get()), - *workspace_memory); - } else{ - pool_p = std::make_shared(*pool_pd, *(src_memory.get()), - *(dst_memory.get()));//, - //*workspace_memory); + if (is_test) { + pool_p = std::make_shared( + *pool_pd, *(src_memory.get()), *(dst_memory.get())); + } else { + std::shared_ptr workspace_memory = + CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine); + // save pool_workspace_memory to be referred in backward path + dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); + pool_p = std::make_shared( + *pool_pd, *(src_memory.get()), *(dst_memory.get()), + *workspace_memory); } dev_ctx.SetBlob(key_pool_p, pool_p); @@ -211,17 +210,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& stride, const std::vector& padding_left_top, const std::vector& padding_right_bot, const std::vector& kernel, const std::string& pooling_type, const mkldnn::engine& engine, - bool ceil_mode) const { - auto pool_desc = mkldnn::pooling_forward::desc( - propagation, + bool ceil_mode, bool is_test) const { + + auto mkldnn_forward_prop_kind = is_test + ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training; + std::cout<(p_pool_pd); + auto p_pool_pd = + new mkldnn::pooling_forward::primitive_desc(pool_desc, engine); + return std::unique_ptr(p_pool_pd); } std::unique_ptr CreateWorkspaceMemory( @@ -233,7 +237,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { : mkldnn::memory::primitive_desc({{}, platform::MKLDNNGetDataType(), mkldnn::memory::format::nchw}, - engine); + engine); auto p_workspace_memory = new mkldnn::memory(workspace_md); return std::unique_ptr(p_workspace_memory); diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 484cb657466..4208b128730 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -206,6 +206,11 @@ void Pool2dOpMaker::Make() { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); + AddAttr("is_test", + "(bool, default false) If true, the forward pass is not " + "part of training." + "MKL-DNN might be faster if this is set to true.") + .SetDefault(false); // TODO(dzhwinter): need to registered layout transform function AddComment(R"DOC( -- GitLab