From b6bc6f7ae4a45ba4f857d93d196f6a4e919ee30b Mon Sep 17 00:00:00 2001 From: jakpiase Date: Mon, 20 Jun 2022 09:23:42 +0200 Subject: [PATCH] Fix for oneDNN layernorm for begin_norm_axis != last_dim (#43476) * fix for layer_norm * minor fix --- paddle/fluid/operators/layer_norm_op.cc | 4 +++- .../tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 3d1e563ef1a..d6421cf541d 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -106,8 +106,10 @@ class LayerNormOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN + int begin_norm_axis = ctx.Attr("begin_norm_axis"); if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { + this->CanMKLDNNBeUsed(ctx, input_data_type) && + begin_norm_axis == ctx.Input("X")->dims().size() - 1) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py index d36b5cc9e64..98e44f8f745 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py @@ -23,7 +23,7 @@ import paddle.fluid as fluid from paddle import enable_static from functools import reduce -from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator +from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator, OpTestTool np.random.random(123) @@ -136,6 +136,10 @@ class TestLayerNormMKLDNNOp(unittest.TestCase): self.__assert_close(mean, out[1], "mean") self.__assert_close(variance, out[2], "variance", 1e-3) + @OpTestTool.skip_if_not_cpu_bf16() + def test_check_forward_non_last_begin_norm_axis(self): + self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=2) + def test_check_forward_with_scale_and_bias(self): self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=3) -- GitLab