From 8c0ea4bffeb582327662b3387ad29135f87090e0 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Fri, 20 Nov 2020 09:03:46 +0100 Subject: [PATCH] Add bf16 matmul, fc, elementwise add and mul (#28729) * Add bf16 matmul, fc, elementwise add and mul * Correct unit test --- .../framework/ir/graph_pattern_detector.cc | 11 +- .../cpu_bfloat16_placement_pass_tester.cc | 4 +- .../mkldnn/elementwise_add_mkldnn_op.cc | 2 + .../mkldnn/elementwise_mul_mkldnn_op.cc | 2 + paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 11 +- .../operators/mkldnn/matmul_mkldnn_op.cc | 19 ++- .../test_mkldnn_cpu_bfloat16_pass.py | 9 +- .../test_elementwise_add_bf16_mkldnn_op.py | 60 +++++++++ .../test_elementwise_mul_bf16_mkldnn_op.py | 60 +++++++++ .../mkldnn/test_fc_bf16_mkldnn_op.py | 85 ++++++++++++ .../mkldnn/test_matmul_bf16_mkldnn_op.py | 121 ++++++++++++++++++ tools/static_mode_white_list.py | 4 + 12 files changed, 373 insertions(+), 15 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_bf16_mkldnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_fc_bf16_mkldnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_matmul_bf16_mkldnn_op.py diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5546a0e3726..56dacdc6db4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2101,13 +2101,18 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set({"concat", "conv2d", "fusion_gru", "gelu", - "layer_norm", "reshape2", "softmax", - "sum", "transpose2"}); + std::unordered_set( + {"concat", "conv2d", "elementwise_add", "elementwise_mul", "fc", + "fusion_gru", "gelu", "layer_norm", "matmul", "reshape2", "softmax", + "sum", "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types); + op->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("use_mkldnn") || + node->Op()->Type() == "reshape2"; + }); return op; } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc index 4e3704e510c..c64bc8a214a 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_placement_pass_tester.cc @@ -24,10 +24,12 @@ namespace ir { void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& inputs, const std::vector& outputs, - const std::string& mkldnn_data_type = "float32") { + const std::string& mkldnn_data_type = "float32", + const bool use_mkldnn = true) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); + if (type != "reshape2") op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("mkldnn_data_type", mkldnn_data_type); if (type == "conv2d") { diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 3dcf5bf6a32..54902015ce1 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -66,6 +66,8 @@ namespace ops = paddle::operators; REGISTER_OP_KERNEL( elementwise_add, MKLDNN, ::paddle::platform::CPUPlace, ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index c73b502a40e..293b5a1a2d3 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -19,5 +19,7 @@ namespace ops = paddle::operators; REGISTER_OP_KERNEL( elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel, ops::EltwiseMKLDNNKernel) diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 0bec5619f54..d560e80a332 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -536,9 +536,13 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, framework::vectorize(w->dims()), ctx.OutputName("Out")); constexpr bool is_int8 = std::is_same::value || std::is_same::value; - if (!is_int8 || force_fp32_output) { + bool is_bfloat16 = std::is_same::value; + if ((!is_int8 && !is_bfloat16) || force_fp32_output) { GetPrimitiveFactory(dev_ctx, prim_key) ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); + } else if (is_bfloat16) { + GetPrimitiveFactory(dev_ctx, prim_key) + ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); } else if (fuse_relu) { GetPrimitiveFactory(dev_ctx, prim_key) ->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx); @@ -580,6 +584,11 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace, FP32, ops::kFCMKLDNNFP32, ops::FCMKLDNNOpKernel); +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( + fc, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kFCMKLDNNFP32, + ops::FCMKLDNNOpKernel); + REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace, U8, ops::kFCMKLDNNINT8, ops::FCMKLDNNOpKernel); diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc index 3ae34fe0e90..21f94c07c1f 100644 --- a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -42,6 +42,11 @@ constexpr bool IsInt8() { return std::is_same::value || std::is_same::value; } +template +constexpr bool IsBfloat16() { + return std::is_same::value; +} + // Get row matrix shape from a vector shape. If the rank of x_dim > 1, the // original x_dim is returned. static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) { @@ -170,7 +175,9 @@ class MatMulFactory { void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx, const memory::dim N, memory::dim b, memory::dims* out_strides) const { - if (!IsInt8() && IsOutputFused(ctx)) *out_strides = {N, b * N, 1}; + if (!IsInt8() && !IsBfloat16() && IsOutputFused(ctx)) { + *out_strides = {N, b * N, 1}; + } } MatMulDims GetMatmulDims(const ExecutionContext& ctx) { @@ -348,10 +355,14 @@ static std::shared_ptr> GetPrimitiveFactory( template static void ExecuteMatMul(const ExecutionContext& ctx) { constexpr bool is_int8 = IsInt8(); + constexpr bool is_bfloat16 = IsBfloat16(); const bool force_fp32_output = ctx.Attr("force_fp32_output"); constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses - if (!is_int8 || force_fp32_output) { + if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); + } else if (is_bfloat16) { + GetPrimitiveFactory(ctx) + ->CreateAndExecute(ctx); } else if (fuse_relu) { GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); } else { @@ -376,5 +387,7 @@ class DNNLMatMulKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace, - ops::DNNLMatMulKernel, ops::DNNLMatMulKernel, + ops::DNNLMatMulKernel, + ops::DNNLMatMulKernel, + ops::DNNLMatMulKernel, ops::DNNLMatMulKernel); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_cpu_bfloat16_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_cpu_bfloat16_pass.py index 0a4d460d1fb..4b36e4b742c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_cpu_bfloat16_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_cpu_bfloat16_pass.py @@ -25,18 +25,13 @@ class TestMKLDNNCpuBfloat16Pass(InferencePassTest): with fluid.program_guard(self.main_program, self.startup_program): x = fluid.data( name='x', shape=[-1] + self.shape_x, dtype=self.d_type) - y = fluid.data( - name='y', shape=[-1] + self.shape_y, dtype=self.d_type) - out = fluid.layers.matmul(x, y) - out = fluid.layers.transpose(out, perm=[0, 1, 2, 3]) + out = fluid.layers.transpose(x, perm=[0, 1, 2, 3]) out = fluid.layers.reshape(out, [0, 0, 0, 0]) out = fluid.layers.fc(out, size=1) self.feeds = { "x": - np.random.random([self.bs] + self.shape_x).astype(self.d_type), - "y": - np.random.random([self.bs] + self.shape_y).astype(self.d_type) + np.random.random([self.bs] + self.shape_x).astype(self.d_type) } self.fetch_list = [out] diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_bf16_mkldnn_op.py new file mode 100644 index 00000000000..7e4a1172380 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_bf16_mkldnn_op.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +from paddle import enable_static + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestElementwiseAddBf16MklDNNOp(OpTest): + def setUp(self): + self.op_type = "elementwise_add" + self.use_mkldnn = True + self.mkldnn_data_type = "bfloat16" + self.axis = -1 + + self.generate_data() + self.inputs = { + 'X': convert_float_to_uint16(self.x), + 'Y': convert_float_to_uint16(self.y) + } + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': convert_float_to_uint16(self.out)} + + def generate_data(self): + self.x = np.random.random(100, ).astype(np.float32) + self.y = np.random.random(100, ).astype(np.float32) + self.out = np.add(self.x, self.y) + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +if __name__ == '__main__': + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py new file mode 100644 index 00000000000..c2716420fba --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_bf16_mkldnn_op.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +from paddle import enable_static + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestElementwiseMulBf16MklDNNOp(OpTest): + def setUp(self): + self.op_type = "elementwise_mul" + self.use_mkldnn = True + self.mkldnn_data_type = "bfloat16" + self.axis = -1 + + self.generate_data() + self.inputs = { + 'X': convert_float_to_uint16(self.x), + 'Y': convert_float_to_uint16(self.y) + } + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': convert_float_to_uint16(self.out)} + + def generate_data(self): + self.x = np.random.random(100, ).astype(np.float32) + self.y = np.random.random(100, ).astype(np.float32) + self.out = np.multiply(self.x, self.y) + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +if __name__ == '__main__': + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fc_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_bf16_mkldnn_op.py new file mode 100644 index 00000000000..1104372c741 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fc_bf16_mkldnn_op.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +from paddle import enable_static + + +def fully_connected_naive(input, weights, bias_data): + result = np.dot(input, weights) + bias_data + return result + + +class MatrixGenerate: + def __init__(self, mb, ic, oc, h, w): + self.input = np.random.random((mb, ic * h * w)).astype(np.float32) + self.weights = np.random.random((ic * h * w, oc)).astype(np.float32) + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestFcBf16MklDNNOp(OpTest): + def generate_data(self): + self.matrix = MatrixGenerate(1, 10, 15, 3, 3) + self.bias = np.random.random(15).astype("float32") + + def setUp(self): + self.op_type = "fc" + self.use_mkldnn = True + self.mkldnn_data_type = "bfloat16" + self.force_fp32_output = False + self.generate_data() + + self.output = fully_connected_naive(self.matrix.input, + self.matrix.weights, self.bias) + if not self.force_fp32_output: + self.output = convert_float_to_uint16(self.output) + + self.inputs = { + 'Input': convert_float_to_uint16(self.matrix.input), + 'W': self.matrix.weights, + 'Bias': self.bias + } + + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'force_fp32_output': self.force_fp32_output + } + + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad_normal(self): + pass + + def test_check_grad_no_weight(self): + pass + + +class TestFCMKLDNNOp1(TestFcBf16MklDNNOp): + def generate_data(self): + self.matrix = MatrixGenerate(2, 15, 48, 2, 2) + self.bias = np.random.random(48).astype(np.float32) + + +if __name__ == "__main__": + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_bf16_mkldnn_op.py new file mode 100644 index 00000000000..149002fc765 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_bf16_mkldnn_op.py @@ -0,0 +1,121 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import os +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +from paddle import enable_static + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestMatmulBf16MklDNNOp(OpTest): + def generate_data(self): + self.x = np.random.random((25, 2, 2)).astype(np.float32) + self.y = np.random.random((25, 2, 2)).astype(np.float32) + self.alpha = 1.0 + self.out = self.alpha * np.matmul(self.x, self.y) + + def set_attributes(self): + self.alpha = self.alpha if hasattr(self, 'alpha') else 1.0 + self.attrs = { + 'alpha': self.alpha, + "use_mkldnn": self.use_mkldnn, + "mkldnn_data_type": self.mkldnn_data_type, + "force_fp32_output": self.force_fp32_output + } + + def setUp(self): + self.op_type = "matmul" + self.use_mkldnn = True + self.dtype = np.uint16 + self.mkldnn_data_type = "bfloat16" + self.force_fp32_output = False + self.generate_data() + self.set_attributes() + + if not self.force_fp32_output: + self.out = convert_float_to_uint16(self.out) + self.outputs = {'Out': self.out} + + self.x = convert_float_to_uint16(self.x) + self.y = convert_float_to_uint16(self.y) + self.inputs = {'X': self.x, 'Y': self.y} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + pass + + +class TestDnnlMatMulOpAlpha(TestMatmulBf16MklDNNOp): + def generate_data(self): + self.x = np.random.random((17, 2, 3)).astype(np.float32) + self.y = np.random.random((17, 3, 2)).astype(np.float32) + self.alpha = 2.0 + self.out = self.alpha * np.matmul(self.x, self.y) + + +class TestDnnlMatMulOp2D(TestMatmulBf16MklDNNOp): + def generate_data(self): + self.x = np.random.random((12, 9)).astype(np.float32) + self.y = np.random.random((9, 12)).astype(np.float32) + self.out = np.matmul(self.x, self.y) + + +class TestDnnlMatMulOpTransposeX(TestMatmulBf16MklDNNOp): + def generate_data(self): + self.x = np.random.random((12, 9)).astype(np.float32) + self.y = np.random.random((12, 9)).astype(np.float32) + self.out = np.matmul(np.transpose(self.x), self.y) + + def set_attributes(self): + self.attrs = { + "use_mkldnn": self.use_mkldnn, + "mkldnn_data_type": self.mkldnn_data_type, + 'transpose_X': True + } + + +class TestDnnlMatMulOpTransposeY(TestMatmulBf16MklDNNOp): + def generate_data(self): + self.x = np.random.random((12, 9)).astype(np.float32) + self.y = np.random.random((12, 9)).astype(np.float32) + self.out = np.matmul(self.x, np.transpose(self.y)) + + def set_attributes(self): + self.attrs = { + "use_mkldnn": self.use_mkldnn, + "mkldnn_data_type": self.mkldnn_data_type, + 'transpose_Y': True + } + + +class TestMatmulBf16MklDNNForceFp32Output(TestMatmulBf16MklDNNOp): + def generate_data(self): + self.x = np.random.random((12, 9)).astype(np.float32) + self.y = np.random.random((9, 12)).astype(np.float32) + self.force_fp32_output = True + self.alpha = 0.5 + self.out = self.alpha * np.matmul(self.x, self.y) + + +if __name__ == "__main__": + enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 7f2ee9cb170..544c79fb13a 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -590,13 +590,17 @@ STATIC_MODE_TESTING_LIST = [ 'test_conv3d_mkldnn_op', 'test_dequantize_mkldnn_op', 'test_elementwise_add_mkldnn_op', + 'test_elementwise_add_bf16_mkldnn_op', 'test_elementwise_mul_mkldnn_op', + 'test_elementwise_mul_bf16_mkldnn_op', 'test_fc_mkldnn_op', + 'test_fc_bf16_mkldnn_op', 'test_fusion_gru_int8_mkldnn_op', 'test_fusion_gru_mkldnn_op', 'test_gaussian_random_mkldnn_op', 'test_lrn_mkldnn_op', 'test_matmul_mkldnn_op', + 'test_matmul_bf16_mkldnn_op', 'test_mul_int8_mkldnn_op', 'test_multi_gru_mkldnn_op', 'test_pool2d_int8_mkldnn_op', -- GitLab