From 6d8d3d4c22ba0bbed57912ca831a26e5340d1c92 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 17 Nov 2020 11:59:10 +0100 Subject: [PATCH] [oneDNN] Layer norm bf16 kernel (#28619) --- .../framework/ir/graph_pattern_detector.cc | 4 +- paddle/fluid/operators/layer_norm_op.cc | 35 ++++ .../operators/mkldnn/layer_norm_mkldnn_op.cc | 177 ++++++++++++++++++ paddle/fluid/platform/mkldnn_reuse.h | 6 + .../mkldnn/test_layer_norm_bf16_mkldnn_op.py | 146 +++++++++++++++ .../mkldnn/test_layer_norm_mkldnn_op.py | 151 +++++++++++++++ .../mkldnn/test_sum_bf16_mkldnn_op.py | 2 +- .../tests/unittests/test_layer_norm_op.py | 11 +- tools/static_mode_white_list.py | 2 + 9 files changed, 528 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_bf16_mkldnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5704dd09cf2..5546a0e3726 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2102,8 +2102,8 @@ PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = std::unordered_set({"concat", "conv2d", "fusion_gru", "gelu", - "reshape2", "softmax", "sum", - "transpose2"}); + "layer_norm", "reshape2", "softmax", + "sum", "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 89d8b57505d..79e3d3b90a9 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -15,6 +15,10 @@ limitations under the License. */ #include "paddle/fluid/operators/layer_norm_op.h" #include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace operators { @@ -91,6 +95,25 @@ class LayerNormOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Variance", {left}); ctx->ShareLoD("X", "Y"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout, library); + } }; class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { @@ -134,6 +157,18 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { "greater than zero. But received [%d].", begin_norm_axis)); }); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddComment(R"DOC( Assume feature vectors exist on dimensions diff --git a/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc new file mode 100644 index 00000000000..22261e948aa --- /dev/null +++ b/paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc @@ -0,0 +1,177 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +template +class LayerNormMKLDNNHandler + : public platform::MKLDNNHandlerT { + public: + LayerNormMKLDNNHandler(const std::vector& dims, const float& epsilon, + const dnnl::normalization_flags& flags, + const bool& is_test, const MKLDNNMemoryFormat fmt, + const platform::MKLDNNDeviceContext& dev_ctx, + platform::Place cpu_place, + const std::string& uniq_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(dims, uniq_name)) { + if (!this->isCached()) { + auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType(), fmt); + if (!is_test) { + // TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced + auto stats_md = dnnl::memory::desc( + {begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType(), + platform::MKLDNNFormatForSize(dims.size() - 1, + MKLDNNMemoryFormat::nchw)); + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_training, md, stats_md, epsilon, flags); + } else { + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_inference, md, epsilon, flags); + } + } + } + + std::shared_ptr AcquireScaleShiftMemory() { + return this->AcquireMemoryFromPrimitive("@scaleshift_mem_p"); + } + + std::shared_ptr AcquireScaleShiftMemory( + std::vector& scaleshift_data) { + // scaleshift_data comes from temporary buffer so we need to copy it into + // created memory primitivie + auto scaleshift_mem = this->AcquireMemoryFromPrimitive( + this->fwd_pd_->weights_desc(), "@scaleshift_mem_p"); + auto data_ptr = scaleshift_mem->get_data_handle(); + std::size_t num_bytes = scaleshift_data.size() * sizeof(float); + std::memcpy(data_ptr, scaleshift_data.data(), num_bytes); + return scaleshift_mem; + } + + std::shared_ptr AcquireMeanMemory(framework::Tensor* mean) { + T* mean_data = mean->mutable_data(this->place_, + this->fwd_pd_->mean_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), + mean_data, "@mean_mem_p"); + } + + std::shared_ptr AcquireVarianceMemory( + framework::Tensor* variance) { + T* variance_data = variance->mutable_data( + this->place_, this->fwd_pd_->variance_desc().get_size()); + return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), + variance_data, "@variance_mem_p"); + } +}; + +template +class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* y = ctx.Output("Y"); + + const float epsilon = ctx.Attr("epsilon"); + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + const bool is_test = ctx.Attr("is_test"); + + auto& dev_ctx = + ctx.template device_context(); + + auto src_tz = paddle::framework::vectorize(x->dims()); + PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1), + platform::errors::InvalidArgument( + "MKL-DNN Layer Norm supports only last logical " + "axis:%d as begin_norm_axis.", + (src_tz.size() - 1))); + + y->mutable_data(ctx.GetPlace()); + const bool with_scaleshift = (scale && bias); + dnnl::normalization_flags flags{}; + + if (with_scaleshift) { + flags |= dnnl::normalization_flags::use_scale_shift; + } + + LayerNormMKLDNNHandler handler(src_tz, epsilon, flags, is_test, + x->format(), dev_ctx, ctx.GetPlace(), + ctx.OutputName("Y")); + + auto src_memory = handler.AcquireSrcMemory(x); + auto dst_memory = handler.AcquireDstMemory(y); + + auto layer_norm_p = handler.AcquireForwardPrimitive(); + + dnnl::stream astream(dev_ctx.GetEngine()); + std::unordered_map args; + + args.insert({DNNL_ARG_SRC, *src_memory}); + args.insert({DNNL_ARG_DST, *dst_memory}); + + if (!is_test) { + auto* mean = ctx.Output("Mean"); + auto* var = ctx.Output("Variance"); + mean->mutable_data(ctx.GetPlace()); + var->mutable_data(ctx.GetPlace()); + + auto mean_memory = handler.AcquireMeanMemory(mean); + auto variance_memory = handler.AcquireVarianceMemory(var); + + args.insert({DNNL_ARG_MEAN, *mean_memory}); + args.insert({DNNL_ARG_VARIANCE, *variance_memory}); + } + + auto scaleshift_memory = handler.AcquireScaleShiftMemory(); + if (with_scaleshift) { + if (scaleshift_memory == nullptr || !is_test) { + auto scale_tz = paddle::framework::vectorize(scale->dims()); + const unsigned int C = scale_tz[0]; + + // MKLDNN requires a single piece of memory for scale and shift/bias + // data + std::vector scaleshift_data; + scaleshift_data.reserve(2 * C); + scaleshift_data.insert(scaleshift_data.begin(), scale->data(), + scale->data() + C); + + scaleshift_data.insert(scaleshift_data.end(), bias->data(), + bias->data() + C); + + scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data); + } + args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}); + } + + layer_norm_p->execute(astream, args); + astream.wait(); + + y->set_layout(DataLayout::kMKLDNN); + y->set_format(platform::GetMKLDNNFormat(*dst_memory)); + } +}; + +} // namespace operators +} // namespace paddle + +// TODO(jczaja): Enable FP32 when performance is good +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(layer_norm, MKLDNN, ::paddle::platform::CPUPlace, + ops::LayerNormMKLDNNOpKernel); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 54f8cb1dc88..8649b90321c 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -190,6 +190,12 @@ class MKLDNNHandlerT { } } + std::shared_ptr AcquireMemoryFromPrimitive( + const std::string& suffix) { + return std::static_pointer_cast( + dev_ctx_.GetBlob(key_ + suffix)); + } + std::shared_ptr AcquireMemoryFromPrimitive( mkldnn::memory::desc md, void* ptr, const std::string& suffix) { const auto local_key = key_ + suffix; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_bf16_mkldnn_op.py new file mode 100644 index 00000000000..dc881a57521 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_bf16_mkldnn_op.py @@ -0,0 +1,146 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from paddle.fluid.tests.unittests.test_layer_norm_op import * +from __future__ import print_function +import unittest +import numpy as np + +from operator import mul +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle import enable_static +from functools import reduce + +from paddle.fluid.tests.unittests.mkldnn.test_layer_norm_mkldnn_op import TestLayerNormMKLDNNOp +from paddle.fluid.tests.unittests.mkldnn.test_layer_norm_mkldnn_op import _reference_layer_norm_naive +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator + +np.random.random(123) + +_set_use_system_allocator(True) + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestLayerNormBF16MKLDNNOp(TestLayerNormMKLDNNOp): + def __assert_close(self, tensor, np_array, msg, rtol=2e-02, atol=2): + self.assertTrue( + np.allclose( + np.array(tensor), np_array, rtol=rtol, atol=atol), msg) + + def check_forward(self, + shape, + begin_norm_axis, + with_scale_bias=True, + with_is_test=False): + # attr + epsilon = 0.00001 + x_shape = shape + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + scale_shape = [D] + + np.random.seed(123) + x = np.random.random_sample(x_shape).astype(np.float32) + x_bf16 = convert_float_to_uint16(x) + + if with_scale_bias: + scale = np.random.random_sample(scale_shape).astype(np.float32) + bias = np.random.random_sample(scale_shape).astype(np.float32) + else: + scale = np.array([]) + bias = np.array([]) + + # reference forward & backward + y, mean, variance = _reference_layer_norm_naive(x, scale, bias, epsilon, + begin_norm_axis) + + y_bf16 = convert_float_to_uint16(y) + + var_dict = locals() + var_names = ['x_bf16', 'mean', 'variance', 'y_bf16'] + if with_scale_bias: + var_names.append('scale') + var_names.append('bias') + ground_truth = {name: var_dict[name] for name in var_names} + + program = fluid.Program() + with fluid.program_guard(program): + block = program.global_block() + + # scale and bias are fp32 and other vars are of bf16 + for name in ground_truth: + if name == 'x_bf16' or name == 'y_bf16': + block.create_var( + name=name, + dtype='uint16', + shape=ground_truth[name].shape) + else: + block.create_var( + name=name, + dtype='float32', + shape=ground_truth[name].shape) + + inputs = {"X": block.var('x_bf16')} + if with_scale_bias: + inputs["Scale"] = block.var('scale') + inputs["Bias"] = block.var('bias') + + block.append_op( + type="layer_norm", + inputs=inputs, + outputs={ + "Y": block.var('y_bf16'), + "Mean": block.var('mean'), # share the same memory + "Variance": block.var('variance'), # share the same memory + }, + attrs={ + "epsilon": epsilon, + "begin_norm_axis": begin_norm_axis, + "use_mkldnn": True, + "is_test": with_is_test + }) + + exe = fluid.Executor(core.CPUPlace()) + + input_list = ['x_bf16'] + if with_scale_bias: + input_list.append('scale') + input_list.append('bias') + + out = exe.run(program, + feed={name: var_dict[name] + for name in input_list}, + fetch_list=['y_bf16', 'mean', 'variance']) + self.__assert_close(y_bf16, out[0], "y_bf16", 2) + if not with_is_test: + self.__assert_close(mean, out[1], "mean") + self.__assert_close(variance, out[2], "variance", 1e-3) + + def test_check_forward_with_is_test(self): + self.check_forward( + shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True) + + # TODO (jczaja): Enable those to test when enabling training using bf16 + def test_check_forward_with_scale_and_bias(self): + pass + + def test_check_forward_without_scale_and_bias(self): + pass + + +if __name__ == "__main__": + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py new file mode 100644 index 00000000000..d20fb003ee9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_layer_norm_mkldnn_op.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from paddle.fluid.tests.unittests.test_layer_norm_op import * +from __future__ import print_function +import unittest +import numpy as np + +from operator import mul +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle import enable_static +from functools import reduce + +from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator + +np.random.random(123) + +_set_use_system_allocator(True) + + +def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): + x_shape = x.shape + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + x.shape = [N, D] + if scale.size == 0 and beta.size == 0: + scale = np.ones([1, D]) + beta = np.zeros([1, D]) + else: + scale = scale.reshape([1, D]) + beta = beta.reshape([1, D]) + + mean = np.mean(x, axis=1) + var = np.var(x, axis=1) + epsilon + output = scale * np.divide((x - mean.reshape([N, 1])), + (np.sqrt(var)).reshape([N, 1])) + beta + + x.shape, output.shape = x_shape, x_shape + return output, mean, var + + +class TestLayerNormMKLDNNOp(unittest.TestCase): + def setUp(self): + self.use_mkldnn = True + + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def check_forward(self, + shape, + begin_norm_axis, + with_scale_bias=True, + with_is_test=False): + # attr + epsilon = 0.00001 + x_shape = shape + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + scale_shape = [D] + + np.random.seed(123) + x = np.random.random_sample(x_shape).astype(np.float32) + + if with_scale_bias: + scale = np.random.random_sample(scale_shape).astype(np.float32) + bias = np.random.random_sample(scale_shape).astype(np.float32) + else: + scale = np.array([]) + bias = np.array([]) + + # reference forward & backward + y, mean, variance = _reference_layer_norm_naive(x, scale, bias, epsilon, + begin_norm_axis) + + var_dict = locals() + var_names = ['x', 'mean', 'variance', 'y'] + if with_scale_bias: + var_names.append('scale') + var_names.append('bias') + ground_truth = {name: var_dict[name] for name in var_names} + + program = fluid.Program() + with fluid.program_guard(program): + block = program.global_block() + + for name in ground_truth: + block.create_var( + name=name, dtype='float32', shape=ground_truth[name].shape) + + inputs = {"X": block.var('x')} + if with_scale_bias: + inputs["Scale"] = block.var('scale') + inputs["Bias"] = block.var('bias') + + block.append_op( + type="layer_norm", + inputs=inputs, + outputs={ + "Y": block.var('y'), + "Mean": block.var('mean'), # share the same memory + "Variance": block.var('variance'), # share the same memory + }, + attrs={ + "epsilon": epsilon, + "begin_norm_axis": begin_norm_axis, + "use_mkldnn": True, + "is_test": with_is_test + }) + + exe = fluid.Executor(core.CPUPlace()) + + input_list = ['x'] + if with_scale_bias: + input_list.append('scale') + input_list.append('bias') + + out = exe.run(program, + feed={name: var_dict[name] + for name in input_list}, + fetch_list=['y', 'mean', 'variance']) + self.__assert_close(y, out[0], "y") + if not with_is_test: + self.__assert_close(mean, out[1], "mean") + self.__assert_close(variance, out[2], "variance", 1e-3) + + def test_check_forward_with_scale_and_bias(self): + self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=3) + + def test_check_forward_without_scale_and_bias(self): + self.check_forward( + shape=[2, 3, 4, 5], begin_norm_axis=3, with_scale_bias=False) + + def test_check_forward_with_is_test(self): + self.check_forward( + shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True) + + +if __name__ == "__main__": + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_sum_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_sum_bf16_mkldnn_op.py index 05d739ae1f3..c71baad0c70 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_sum_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_sum_bf16_mkldnn_op.py @@ -25,7 +25,7 @@ import paddle.fluid.op as fluid_op @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") -class TestSumMKLDNN(TestSumOp): +class TestSumBF16MKLDNN(TestSumOp): def setUp(self): self.op_type = "sum" self.use_mkldnn = True diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 8df7ea35ec1..d2c07c185dd 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -117,8 +117,12 @@ class TestLayerNormOp(unittest.TestCase): begin_norm_axis, has_scale=True, has_bias=True, - y_grad_scale=1.0): - def test_with_place(place, shape, begin_norm_axis): + y_grad_scale=1.0, + use_mkldnn=False): + def test_with_place(place, + shape, + begin_norm_axis, + use_mkldnn=use_mkldnn): # attr epsilon = 0.00001 x_shape = shape @@ -181,7 +185,8 @@ class TestLayerNormOp(unittest.TestCase): }, attrs={ "epsilon": epsilon, - "begin_norm_axis": begin_norm_axis + "begin_norm_axis": begin_norm_axis, + "use_mkldnn": use_mkldnn }) # generate backward op_desc grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 1f153442aff..5fe1cc722e8 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -293,6 +293,8 @@ STATIC_MODE_TESTING_LIST = [ 'test_label_smooth_op', 'test_lamb_op', 'test_layer_norm_op', + 'test_layer_norm_mkldnn_op', + 'test_layer_norm_bf16_mkldnn_op', 'test_layer_norm_op_v2', 'test_learning_rate_scheduler', 'test_linear_interp_op', -- GitLab