diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index 3bead16ce44c26b9d7a6f2a5c6b471612494d595..0a18882e8199c2a375a230a693b8b01d12aabfa0 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -36,6 +36,14 @@ std::shared_ptr insert_to_context(const std::string& key, return p; } + +template +void run_primitive(Args&&... args) { + auto forward_op = mkldnn::lrn_forward{args...}; + + std::vector pipeline = {forward_op}; + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); +} } // namespace template @@ -87,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, static_cast(output_data)}; - std::unique_ptr forward_op = nullptr; - if (!is_test) { const std::string key = ctx.op().Output("Out"); const std::string key_src_memory = key + "@lrn_src_memory"; @@ -108,9 +114,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { key_workspace_memory, dev_ctx, forward_pd->workspace_primitive_desc()); - forward_op.reset(new mkldnn::lrn_forward{*forward_pd, *src_memory, - *workspace_memory, dst_memory}); - + run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); } else { auto forward_pd = mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; @@ -119,12 +123,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto workspace_memory = mkldnn::memory{forward_pd.workspace_primitive_desc()}; - forward_op.reset(new mkldnn::lrn_forward{forward_pd, src_memory, - workspace_memory, dst_memory}); + run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); } - - std::vector pipeline = {*forward_op}; - mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); } }; @@ -136,6 +136,9 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { "MKLDNN LRN must use float data."); PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "MKLDNN LRN must use CPUPlace."); + PADDLE_ENFORCE( + !ctx.Attr("is_test"), + "is_test attribute should be set to False in training phase."); auto x = ctx.Input("X"); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 2b1947a187bbd17871107553127647032ac7d7f9..b36b5c3a339bd7e534bcc3eb7a2efef313cb2a5d 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -155,8 +155,8 @@ class LRNOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4."); ctx->SetOutputDim("Out", x_dim); - ctx->SetOutputDim("MidOut", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); + ctx->SetOutputDim("MidOut", x_dim); } framework::OpKernelType GetExpectedKernelType( diff --git a/python/paddle/fluid/tests/unittests/test_lrn_op.py b/python/paddle/fluid/tests/unittests/test_lrn_op.py index 2268eafdbd08cd0d6a175d19cedd79b7b984289b..8fa480b9bce84d2936f23cce9e41e8e54014b074 100644 --- a/python/paddle/fluid/tests/unittests/test_lrn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lrn_op.py @@ -97,5 +97,24 @@ class TestLRNMKLDNNOp(TestLRNOp): self.check_output(atol=0.002) +class TestLRNMKLDNNOpWithIsTest(TestLRNMKLDNNOp): + def get_attrs(self): + attrs = TestLRNMKLDNNOp.get_attrs(self) + attrs['is_test'] = True + return attrs + + def test_check_grad_normal(self): + def check_raise_is_test(): + try: + self.check_grad(['X'], 'Out', max_relative_error=0.01) + except Exception as e: + t = \ + "is_test attribute should be set to False in training phase." + if t in str(e): + raise AttributeError + + self.assertRaises(AttributeError, check_raise_is_test) + + if __name__ == "__main__": unittest.main()