未验证 提交 b6edaff8 编写于 作者: J jakpiase 提交者: GitHub

[Need review] Optimized and refactored oneDNN layer_norm kernel (#36917)

* optimization for layernorm

* further refactoring

* added reviewer suggestions
上级 53690719
...@@ -43,16 +43,20 @@ class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< ...@@ -43,16 +43,20 @@ class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
} }
} }
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory( std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(const Tensor* scale,
std::vector<float>& scaleshift_data) { const Tensor* shift) {
// scaleshift_data comes from temporary buffer so we need to copy it into // OneDNN requires a single piece of memory for scale and shift data
// created memory primitivie const unsigned int C = framework::vectorize(scale->dims())[0];
auto scaleshift_mem =
auto scaleshift_memory =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc()); this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto data_ptr = scaleshift_mem->get_data_handle();
std::size_t num_bytes = scaleshift_data.size() * sizeof(float); auto mem_data_handle =
std::memcpy(data_ptr, scaleshift_data.data(), num_bytes); reinterpret_cast<float*>(scaleshift_memory->get_data_handle());
return scaleshift_mem; std::copy(scale->data<float>(), scale->data<float>() + C, mem_data_handle);
std::copy(shift->data<float>(), shift->data<float>() + C,
mem_data_handle + C);
return scaleshift_memory;
} }
std::shared_ptr<dnnl::memory> AcquireMeanMemory(framework::Tensor* mean) { std::shared_ptr<dnnl::memory> AcquireMeanMemory(framework::Tensor* mean) {
...@@ -95,7 +99,6 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -95,7 +99,6 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"axis:%d as begin_norm_axis.", "axis:%d as begin_norm_axis.",
(src_tz.size() - 1))); (src_tz.size() - 1)));
y->mutable_data<T>(ctx.GetPlace());
const bool with_scaleshift = (scale && bias); const bool with_scaleshift = (scale && bias);
dnnl::normalization_flags flags{}; dnnl::normalization_flags flags{};
...@@ -113,16 +116,12 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -113,16 +116,12 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto layer_norm_p = handler.AcquireForwardPrimitive(); auto layer_norm_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> args; std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory}};
args.insert({DNNL_ARG_SRC, *src_memory});
args.insert({DNNL_ARG_DST, *dst_memory});
if (!is_test) { if (!is_test) {
auto* mean = ctx.Output<Tensor>("Mean"); auto* mean = ctx.Output<Tensor>("Mean");
auto* var = ctx.Output<Tensor>("Variance"); auto* var = ctx.Output<Tensor>("Variance");
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto mean_memory = handler.AcquireMeanMemory(mean); auto mean_memory = handler.AcquireMeanMemory(mean);
auto variance_memory = handler.AcquireVarianceMemory(var); auto variance_memory = handler.AcquireVarianceMemory(var);
...@@ -131,22 +130,9 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -131,22 +130,9 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
args.insert({DNNL_ARG_VARIANCE, *variance_memory}); args.insert({DNNL_ARG_VARIANCE, *variance_memory});
} }
std::shared_ptr<mkldnn::memory> scaleshift_memory;
if (with_scaleshift) { if (with_scaleshift) {
auto scale_tz = paddle::framework::vectorize(scale->dims()); std::shared_ptr<mkldnn::memory> scaleshift_memory =
const unsigned int C = scale_tz[0]; handler.AcquireScaleShiftMemory(scale, bias);
// MKLDNN requires a single piece of memory for scale and shift/bias
// data
std::vector<float> scaleshift_data;
scaleshift_data.reserve(2 * C);
scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(),
scale->data<float>() + C);
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(),
bias->data<float>() + C);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}); args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册