diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc index e84266caa227c9f75c8127a6d6cab254bc601d23..8ab4612ff04b504b31fddaf245ce4127e5edc8ae 100644 --- a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc @@ -43,16 +43,20 @@ class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< } } - std::shared_ptr AcquireScaleShiftMemory( - std::vector& scaleshift_data) { - // scaleshift_data comes from temporary buffer so we need to copy it into - // created memory primitivie - auto scaleshift_mem = + std::shared_ptr AcquireScaleShiftMemory(const Tensor* scale, + const Tensor* shift) { + // OneDNN requires a single piece of memory for scale and shift data + const unsigned int C = framework::vectorize(scale->dims())[0]; + + auto scaleshift_memory = this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); - auto data_ptr = scaleshift_mem->get_data_handle(); - std::size_t num_bytes = scaleshift_data.size() * sizeof(float); - std::memcpy(data_ptr, scaleshift_data.data(), num_bytes); - return scaleshift_mem; + + auto mem_data_handle = + reinterpret_cast(scaleshift_memory->get_data_handle()); + std::copy(scale->data(), scale->data() + C, mem_data_handle); + std::copy(shift->data(), shift->data() + C, + mem_data_handle + C); + return scaleshift_memory; } std::shared_ptr AcquireMeanMemory(framework::Tensor* mean) { @@ -95,7 +99,6 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { "axis:%d as begin_norm_axis.", (src_tz.size() - 1))); - y->mutable_data(ctx.GetPlace()); const bool with_scaleshift = (scale && bias); dnnl::normalization_flags flags{}; @@ -113,16 +116,12 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { auto layer_norm_p = handler.AcquireForwardPrimitive(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - std::unordered_map args; - - args.insert({DNNL_ARG_SRC, *src_memory}); - args.insert({DNNL_ARG_DST, *dst_memory}); + std::unordered_map args = {{DNNL_ARG_SRC, *src_memory}, + {DNNL_ARG_DST, *dst_memory}}; if (!is_test) { auto* mean = ctx.Output("Mean"); auto* var = ctx.Output("Variance"); - mean->mutable_data(ctx.GetPlace()); - var->mutable_data(ctx.GetPlace()); auto mean_memory = handler.AcquireMeanMemory(mean); auto variance_memory = handler.AcquireVarianceMemory(var); @@ -131,22 +130,9 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { args.insert({DNNL_ARG_VARIANCE, *variance_memory}); } - std::shared_ptr scaleshift_memory; if (with_scaleshift) { - auto scale_tz = paddle::framework::vectorize(scale->dims()); - const unsigned int C = scale_tz[0]; - - // MKLDNN requires a single piece of memory for scale and shift/bias - // data - std::vector scaleshift_data; - scaleshift_data.reserve(2 * C); - scaleshift_data.insert(scaleshift_data.begin(), scale->data(), - scale->data() + C); - - scaleshift_data.insert(scaleshift_data.end(), bias->data(), - bias->data() + C); - - scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data); + std::shared_ptr scaleshift_memory = + handler.AcquireScaleShiftMemory(scale, bias); args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}); }