diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index 99b8d020436fc1418bd8877dd1fd640ae0bb3994..eb241b9157fecd8615797e276ab53f9f6812e21d 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -35,159 +35,141 @@ using paddle::platform::MKLDNNDeviceContext; using platform::to_void_cast; template -class BatchNormMKLDNNHandler - : public platform::MKLDNNHandlerT { +class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< + T, mkldnn::batch_normalization_forward, + mkldnn::batch_normalization_backward> { public: BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, - const platform::MKLDNNDeviceContext &dev_ctx, - const mkldnn::engine mkldnn_engine, - platform::Place cpu_place, const Tensor *x, - const bool global_stats, const bool test_mode, - const std::string &unique_name) - : platform::MKLDNNHandlerT( - dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), - unique_name)) { - if (!this->isCached()) { - const float epsilon = ctx.Attr("epsilon"); - const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); - - std::vector DataLayout_error_msg = {"kNHWC", "kNCHW", - "kAnyLayout", "kMKLDNN"}; - PADDLE_ENFORCE_EQ( - x->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "Wrong layout set for X tensor. Expected layout is `kMKLDNN`, " - "But received %s.", - DataLayout_error_msg[static_cast(DataLayout::kMKLDNN)])); - PADDLE_ENFORCE_NE( - x->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument("Wrong format set for X tensor")); - - auto src_tz = paddle::framework::vectorize(x->dims()); - - // Flags are added by bitwise OR operation - auto flags = mkldnn::normalization_flags::use_scale_shift; // 001 - if (global_stats) - flags |= mkldnn::normalization_flags::use_global_stats; // 010 - if (fuse_with_relu && test_mode) - flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100 - - auto md = mkldnn::memory::desc( - src_tz, platform::MKLDNNGetDataType(), - platform::MKLDNNFormatForSize(src_tz.size(), x->format())); - - this->AcquireForwardPrimitiveDescriptor( - global_stats == true ? mkldnn::prop_kind::forward_scoring - : mkldnn::prop_kind::forward_training, - md, epsilon, flags); - } + const mkldnn::engine mkldnn_engine, const Tensor *x, + const bool global_stats, const bool test_mode) + : platform::MKLDNNHandlerNoCachingT( + mkldnn_engine, ctx.GetPlace()) { + const float epsilon = ctx.Attr("epsilon"); + const bool fuse_with_relu = ctx.Attr("fuse_with_relu"); + + std::vector DataLayout_error_msg = {"kNHWC", "kNCHW", + "kAnyLayout", "kMKLDNN"}; + PADDLE_ENFORCE_EQ( + x->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument( + "Wrong layout set for X tensor. Expected layout is `kMKLDNN`, " + "But received %s.", + DataLayout_error_msg[static_cast(DataLayout::kMKLDNN)])); + PADDLE_ENFORCE_NE( + x->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for X tensor")); + + auto src_tz = paddle::framework::vectorize(x->dims()); + + // Flags are added by bitwise OR operation + auto flags = mkldnn::normalization_flags::use_scale_shift; // 001 + if (global_stats) + flags |= mkldnn::normalization_flags::use_global_stats; // 010 + if (fuse_with_relu && test_mode) + flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100 + + auto md = mkldnn::memory::desc( + src_tz, platform::MKLDNNGetDataType(), + platform::MKLDNNFormatForSize(src_tz.size(), x->format())); + + this->AcquireForwardPrimitiveDescriptor( + global_stats == true ? mkldnn::prop_kind::forward_scoring + : mkldnn::prop_kind::forward_training, + md, epsilon, flags); } BatchNormMKLDNNHandler(const paddle::framework::ExecutionContext &ctx, - const platform::MKLDNNDeviceContext &dev_ctx, - platform::Place cpu_place, const Tensor *in_x, - const Tensor *scale, const Tensor *out_grad, - const std::string &unique_name) - : platform::MKLDNNHandlerT( - dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()), - unique_name)) { - if (!this->isBwdCached()) { - PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN, - platform::errors::InvalidArgument( - "Wrong layout set for Input out_grad tensor")); - PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef, - platform::errors::InvalidArgument( - "Wrong format set for Input out_grad tensor")); - - auto src_tz = paddle::framework::vectorize(in_x->dims()); - auto scale_tz = paddle::framework::vectorize(scale->dims()); - PADDLE_ENFORCE_EQ( - scale_tz.size(), 1, - platform::errors::InvalidArgument( - "Dims of scale tensor must be 1, but received scale's size is %d", - scale_tz.size())); - - MKLDNNMemoryFormat diff_fmt = - platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format()); - - MKLDNNMemoryFormat src_fmt = - platform::MKLDNNFormatForSize(src_tz.size(), in_x->format()); - - auto dims = framework::vectorize(in_x->dims()); - auto diff_dst_md = mkldnn::memory::desc( - dims, platform::MKLDNNGetDataType(), diff_fmt); - auto src_md = - mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), src_fmt); - - const float epsilon = ctx.Attr("epsilon"); - - this->AcquireForwardPrimitiveDescriptor( - mkldnn::prop_kind::forward_training, src_md, epsilon, - mkldnn::normalization_flags::use_scale_shift); - this->AcquireBackwardPrimitiveDescriptor( - mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, - mkldnn::normalization_flags::use_scale_shift); - } + const mkldnn::engine mkldnn_engine, const Tensor *in_x, + const Tensor *scale, const Tensor *out_grad) + : platform::MKLDNNHandlerNoCachingT( + mkldnn_engine, ctx.GetPlace()) { + PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument( + "Wrong layout set for Input out_grad tensor")); + PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument( + "Wrong format set for Input out_grad tensor")); + + auto src_tz = paddle::framework::vectorize(in_x->dims()); + auto scale_tz = paddle::framework::vectorize(scale->dims()); + PADDLE_ENFORCE_EQ( + scale_tz.size(), 1, + platform::errors::InvalidArgument( + "Dims of scale tensor must be 1, but received scale's size is %d", + scale_tz.size())); + + MKLDNNMemoryFormat diff_fmt = + platform::MKLDNNFormatForSize(src_tz.size(), out_grad->format()); + + MKLDNNMemoryFormat src_fmt = + platform::MKLDNNFormatForSize(src_tz.size(), in_x->format()); + + auto dims = framework::vectorize(in_x->dims()); + auto diff_dst_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), diff_fmt); + auto src_md = + mkldnn::memory::desc(dims, platform::MKLDNNGetDataType(), src_fmt); + + const float epsilon = ctx.Attr("epsilon"); + + this->AcquireForwardPrimitiveDescriptor( + mkldnn::prop_kind::forward_training, src_md, epsilon, + mkldnn::normalization_flags::use_scale_shift); + this->AcquireBackwardPrimitiveDescriptor( + mkldnn::prop_kind::backward, diff_dst_md, src_md, epsilon, + mkldnn::normalization_flags::use_scale_shift); } std::shared_ptr AcquireScaleShiftMemory(const Tensor *scale, - const Tensor *shift, - const bool is_test) { - auto scaleshift_memory = this->AcquireMemory("@scaleshift_mem_p"); - if (scaleshift_memory == nullptr || !is_test) { - auto scale_tz = paddle::framework::vectorize(scale->dims()); - const unsigned int C = scale_tz[0]; - PADDLE_ENFORCE_EQ( - scale_tz.size(), 1, - platform::errors::InvalidArgument( - "Dims of scale tensor must be 1, but received scale's size is %d", - scale_tz.size())); - - auto mem_p = this->AcquireMemoryFromPrimitive( - this->fwd_pd_->weights_desc(), "@scaleshift_mem_p"); - - // MKLDNN requires a single piece of memory for scale and shift/bias data - auto mem_data_handle = reinterpret_cast(mem_p->get_data_handle()); - std::copy(scale->data(), scale->data() + C, mem_data_handle); - std::copy(shift->data(), shift->data() + C, mem_data_handle + C); - - return mem_p; - } + const Tensor *shift) { + auto scale_tz = paddle::framework::vectorize(scale->dims()); + const unsigned int C = scale_tz[0]; + PADDLE_ENFORCE_EQ( + scale_tz.size(), 1, + platform::errors::InvalidArgument( + "Dims of scale tensor must be 1, but received scale's size is %d", + scale_tz.size())); + + auto scaleshift_memory = + this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); + + // MKLDNN requires a single piece of memory for scale and shift/bias data + auto mem_data_handle = + reinterpret_cast(scaleshift_memory->get_data_handle()); + std::copy(scale->data(), scale->data() + C, mem_data_handle); + std::copy(shift->data(), shift->data() + C, mem_data_handle + C); return scaleshift_memory; } std::shared_ptr AcquireDiffScaleShiftMemory( T *diff_scaleshift_data) { return this->AcquireMemoryFromPrimitive(this->bwd_pd_->diff_weights_desc(), - diff_scaleshift_data, - "@diff_scaleshift_mem_p"); + diff_scaleshift_data); } std::shared_ptr AcquireMeanMemory( const framework::Tensor *mean) { const T *mean_data = mean->data(); - return this->AcquireMemoryFromPrimitive( - this->fwd_pd_->mean_desc(), to_void_cast(mean_data), "@mean_mem_p"); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), + to_void_cast(mean_data)); } std::shared_ptr AcquireMeanMemory(framework::Tensor *mean) { T *mean_data = mean->mutable_data(this->place_, this->fwd_pd_->mean_desc().get_size()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), - mean_data, "@mean_mem_p"); + mean_data); } std::shared_ptr AcquireVarianceMemory( const framework::Tensor *variance) { const T *variance_data = variance->data(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), - to_void_cast(variance_data), - "@variance_mem_p"); + to_void_cast(variance_data)); } std::shared_ptr AcquireVarianceMemory( @@ -195,7 +177,7 @@ class BatchNormMKLDNNHandler T *variance_data = variance->mutable_data( this->place_, this->fwd_pd_->variance_desc().get_size()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), - variance_data, "@variance_mem_p"); + variance_data); } }; @@ -220,13 +202,11 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { auto *batch_mean = ctx.Output("SavedMean"); auto *batch_variance = ctx.Output("SavedVariance"); - BatchNormMKLDNNHandler handler(ctx, dev_ctx, mkldnn_engine, - ctx.GetPlace(), x, global_stats, - test_mode, ctx.OutputName("SavedMean")); + BatchNormMKLDNNHandler handler(ctx, mkldnn_engine, x, global_stats, + test_mode); auto src_memory = handler.AcquireSrcMemory(x); - auto scaleshift_memory = - handler.AcquireScaleShiftMemory(scale, shift, is_test); + auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift); auto dst_memory = handler.AcquireDstMemory(y); auto batch_norm_p = handler.AcquireForwardPrimitive(); @@ -303,8 +283,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto *diff_scale = ctx.Output(framework::GradVarName("Scale")); auto *diff_shift = ctx.Output(framework::GradVarName("Bias")); - BatchNormMKLDNNHandler handler(ctx, dev_ctx, ctx.GetPlace(), x, scale, - diff_y, ctx.InputName("SavedMean")); + BatchNormMKLDNNHandler handler(ctx, mkldnn_engine, x, scale, diff_y); // MKLDNN requires a single piece of memory for scale and shift/bias data const unsigned int C = paddle::framework::vectorize(scale->dims())[0]; @@ -316,8 +295,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto mean_memory = handler.AcquireMeanMemory(batch_mean); auto variance_memory = handler.AcquireVarianceMemory(batch_variance); auto diff_dst_memory = handler.AcquireDiffDstMemory(diff_y); - auto scaleshift_memory = - handler.AcquireScaleShiftMemory(scale, shift, false); + auto scaleshift_memory = handler.AcquireScaleShiftMemory(scale, shift); auto diff_src_memory = handler.AcquireDiffSrcMemory(diff_x); auto diff_scaleshift_memory = handler.AcquireDiffScaleShiftMemory(diff_scaleshift_data.data()); diff --git a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc index 9d80286f4c4efa54ce83ca6148399d0875d64dc0..90f0de60b592deb4d2a26befeca7302b7fc6c87f 100644 --- a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc @@ -30,27 +30,21 @@ using platform::to_void_cast; template class InterpolateMKLDNNHandler - : public platform::MKLDNNHandlerT { + : public platform::MKLDNNHandlerNoCachingT { public: InterpolateMKLDNNHandler(const dnnl::algorithm algo, - const platform::MKLDNNDeviceContext& dev_ctx, const dnnl::engine engine, platform::Place cpu_place, - const Tensor* x, Tensor* z, - const std::string& uniq_name) - : platform::MKLDNNHandlerT( - dev_ctx, engine, cpu_place, - platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), - uniq_name)) { - if (!this->isCached()) { - const auto src_x_tz = framework::vectorize(x->dims()); - const auto dst_tz = framework::vectorize(z->dims()); - const auto src_md = dnnl::memory::desc( - src_x_tz, platform::MKLDNNGetDataType(), x->format()); - const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType(), - MKLDNNMemoryFormat::any); - this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_inference, algo, src_md, dst_md); - } + const Tensor* x, Tensor* z) + : platform::MKLDNNHandlerNoCachingT( + engine, cpu_place) { + const auto src_x_tz = framework::vectorize(x->dims()); + const auto dst_tz = framework::vectorize(z->dims()); + const auto src_md = dnnl::memory::desc( + src_x_tz, platform::MKLDNNGetDataType(), x->format()); + const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_inference, + algo, src_md, dst_md); } }; @@ -145,7 +139,6 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto* x = ctx.Input("X"); - std::vector scale_prior; auto* z = ctx.Output("Out"); auto interp_method = ctx.Attr("interp_method"); @@ -155,11 +148,10 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { auto out_dims_vec = ComputeOutputShape(ctx); framework::DDim dim_out = framework::make_ddim(out_dims_vec); - z->mutable_data(dim_out, ctx.GetPlace()); + z->Resize(dim_out); - InterpolateMKLDNNHandler handler(algo, dev_ctx, mkldnn_engine, - ctx.GetPlace(), x, z, - ctx.OutputName("Out")); + InterpolateMKLDNNHandler handler(algo, mkldnn_engine, ctx.GetPlace(), x, + z); auto src_memory_p = handler.AcquireSrcMemory(x); auto dst_memory_p = handler.AcquireDstMemory(z);