diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index 9669a966cc0c68521800d29a6ccdbd86f6e7c5ba..2a8b332521804ccebdbd4e6914b2763abfb5dbdc 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -38,57 +38,14 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { "Operator DNNL Pool must use CPUPlace")); auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); const Tensor* input = ctx.Input("X"); Tensor* output = ctx.Output("Out"); - PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, - "Wrong layout set for Input tensor"); - PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef, - "Wrong format set for Input tensor"); - - std::string pooling_type = ctx.Attr("pooling_type"); - - std::vector ksize_temp = ctx.Attr>("ksize"); - std::vector ksize(begin(ksize_temp), end(ksize_temp)); - - std::vector strides_temp = ctx.Attr>("strides"); - std::vector strides(begin(strides_temp), end(strides_temp)); - - std::vector paddings_temp = ctx.Attr>("paddings"); - std::vector paddings(begin(paddings_temp), end(paddings_temp)); - - bool global_pooling = ctx.Attr("global_pooling"); - std::string padding_algorithm = ctx.Attr("padding_algorithm"); - - // Only 2D pooling is supported now - PADDLE_ENFORCE_EQ(ksize.size(), 2, "ksize must be 2D, i.e. 2D pooling"); - PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true, - "pooling_type must be 'max' or 'avg'"); - PADDLE_ENFORCE_EQ(input->dims().size(), 4, - "Input dim must be with 4, i.e. NCHW"); - - auto input_dims = input->dims(); - framework::DDim data_dims = - framework::slice_ddim(input_dims, 2, input_dims.size()); - - if (global_pooling) { - UpdateKsize(&ksize, data_dims); - } - - UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, data_dims, - strides, ksize); - - auto src_tz = paddle::framework::vectorize(input->dims()); - auto dst_tz = paddle::framework::vectorize(output->dims()); - - auto is_test = ctx.Attr("is_test"); - - platform::PoolingMKLDNNHandler handler( - src_tz, dst_tz, ksize, strides, paddings, pooling_type, - ctx.Attr("ceil_mode"), input->format(), - paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx, - ctx.GetPlace(), ctx.OutputName("Out"), ctx.Attr("exclusive")); + platform::PoolingMKLDNNHandler handler(ctx, dev_ctx, mkldnn_engine, + ctx.GetPlace(), input, output, + ctx.OutputName("Out")); auto src_memory = handler.AcquireSrcMemory(input); auto dst_memory = handler.AcquireDstMemory(output); @@ -96,7 +53,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { auto pool_p = handler.AcquireForwardPrimitive(); mkldnn::stream astream(dev_ctx.GetEngine()); - if ((is_test == false) && (pooling_type == "max")) { + if ((ctx.Attr("is_test") == false) && + (ctx.Attr("pooling_type") == "max")) { // Training auto workspace_memory = handler.AcquireWorkspaceMemory(); pool_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory}, diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 8de2416ea915a946ca69877f8e48e28c25b6c5a5..2d475e7150a73c8e745f267fd60bb0c2bd1d1c8a 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "boost/optional.hpp" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/place.h" @@ -592,41 +593,100 @@ template class PoolingMKLDNNHandler : public MKLDNNHandlerT { public: - PoolingMKLDNNHandler( - const std::vector& src_dims, - const std::vector& dst_dims, const std::vector& ksize, - const std::vector& strides, const std::vector& paddings, - const std::string& pooling_type, bool ceil_mode, - const MKLDNNMemoryFormat fmt, mkldnn::memory::data_type dt, bool is_test, - const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place, - const std::string& unique_name, bool exclude_padding) + PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine mkldnn_engine, + platform::Place cpu_place, const Tensor* input, + Tensor* output, const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(src_dims, dt, unique_name)) { - auto src_md = mkldnn::memory::desc(src_dims, dt, fmt); - /* create memory descriptor for pooling without specified format - * ('any') which lets a primitive (pooling in this case) choose - * the memory format preferred for best performance - */ - auto dst_md = - platform::MKLDNNMemDesc(dst_dims, dt, MKLDNNMemoryFormat::any); + platform::CreateKey(framework::vectorize(input->dims()), + framework::ToMKLDNNDataType(input->type()), + unique_name)) { + if (!this->isCached()) { + PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument( + "Wrong layout set for Input tensor")); + PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument( + "Wrong format set for Input tensor")); + + const std::string pooling_type = ctx.Attr("pooling_type"); + + std::vector ksize_temp = ctx.Attr>("ksize"); + std::vector ksize(begin(ksize_temp), end(ksize_temp)); + + std::vector strides_temp = ctx.Attr>("strides"); + std::vector strides(begin(strides_temp), end(strides_temp)); + + std::vector paddings_temp = ctx.Attr>("paddings"); + std::vector paddings(begin(paddings_temp), end(paddings_temp)); + + const bool global_pooling = ctx.Attr("global_pooling"); + const std::string padding_algorithm = + ctx.Attr("padding_algorithm"); + + // Only 2D pooling is supported now + PADDLE_ENFORCE_EQ(ksize.size(), 2, + platform::errors::InvalidArgument( + "ksize must be 2D, i.e. 2D pooling")); + PADDLE_ENFORCE_EQ(pooling_type == "max" || pooling_type == "avg", true, + platform::errors::InvalidArgument( + "pooling_type must be 'max' or 'avg'")); + PADDLE_ENFORCE_EQ(input->dims().size(), 4, + platform::errors::InvalidArgument( + "Input dim must be with 4, i.e. NCHW")); + + const auto input_dims = input->dims(); + framework::DDim data_dims = + framework::slice_ddim(input_dims, 2, input_dims.size()); + + if (global_pooling) { + operators::UpdateKsize(&ksize, data_dims); + } - auto mkldnn_paddings = ToMkldnnPadding(paddings); + operators::UpdatePadding(&paddings, global_pooling, 0, padding_algorithm, + data_dims, strides, ksize); + + const auto src_tz = paddle::framework::vectorize(input->dims()); + const auto dst_tz = paddle::framework::vectorize(output->dims()); + + const auto is_test = ctx.Attr("is_test"); + + const auto dt = framework::ToMKLDNNDataType(input->type()); + const auto fmt = input->format(); + + const auto exclude_padding = ctx.Attr("exclusive"); + + const auto src_md = mkldnn::memory::desc(src_tz, dt, fmt); + /* create memory descriptor for pooling without specified format + * ('any') which lets a primitive (pooling in this case) choose + * the memory format preferred for best performance + */ + + const auto dst_md = + platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any); - if (ceil_mode) { - CorrectOutputSize(src_dims, dst_dims, ksize, paddings, strides, - mkldnn_paddings[1]); + auto mkldnn_paddings = ToMkldnnPadding(paddings); + + const bool ceil_mode = ctx.Attr("ceil_mode"); + + if (ceil_mode) { + CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, + mkldnn_paddings[1]); + } + this->AcquireForwardPrimitiveDescriptor( + is_test ? mkldnn::prop_kind::forward_inference + : mkldnn::prop_kind::forward_training, + pooling_type == "max" + ? mkldnn::algorithm::pooling_max + : (exclude_padding + ? mkldnn::algorithm::pooling_avg_exclude_padding + : mkldnn::algorithm::pooling_avg_include_padding), + src_md, dst_md, strides, ksize, mkldnn_paddings[0], + mkldnn_paddings[1]); } - this->AcquireForwardPrimitiveDescriptor( - is_test ? mkldnn::prop_kind::forward_inference - : mkldnn::prop_kind::forward_training, - pooling_type == "max" - ? mkldnn::algorithm::pooling_max - : (exclude_padding - ? mkldnn::algorithm::pooling_avg_exclude_padding - : mkldnn::algorithm::pooling_avg_include_padding), - src_md, dst_md, strides, ksize, mkldnn_paddings[0], mkldnn_paddings[1]); } PoolingMKLDNNHandler(