From 7d7b1ec6ef2099335943b67d3c7a2fbe81ccb456 Mon Sep 17 00:00:00 2001 From: xiaolil1 Date: Tue, 30 Oct 2018 14:16:46 +0800 Subject: [PATCH] modify pool to enable jit --- paddle/fluid/operators/pool_mkldnn_op.cc | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/pool_mkldnn_op.cc b/paddle/fluid/operators/pool_mkldnn_op.cc index 07df199297d..28beac66b01 100644 --- a/paddle/fluid/operators/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/pool_mkldnn_op.cc @@ -141,9 +141,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { */ auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt, mkldnn::memory::format::any); - + auto propagation = src_md.data.data_type == mkldnn_f32 ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; std::shared_ptr pool_pd = - CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, + CreatePrimitiveDesc(src_md, dst_md, propagation, strides, padding_left_top, padding_right_bottom, ksize, pooling_type, mkldnn_engine, ceil_mode); @@ -164,9 +164,15 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); - pool_p = std::make_shared(*pool_pd, *(src_memory.get()), + if (propagation == mkldnn::prop_kind::forward_training) { + pool_p = std::make_shared(*pool_pd, *(src_memory.get()), *(dst_memory.get()), *workspace_memory); + } else{ + pool_p = std::make_shared(*pool_pd, *(src_memory.get()), + *(dst_memory.get()));//, + //*workspace_memory); + } dev_ctx.SetBlob(key_pool_p, pool_p); @@ -201,12 +207,13 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { private: std::unique_ptr CreatePrimitiveDesc( const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst, + const mkldnn::prop_kind& propagation, const std::vector& stride, const std::vector& padding_left_top, const std::vector& padding_right_bot, const std::vector& kernel, const std::string& pooling_type, const mkldnn::engine& engine, bool ceil_mode) const { auto pool_desc = mkldnn::pooling_forward::desc( - mkldnn::prop_kind::forward, + propagation, pooling_type == "max" ? mkldnn::algorithm::pooling_max : mkldnn::algorithm::pooling_avg, src, dst, stride, kernel, padding_left_top, padding_right_bot, -- GitLab