未验证 提交 b0ee1405 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add conv2d bfloat16 support (#27325)

上级 b38e4f28
......@@ -1894,8 +1894,7 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>();
std::unordered_set<std::string> supported_op_types{"conv2d"};
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......
......@@ -166,7 +166,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#endif
if (input_data_type != framework::proto::VarType::INT8 &&
input_data_type != framework::proto::VarType::UINT8) {
input_data_type != framework::proto::VarType::UINT8 &&
input_data_type != framework::proto::VarType::BF16) {
auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
platform::errors::InvalidArgument(
......@@ -455,6 +456,11 @@ void Conv3DOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"});
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>("fuse_activation",
......
......@@ -55,12 +55,12 @@ inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
}
}
static mkldnn::memory::data_type GetDstType(bool is_int8,
static mkldnn::memory::data_type GetDstType(bool is_int8, bool is_bfloat16,
bool force_fp32_output,
std::string fuse_activation,
bool fuse_residual_conn,
const Tensor* residual_param) {
auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float
auto dst_dt = mkldnn::memory::data_type::f32;
if (is_int8) {
dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
? mkldnn::memory::data_type::u8
......@@ -72,6 +72,13 @@ static mkldnn::memory::data_type GetDstType(bool is_int8,
auto residual_dt = framework::ToMKLDNNDataType(residual_param->type());
if (dst_dt != residual_dt) dst_dt = residual_dt;
}
} else {
if (!force_fp32_output && is_bfloat16) {
dst_dt = mkldnn::memory::data_type::bf16;
if (fuse_residual_conn && residual_param) {
dst_dt = framework::ToMKLDNNDataType(residual_param->type());
}
}
}
return dst_dt;
}
......@@ -224,12 +231,15 @@ class ConvMKLDNNHandlerT
src_tz.size(), chosen_memory_format);
}
}
const auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
const auto weights_md =
platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
auto data_type = mkldnn::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
std::is_same<T_out, platform::bfloat16>::value)
data_type = mkldnn::memory::data_type::bf16;
const auto src_md =
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
const auto weights_md = platform::MKLDNNMemDesc(weights_tz, data_type,
MKLDNNMemoryFormat::any);
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
......@@ -241,8 +251,8 @@ class ConvMKLDNNHandlerT
if (bias) {
auto bias_tz = framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
......@@ -384,15 +394,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
bool is_BFLOAT16 = ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
auto residual_param = ctx.Input<Tensor>("ResidualData");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto dst_dt =
GetDstType(is_INT8, is_BFLOAT16, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (!is_INT8) {
ComputeFP32<float>(ctx);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeFP32<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::bf16) {
ComputeFP32<platform::bfloat16>(ctx);
}
} else {
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::u8) {
......@@ -1103,6 +1119,10 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
::paddle::platform::CPUPlace, U8,
ops::kConvMKLDNNINT8,
......
......@@ -110,4 +110,5 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace,
ops::DeQuantOpKernel<uint8_t>, ops::DeQuantOpKernel<int8_t>);
ops::DeQuantOpKernel<uint8_t>, ops::DeQuantOpKernel<int8_t>,
ops::DeQuantOpKernel<paddle::platform::bfloat16>);
......@@ -41,6 +41,7 @@ namespace detail {
// import numpy as np
// print np.dtype(np.float16).num # 23
constexpr int NPY_FLOAT16_ = 23;
constexpr int NPY_UINT16_ = 4;
// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
......@@ -60,6 +61,23 @@ struct npy_format_descriptor<paddle::platform::float16> {
static PYBIND11_DESCR name() { return _("float16"); }
};
// Note: Since bfloat16 is not a builtin type in C++ and in numpy,
// we register paddle::platform::bfloat16 as numpy.uint16.
template <>
struct npy_format_descriptor<paddle::platform::bfloat16> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_UINT16_);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "H" represents UINT16.
// Details at:
// https://docs.python.org/3/library/struct.html#format-characters.
return "H";
}
static PYBIND11_DESCR name() { return _("bfloat16"); }
};
} // namespace detail
} // namespace pybind11
......
......@@ -613,7 +613,9 @@ def convert_np_dtype_to_dtype_(np_dtype):
elif dtype == np.bool:
return core.VarDesc.VarType.BOOL
elif dtype == np.uint16:
return core.VarDesc.VarType.INT16
# since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return core.VarDesc.VarType.BF16
elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8
elif dtype == np.int8:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import struct
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp
def conv2d_forward_refer(input, filter, group, conv_param):
out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group,
conv_param)
return out
def conv2d_residual_naive(out, residual):
assert out.shape == residual.shape
out = np.add(out, residual)
return out
class TestConv2dBf16Op(TestConv2dOp):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
self.exhaustive_search = False
self.use_cuda = False
self.use_mkldnn = True
self.weight_type = np.float32
self.input_type = np.float32
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
self.force_fp32_output = False
self.init_group()
self.init_dilation()
self.init_test_case()
self.init_fuse_relu()
self.init_fuse_residual()
self.init_data_type()
self.init_force_fp32_output()
conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
self.input = np.random.random(self.input_size).astype(np.float32)
self.filter = np.random.random(self.filter_size).astype(np.float32)
conv_out, _, _, _, _ = conv2d_forward_naive(self.input, self.filter,
self.groups, conv2d_param)
self.conv_output_float = conv_out
if self.fuse_residual:
self.input_residual = np.random.random(
self.input_residual_size).astype(np.float32)
self.conv_output_float = conv2d_residual_naive(
self.conv_output_float, self.input_residual)
self.conv_output = convert_float_to_uint16(self.conv_output_float)
self.outputs = {'Output': self.conv_output}
elif self.force_fp32_output:
self.outputs = {'Output': self.conv_output_float.astype(np.float32)}
if self.input_type is not np.float32:
self.input = convert_float_to_uint16(self.input)
self.inputs = {
'Input': self.input.view(self.input_type),
'Filter': OpTest.np_dtype_to_fluid_dtype(
self.filter.astype(self.weight_type))
}
if self.fuse_residual:
self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype(
convert_float_to_uint16(self.input_residual))
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
'force_fp32_output': self.force_fp32_output,
'fuse_residual_connection': self.fuse_residual
}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
def test_check_grad_no_filter(self):
pass
def test_check_grad_no_input(self):
pass
def init_test_case(self):
TestConv2dOp.init_test_case(self)
self.input_size = [1, 1, 5, 5] # NCHW
f_c = self.input_size[1] // self.groups
self.input_residual_size = [1, 2, 3, 3]
self.filter_size = [2, f_c, 3, 3]
def init_data_type(self):
self.weight_type = np.float32
self.input_type = np.float32
def init_force_fp32_output(self):
self.force_fp32_output = False
def init_fuse_relu(self):
self.fuse_activation = "relu"
def init_fuse_residual(self):
self.fuse_residual = True
class TestConv2d(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.input_residual_size = [2, 6, 3, 3]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_data_type(self):
self.input_type = np.uint16
class TestWithPad(TestConv2d):
def init_test_case(self):
TestConv2d.init_test_case(self)
self.pad = [1, 1]
self.input_residual_size = [2, 6, 5, 5]
class TestWithGroup(TestConv2d):
def init_group(self):
self.groups = 3
class TestWithStride(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 6, 6]
self.input_residual_size = [2, 6, 3, 3]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_data_type(self):
self.input_type = np.uint16
class TestWith1x1ForceFP32Output(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [1, 3, 5, 5]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_force_fp32_output(self):
self.force_fp32_output = True
def init_fuse_residual(self):
self.fuse_residual = False
class TestWithInput1x1Filter1x1(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 1, 1]
self.input_residual_size = [2, 6, 1, 1]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
self.groups = 3
if __name__ == '__main__':
unittest.main()
......@@ -36,6 +36,7 @@ class TestConv2dInt8Op(TestConv2dOp):
self.use_cuda = False
self.use_mkldnn = False
self.data_format = "NCHW"
self.mkldnn_data_type = "int8"
self.weighttype = np.float32
self.use_mkldnn = True
self.init_group()
......@@ -141,7 +142,8 @@ class TestConv2dInt8Op(TestConv2dOp):
'Scale_weights': self.scale_weights,
'Scale_in_eltwise': self.scale_in_eltwise,
'fuse_activation': self.fuse_activation,
'fuse_residual_connection': self.fuse_residual
'fuse_residual_connection': self.fuse_residual,
'mkldnn_data_type': self.mkldnn_data_type
}
self.outputs = {'Output': output}
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
class TestDeQuantizeOp(OpTest):
......@@ -32,6 +32,9 @@ class TestDeQuantizeOp(OpTest):
input = (np.random.randint(0, 100, self.input_size) - 50
).astype(self.data_type)
output = (input * (1 / self.scale)).astype('float')
elif self.data_type == 'uint16':
output = np.random.random(self.input_size).astype(np.float32)
input = convert_float_to_uint16(output)
else:
input = (np.random.randint(0, 100,
self.input_size)).astype(self.data_type)
......@@ -70,5 +73,13 @@ class TestDeQuantizeOp2(TestDeQuantizeOp):
self.data_type = 'uint8'
class TestDeQuantizeOpBf16(TestDeQuantizeOp):
def set_scale(self):
self.scale = 1.0
def set_data_type(self):
self.data_type = 'uint16'
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,7 @@ import warnings
import numpy as np
import random
import six
import struct
import time
import itertools
import collections
......@@ -167,6 +168,18 @@ def skip_check_grad_ci(reason=None):
return wrapper
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(float_list):
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
return np.reshape(new_output, float_list.shape).view(np.uint16)
class OpTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
......@@ -242,6 +255,11 @@ class OpTest(unittest.TestCase):
self.call_once = True
self.dtype = data_type
def is_bfloat16_op(self):
return self.dtype == np.uint16 or (
hasattr(self, 'mkldnn_data_type') and
getattr(self, 'mkldnn_data_type') is "bfloat16")
def infer_dtype_from_inputs_outputs(self, inputs, outputs):
def is_np_data(input):
return isinstance(input, (np.ndarray, np.generic))
......@@ -276,8 +294,9 @@ class OpTest(unittest.TestCase):
infer_dtype(inputs, dtype_set)
dtype_list = [
np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16),
np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.int16),
np.dtype(np.int8), np.dtype(np.uint8), np.dtype(np.bool)
np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.uint16),
np.dtype(np.int16), np.dtype(np.int8), np.dtype(np.uint8),
np.dtype(np.bool)
]
# check the dtype in dtype_list in order, select the first dtype that in dtype_set
for dtype in dtype_list:
......@@ -957,6 +976,14 @@ class OpTest(unittest.TestCase):
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
atol = 0
if self.is_bfloat16_op():
check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr(
self, 'force_fp32_output'):
atol = 1e-2
else:
atol = 2
if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list:
raise AssertionError(
......@@ -1286,8 +1313,9 @@ class OpTest(unittest.TestCase):
no_grad_set = set()
else:
if (self.op_type not in no_grad_set_white_list.NEED_TO_FIX_OP_LIST
) and (self.op_type not in
no_grad_set_white_list.NOT_CHECK_OP_LIST):
) and (
self.op_type not in no_grad_set_white_list.NOT_CHECK_OP_LIST
) and (not self.is_bfloat16_op()):
raise AssertionError("no_grad_set must be None, op_type is " +
self.op_type + " Op.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册