diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc index cc4bfbae2665fe7030dccd48b28c8819164e68c7..4553b4f8ef9193e5a67a10b45ef4857ff42e0398 100644 --- a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc @@ -19,45 +19,36 @@ namespace paddle { namespace operators { template -class LayerNormMKLDNNHandler - : public platform::MKLDNNHandlerT { +class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< + T, dnnl::layer_normalization_forward> { public: LayerNormMKLDNNHandler(const std::vector& dims, const float& epsilon, const dnnl::normalization_flags& flags, const bool& is_test, const MKLDNNMemoryFormat fmt, - const platform::MKLDNNDeviceContext& dev_ctx, - platform::Place cpu_place, - const std::string& uniq_name) - : platform::MKLDNNHandlerT( - dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(dev_ctx, dims, uniq_name)) { - if (!this->isCached()) { - auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); - if (!is_test) { - // TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced - auto stats_md = dnnl::memory::desc( - {begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType(), - platform::MKLDNNFormatForSize(dims.size() - 1, - MKLDNNMemoryFormat::nchw)); - this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_training, md, stats_md, epsilon, flags); - } else { - this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_inference, md, epsilon, flags); - } + const mkldnn::engine engine, platform::Place cpu_place) + : platform::MKLDNNHandlerNoCachingT( + mkldnn_engine, cpu_place) { + auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + if (!is_test) { + // TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced + auto stats_md = dnnl::memory::desc( + {begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType(), + platform::MKLDNNFormatForSize(dims.size() - 1, + MKLDNNMemoryFormat::nchw)); + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, + md, stats_md, epsilon, flags); + } else { + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_inference, md, epsilon, flags); } } - std::shared_ptr AcquireScaleShiftMemory() { - return this->AcquireMemoryFromPrimitive("@scaleshift_mem_p"); - } - std::shared_ptr AcquireScaleShiftMemory( std::vector& scaleshift_data) { // scaleshift_data comes from temporary buffer so we need to copy it into // created memory primitivie - auto scaleshift_mem = this->AcquireMemoryFromPrimitive( - this->fwd_pd_->weights_desc(), "@scaleshift_mem_p"); + auto scaleshift_mem = + this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); auto data_ptr = scaleshift_mem->get_data_handle(); std::size_t num_bytes = scaleshift_data.size() * sizeof(float); std::memcpy(data_ptr, scaleshift_data.data(), num_bytes); @@ -68,7 +59,7 @@ class LayerNormMKLDNNHandler T* mean_data = mean->mutable_data(this->place_, this->fwd_pd_->mean_desc().get_size()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), - mean_data, "@mean_mem_p"); + mean_data); } std::shared_ptr AcquireVarianceMemory( @@ -76,7 +67,7 @@ class LayerNormMKLDNNHandler T* variance_data = variance->mutable_data( this->place_, this->fwd_pd_->variance_desc().get_size()); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), - variance_data, "@variance_mem_p"); + variance_data); } }; @@ -95,6 +86,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); auto src_tz = paddle::framework::vectorize(x->dims()); PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1), @@ -112,8 +104,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { } LayerNormMKLDNNHandler handler(src_tz, epsilon, flags, is_test, - x->format(), dev_ctx, ctx.GetPlace(), - ctx.OutputName("Y")); + x->format(), mkldnn_engine, + ctx.GetPlace()); auto src_memory = handler.AcquireSrcMemory(x); auto dst_memory = handler.AcquireDstMemory(y); @@ -139,24 +131,22 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { args.insert({DNNL_ARG_VARIANCE, *variance_memory}); } - auto scaleshift_memory = handler.AcquireScaleShiftMemory(); + auto scaleshift_memory = nullptr; if (with_scaleshift) { - if (scaleshift_memory == nullptr || !is_test) { - auto scale_tz = paddle::framework::vectorize(scale->dims()); - const unsigned int C = scale_tz[0]; - - // MKLDNN requires a single piece of memory for scale and shift/bias - // data - std::vector scaleshift_data; - scaleshift_data.reserve(2 * C); - scaleshift_data.insert(scaleshift_data.begin(), scale->data(), - scale->data() + C); - - scaleshift_data.insert(scaleshift_data.end(), bias->data(), - bias->data() + C); - - scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data); - } + auto scale_tz = paddle::framework::vectorize(scale->dims()); + const unsigned int C = scale_tz[0]; + + // MKLDNN requires a single piece of memory for scale and shift/bias + // data + std::vector scaleshift_data; + scaleshift_data.reserve(2 * C); + scaleshift_data.insert(scaleshift_data.begin(), scale->data(), + scale->data() + C); + + scaleshift_data.insert(scaleshift_data.end(), bias->data(), + bias->data() + C); + + scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data); args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}); }