diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index a2971fcd144698e96d82dc8a38a412eed45dbff7..3bead16ce44c26b9d7a6f2a5c6b471612494d595 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -22,6 +22,22 @@ namespace operators { using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; +namespace { +template +std::shared_ptr insert_to_context(const std::string& key, + const MKLDNNDeviceContext& dev_ctx, + Args&&... args) { + auto p = std::static_pointer_cast(dev_ctx.GetBlob(key)); + + if (!p) { + p = std::make_shared(args...); + dev_ctx.SetBlob(key, std::static_pointer_cast(p)); + } + + return p; +} +} // namespace + template class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -42,15 +58,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto output_data = out->mutable_data(ctx.GetPlace()); mid->mutable_data(ctx.GetPlace()); - const std::string key = ctx.op().Output("Out"); - const std::string key_src_memory = key + "@lrn_src_memory"; - const std::string key_pd = key + "@lrn_pd"; - const std::string key_workspace_memory = key + "@lrn_workspace_memory"; - const int n = ctx.Attr("n"); const float alpha = ctx.Attr("alpha"); const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k"); + const bool is_test = ctx.Attr("is_test"); auto e_mid = framework::EigenTensor::From(*mid); e_mid = e_mid.constant(k); @@ -71,28 +83,47 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { beta, k}; - auto forward_pd = std::make_shared( - forward_desc, mkldnn_engine); - - dev_ctx.SetBlob(key_pd, forward_pd); - auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; - auto src_memory = std::make_shared( - src_memory_pd, static_cast(const_cast(input_data))); - - dev_ctx.SetBlob(key_src_memory, src_memory); auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, static_cast(output_data)}; - auto workspace_md = forward_pd->workspace_primitive_desc(); - auto workspace_memory = std::make_shared(workspace_md); + std::unique_ptr forward_op = nullptr; + + if (!is_test) { + const std::string key = ctx.op().Output("Out"); + const std::string key_src_memory = key + "@lrn_src_memory"; + const std::string key_pd = key + "@lrn_pd"; + const std::string key_workspace_memory = key + "@lrn_workspace_memory"; + + auto forward_pd = insert_to_context( + key_pd, dev_ctx, forward_desc, mkldnn_engine); + + auto src_memory = insert_to_context( + key_src_memory, dev_ctx, src_memory_pd); + + src_memory->set_data_handle( + static_cast(const_cast(input_data))); + + auto workspace_memory = insert_to_context( + key_workspace_memory, dev_ctx, + forward_pd->workspace_primitive_desc()); + + forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory, + *workspace_memory, dst_memory}); - dev_ctx.SetBlob(key_workspace_memory, workspace_memory); + } else { + auto forward_pd = + mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; + auto src_memory = mkldnn::memory{ + src_memory_pd, static_cast(const_cast(input_data))}; + auto workspace_memory = + mkldnn::memory{forward_pd.workspace_primitive_desc()}; - auto forward_op = mkldnn::lrn_forward{*forward_pd, *src_memory, - *workspace_memory, dst_memory}; + forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory, + workspace_memory, dst_memory}); + } - std::vector pipeline = {forward_op}; + std::vector pipeline = {*forward_op}; mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index bd72f0435e524613f4adac9de5d9da88c4249f25..2b1947a187bbd17871107553127647032ac7d7f9 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -214,6 +214,7 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); + AddAttr("is_test", "").SetDefault(false); AddComment(R"DOC( Local Response Normalization Operator.