未验证 提交 8c0ea4bf 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add bf16 matmul, fc, elementwise add and mul (#28729)

* Add bf16 matmul, fc, elementwise add and mul

* Correct unit test
上级 efc3b182
......@@ -2101,13 +2101,18 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "fusion_gru", "gelu",
"layer_norm", "reshape2", "softmax",
"sum", "transpose2"});
std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "elementwise_mul", "fc",
"fusion_gru", "gelu", "layer_norm", "matmul", "reshape2", "softmax",
"sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<bool>("use_mkldnn") ||
node->Op()->Type() == "reshape2";
});
return op;
}
......
......@@ -24,10 +24,12 @@ namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::string& mkldnn_data_type = "float32") {
const std::string& mkldnn_data_type = "float32",
const bool use_mkldnn = true) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type != "reshape2") op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
if (type == "conv2d") {
......
......@@ -66,6 +66,8 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
elementwise_add, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>)
......
......@@ -19,5 +19,7 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>)
......@@ -536,9 +536,13 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
if (!is_int8 || force_fp32_output) {
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
if ((!is_int8 && !is_bfloat16) || force_fp32_output) {
GetPrimitiveFactory<T_in, T_w, float>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (is_bfloat16) {
GetPrimitiveFactory<T_in, T_w, platform::bfloat16>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
} else if (fuse_relu) {
GetPrimitiveFactory<T_in, T_w, uint8_t>(dev_ctx, prim_key)
->ExecuteFcPrimitive(input, w, bias, output, dev_ctx, ctx);
......@@ -580,6 +584,11 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace,
FP32, ops::kFCMKLDNNFP32,
ops::FCMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
fc, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kFCMKLDNNFP32,
ops::FCMKLDNNOpKernel<paddle::platform::bfloat16,
paddle::platform::bfloat16>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, ::paddle::platform::CPUPlace,
U8, ops::kFCMKLDNNINT8,
ops::FCMKLDNNOpKernel<uint8_t, int8_t>);
......
......@@ -42,6 +42,11 @@ constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) {
......@@ -170,7 +175,9 @@ class MatMulFactory {
void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx,
const memory::dim N, memory::dim b,
memory::dims* out_strides) const {
if (!IsInt8<OT>() && IsOutputFused(ctx)) *out_strides = {N, b * N, 1};
if (!IsInt8<OT>() && !IsBfloat16<OT>() && IsOutputFused(ctx)) {
*out_strides = {N, b * N, 1};
}
}
MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
......@@ -348,10 +355,14 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
template <typename XT, typename YT>
static void ExecuteMatMul(const ExecutionContext& ctx) {
constexpr bool is_int8 = IsInt8<XT>();
constexpr bool is_bfloat16 = IsBfloat16<XT>();
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
if (!is_int8 || force_fp32_output) {
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
GetPrimitiveFactory<XT, YT, float>(ctx)->CreateAndExecute(ctx);
} else if (is_bfloat16) {
GetPrimitiveFactory<XT, YT, paddle::platform::bfloat16>(ctx)
->CreateAndExecute(ctx);
} else if (fuse_relu) {
GetPrimitiveFactory<XT, YT, uint8_t>(ctx)->CreateAndExecute(ctx);
} else {
......@@ -376,5 +387,7 @@ class DNNLMatMulKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace,
ops::DNNLMatMulKernel<float>, ops::DNNLMatMulKernel<int8_t>,
ops::DNNLMatMulKernel<float>,
ops::DNNLMatMulKernel<paddle::platform::bfloat16>,
ops::DNNLMatMulKernel<int8_t>,
ops::DNNLMatMulKernel<uint8_t>);
......@@ -25,18 +25,13 @@ class TestMKLDNNCpuBfloat16Pass(InferencePassTest):
with fluid.program_guard(self.main_program, self.startup_program):
x = fluid.data(
name='x', shape=[-1] + self.shape_x, dtype=self.d_type)
y = fluid.data(
name='y', shape=[-1] + self.shape_y, dtype=self.d_type)
out = fluid.layers.matmul(x, y)
out = fluid.layers.transpose(out, perm=[0, 1, 2, 3])
out = fluid.layers.transpose(x, perm=[0, 1, 2, 3])
out = fluid.layers.reshape(out, [0, 0, 0, 0])
out = fluid.layers.fc(out, size=1)
self.feeds = {
"x":
np.random.random([self.bs] + self.shape_x).astype(self.d_type),
"y":
np.random.random([self.bs] + self.shape_y).astype(self.d_type)
np.random.random([self.bs] + self.shape_x).astype(self.d_type)
}
self.fetch_list = [out]
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle import enable_static
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestElementwiseAddBf16MklDNNOp(OpTest):
def setUp(self):
self.op_type = "elementwise_add"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
self.axis = -1
self.generate_data()
self.inputs = {
'X': convert_float_to_uint16(self.x),
'Y': convert_float_to_uint16(self.y)
}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
def generate_data(self):
self.x = np.random.random(100, ).astype(np.float32)
self.y = np.random.random(100, ).astype(np.float32)
self.out = np.add(self.x, self.y)
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__':
enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle import enable_static
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestElementwiseMulBf16MklDNNOp(OpTest):
def setUp(self):
self.op_type = "elementwise_mul"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
self.axis = -1
self.generate_data()
self.inputs = {
'X': convert_float_to_uint16(self.x),
'Y': convert_float_to_uint16(self.y)
}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
def generate_data(self):
self.x = np.random.random(100, ).astype(np.float32)
self.y = np.random.random(100, ).astype(np.float32)
self.out = np.multiply(self.x, self.y)
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad_normal(self):
pass
def test_check_grad_ingore_x(self):
pass
def test_check_grad_ingore_y(self):
pass
if __name__ == '__main__':
enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle import enable_static
def fully_connected_naive(input, weights, bias_data):
result = np.dot(input, weights) + bias_data
return result
class MatrixGenerate:
def __init__(self, mb, ic, oc, h, w):
self.input = np.random.random((mb, ic * h * w)).astype(np.float32)
self.weights = np.random.random((ic * h * w, oc)).astype(np.float32)
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestFcBf16MklDNNOp(OpTest):
def generate_data(self):
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
self.bias = np.random.random(15).astype("float32")
def setUp(self):
self.op_type = "fc"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
self.force_fp32_output = False
self.generate_data()
self.output = fully_connected_naive(self.matrix.input,
self.matrix.weights, self.bias)
if not self.force_fp32_output:
self.output = convert_float_to_uint16(self.output)
self.inputs = {
'Input': convert_float_to_uint16(self.matrix.input),
'W': self.matrix.weights,
'Bias': self.bias
}
self.attrs = {
'use_mkldnn': self.use_mkldnn,
'force_fp32_output': self.force_fp32_output
}
self.outputs = {'Out': self.output}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad_normal(self):
pass
def test_check_grad_no_weight(self):
pass
class TestFCMKLDNNOp1(TestFcBf16MklDNNOp):
def generate_data(self):
self.matrix = MatrixGenerate(2, 15, 48, 2, 2)
self.bias = np.random.random(48).astype(np.float32)
if __name__ == "__main__":
enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
from paddle import enable_static
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestMatmulBf16MklDNNOp(OpTest):
def generate_data(self):
self.x = np.random.random((25, 2, 2)).astype(np.float32)
self.y = np.random.random((25, 2, 2)).astype(np.float32)
self.alpha = 1.0
self.out = self.alpha * np.matmul(self.x, self.y)
def set_attributes(self):
self.alpha = self.alpha if hasattr(self, 'alpha') else 1.0
self.attrs = {
'alpha': self.alpha,
"use_mkldnn": self.use_mkldnn,
"mkldnn_data_type": self.mkldnn_data_type,
"force_fp32_output": self.force_fp32_output
}
def setUp(self):
self.op_type = "matmul"
self.use_mkldnn = True
self.dtype = np.uint16
self.mkldnn_data_type = "bfloat16"
self.force_fp32_output = False
self.generate_data()
self.set_attributes()
if not self.force_fp32_output:
self.out = convert_float_to_uint16(self.out)
self.outputs = {'Out': self.out}
self.x = convert_float_to_uint16(self.x)
self.y = convert_float_to_uint16(self.y)
self.inputs = {'X': self.x, 'Y': self.y}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
class TestDnnlMatMulOpAlpha(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((17, 2, 3)).astype(np.float32)
self.y = np.random.random((17, 3, 2)).astype(np.float32)
self.alpha = 2.0
self.out = self.alpha * np.matmul(self.x, self.y)
class TestDnnlMatMulOp2D(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((9, 12)).astype(np.float32)
self.out = np.matmul(self.x, self.y)
class TestDnnlMatMulOpTransposeX(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((12, 9)).astype(np.float32)
self.out = np.matmul(np.transpose(self.x), self.y)
def set_attributes(self):
self.attrs = {
"use_mkldnn": self.use_mkldnn,
"mkldnn_data_type": self.mkldnn_data_type,
'transpose_X': True
}
class TestDnnlMatMulOpTransposeY(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((12, 9)).astype(np.float32)
self.out = np.matmul(self.x, np.transpose(self.y))
def set_attributes(self):
self.attrs = {
"use_mkldnn": self.use_mkldnn,
"mkldnn_data_type": self.mkldnn_data_type,
'transpose_Y': True
}
class TestMatmulBf16MklDNNForceFp32Output(TestMatmulBf16MklDNNOp):
def generate_data(self):
self.x = np.random.random((12, 9)).astype(np.float32)
self.y = np.random.random((9, 12)).astype(np.float32)
self.force_fp32_output = True
self.alpha = 0.5
self.out = self.alpha * np.matmul(self.x, self.y)
if __name__ == "__main__":
enable_static()
unittest.main()
......@@ -590,13 +590,17 @@ STATIC_MODE_TESTING_LIST = [
'test_conv3d_mkldnn_op',
'test_dequantize_mkldnn_op',
'test_elementwise_add_mkldnn_op',
'test_elementwise_add_bf16_mkldnn_op',
'test_elementwise_mul_mkldnn_op',
'test_elementwise_mul_bf16_mkldnn_op',
'test_fc_mkldnn_op',
'test_fc_bf16_mkldnn_op',
'test_fusion_gru_int8_mkldnn_op',
'test_fusion_gru_mkldnn_op',
'test_gaussian_random_mkldnn_op',
'test_lrn_mkldnn_op',
'test_matmul_mkldnn_op',
'test_matmul_bf16_mkldnn_op',
'test_mul_int8_mkldnn_op',
'test_multi_gru_mkldnn_op',
'test_pool2d_int8_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册