未验证 提交 b6edaff8 编写于 作者: J jakpiase 提交者: GitHub

[Need review] Optimized and refactored oneDNN layer_norm kernel (#36917)

* optimization for layernorm

* further refactoring

* added reviewer suggestions
上级 53690719
......@@ -43,16 +43,20 @@ class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
}
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
std::vector<float>& scaleshift_data) {
// scaleshift_data comes from temporary buffer so we need to copy it into
// created memory primitivie
auto scaleshift_mem =
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(const Tensor* scale,
const Tensor* shift) {
// OneDNN requires a single piece of memory for scale and shift data
const unsigned int C = framework::vectorize(scale->dims())[0];
auto scaleshift_memory =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto data_ptr = scaleshift_mem->get_data_handle();
std::size_t num_bytes = scaleshift_data.size() * sizeof(float);
std::memcpy(data_ptr, scaleshift_data.data(), num_bytes);
return scaleshift_mem;
auto mem_data_handle =
reinterpret_cast<float*>(scaleshift_memory->get_data_handle());
std::copy(scale->data<float>(), scale->data<float>() + C, mem_data_handle);
std::copy(shift->data<float>(), shift->data<float>() + C,
mem_data_handle + C);
return scaleshift_memory;
}
std::shared_ptr<dnnl::memory> AcquireMeanMemory(framework::Tensor* mean) {
......@@ -95,7 +99,6 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"axis:%d as begin_norm_axis.",
(src_tz.size() - 1)));
y->mutable_data<T>(ctx.GetPlace());
const bool with_scaleshift = (scale && bias);
dnnl::normalization_flags flags{};
......@@ -113,16 +116,12 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto layer_norm_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> args;
args.insert({DNNL_ARG_SRC, *src_memory});
args.insert({DNNL_ARG_DST, *dst_memory});
std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory}};
if (!is_test) {
auto* mean = ctx.Output<Tensor>("Mean");
auto* var = ctx.Output<Tensor>("Variance");
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto mean_memory = handler.AcquireMeanMemory(mean);
auto variance_memory = handler.AcquireVarianceMemory(var);
......@@ -131,22 +130,9 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
args.insert({DNNL_ARG_VARIANCE, *variance_memory});
}
std::shared_ptr<mkldnn::memory> scaleshift_memory;
if (with_scaleshift) {
auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias
// data
std::vector<float> scaleshift_data;
scaleshift_data.reserve(2 * C);
scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(),
scale->data<float>() + C);
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(),
bias->data<float>() + C);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
std::shared_ptr<mkldnn::memory> scaleshift_memory =
handler.AcquireScaleShiftMemory(scale, bias);
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册