未验证 提交 b6bc6f7a 编写于 作者: J jakpiase 提交者: GitHub

Fix for oneDNN layernorm for begin_norm_axis != last_dim (#43476)

* fix for layer_norm

* minor fix
上级 8727bb7c
...@@ -106,8 +106,10 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -106,8 +106,10 @@ class LayerNormOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
int begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) { this->CanMKLDNNBeUsed(ctx, input_data_type) &&
begin_norm_axis == ctx.Input<Tensor>("X")->dims().size() - 1) {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
......
...@@ -23,7 +23,7 @@ import paddle.fluid as fluid ...@@ -23,7 +23,7 @@ import paddle.fluid as fluid
from paddle import enable_static from paddle import enable_static
from functools import reduce from functools import reduce
from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator, OpTestTool
np.random.random(123) np.random.random(123)
...@@ -136,6 +136,10 @@ class TestLayerNormMKLDNNOp(unittest.TestCase): ...@@ -136,6 +136,10 @@ class TestLayerNormMKLDNNOp(unittest.TestCase):
self.__assert_close(mean, out[1], "mean") self.__assert_close(mean, out[1], "mean")
self.__assert_close(variance, out[2], "variance", 1e-3) self.__assert_close(variance, out[2], "variance", 1e-3)
@OpTestTool.skip_if_not_cpu_bf16()
def test_check_forward_non_last_begin_norm_axis(self):
self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=2)
def test_check_forward_with_scale_and_bias(self): def test_check_forward_with_scale_and_bias(self):
self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=3) self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册