diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 9c1eaa99a3ca04ddbeecab639d5587d5509e3f00..96952e20c2158453df0d94c9e43c64bb6bb1e04f 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1894,8 +1894,7 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { - std::unordered_set supported_op_types = - std::unordered_set(); + std::unordered_set supported_op_types{"conv2d"}; if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index bf97b9d03c455182a8d95b6987896b9a580c84fe..ef8a2b38f20b99f0b1e41ddc1976f88dd8d1f5ab 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -166,7 +166,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( #endif if (input_data_type != framework::proto::VarType::INT8 && - input_data_type != framework::proto::VarType::UINT8) { + input_data_type != framework::proto::VarType::UINT8 && + input_data_type != framework::proto::VarType::BF16) { auto filter_data_type = ctx.Input("Filter")->type(); PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, platform::errors::InvalidArgument( @@ -455,6 +456,11 @@ void Conv3DOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "int8", "bfloat16"}); AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); AddAttr("fuse_activation", diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index a6cda154e55b972fc653cffc4815f9e0f6e975de..7a4e11091fd3a6d064f3c4d905bb65c61d62d882 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -55,12 +55,12 @@ inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format, } } -static mkldnn::memory::data_type GetDstType(bool is_int8, +static mkldnn::memory::data_type GetDstType(bool is_int8, bool is_bfloat16, bool force_fp32_output, std::string fuse_activation, bool fuse_residual_conn, const Tensor* residual_param) { - auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float + auto dst_dt = mkldnn::memory::data_type::f32; if (is_int8) { dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") ? mkldnn::memory::data_type::u8 @@ -72,6 +72,13 @@ static mkldnn::memory::data_type GetDstType(bool is_int8, auto residual_dt = framework::ToMKLDNNDataType(residual_param->type()); if (dst_dt != residual_dt) dst_dt = residual_dt; } + } else { + if (!force_fp32_output && is_bfloat16) { + dst_dt = mkldnn::memory::data_type::bf16; + if (fuse_residual_conn && residual_param) { + dst_dt = framework::ToMKLDNNDataType(residual_param->type()); + } + } } return dst_dt; } @@ -224,12 +231,15 @@ class ConvMKLDNNHandlerT src_tz.size(), chosen_memory_format); } } - - const auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); - const auto weights_md = - platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType(), - MKLDNNMemoryFormat::any); + auto data_type = mkldnn::memory::data_type::f32; + if (ctx.Attr("mkldnn_data_type") == "bfloat16" || + std::is_same::value) + data_type = mkldnn::memory::data_type::bf16; + + const auto src_md = + platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format); + const auto weights_md = platform::MKLDNNMemDesc(weights_tz, data_type, + MKLDNNMemoryFormat::any); const auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); @@ -241,8 +251,8 @@ class ConvMKLDNNHandlerT if (bias) { auto bias_tz = framework::vectorize(bias->dims()); - auto bias_md = platform::MKLDNNMemDesc( - bias_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::x); + auto bias_md = + platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x); this->AcquireForwardPrimitiveDescriptor( conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, @@ -384,15 +394,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "Operator DNNL Conv must use CPUPlace")); bool is_INT8 = std::is_same::value || std::is_same::value; + bool is_BFLOAT16 = ctx.Attr("mkldnn_data_type") == "bfloat16"; + auto residual_param = ctx.Input("ResidualData"); + bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); + std::string fuse_activation = ctx.Attr("fuse_activation"); + bool force_fp32_output = ctx.Attr("force_fp32_output"); + auto dst_dt = + GetDstType(is_INT8, is_BFLOAT16, force_fp32_output, fuse_activation, + fuse_residual_conn, residual_param); if (!is_INT8) { - ComputeFP32(ctx); + if (dst_dt == mkldnn::memory::data_type::f32) { + ComputeFP32(ctx); + } else if (dst_dt == mkldnn::memory::data_type::bf16) { + ComputeFP32(ctx); + } } else { - std::string fuse_activation = ctx.Attr("fuse_activation"); - bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - bool force_fp32_output = ctx.Attr("force_fp32_output"); - auto residual_param = ctx.Input("ResidualData"); - auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation, - fuse_residual_conn, residual_param); if (dst_dt == mkldnn::memory::data_type::f32) { ComputeINT8(ctx); } else if (dst_dt == mkldnn::memory::data_type::u8) { @@ -1103,6 +1119,10 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ops::kConvMKLDNNFP32, ops::ConvMKLDNNOpKernel); +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( + conv2d, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kConvMKLDNNFP32, + ops::ConvMKLDNNOpKernel); + REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ::paddle::platform::CPUPlace, U8, ops::kConvMKLDNNINT8, diff --git a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc index 540642c7140e707441ad9c4d71ae9b777863a7bd..70d4c34d9c5c4d28e2705c85f56bc65f90fbb3cf 100644 --- a/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc @@ -110,4 +110,5 @@ class DeQuantOpKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, - ops::DeQuantOpKernel, ops::DeQuantOpKernel); + ops::DeQuantOpKernel, ops::DeQuantOpKernel, + ops::DeQuantOpKernel); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 5ee15073267b6eac8978022a70ead5d0f439c62f..142ab2bb9d790175a843d1b81b74dc762a3213fd 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -41,6 +41,7 @@ namespace detail { // import numpy as np // print np.dtype(np.float16).num # 23 constexpr int NPY_FLOAT16_ = 23; +constexpr int NPY_UINT16_ = 4; // Note: Since float16 is not a builtin type in C++, we register // paddle::platform::float16 as numpy.float16. @@ -60,6 +61,23 @@ struct npy_format_descriptor { static PYBIND11_DESCR name() { return _("float16"); } }; +// Note: Since bfloat16 is not a builtin type in C++ and in numpy, +// we register paddle::platform::bfloat16 as numpy.uint16. +template <> +struct npy_format_descriptor { + static py::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_UINT16_); + return reinterpret_borrow(ptr); + } + static std::string format() { + // Note: "H" represents UINT16. + // Details at: + // https://docs.python.org/3/library/struct.html#format-characters. + return "H"; + } + static PYBIND11_DESCR name() { return _("bfloat16"); } +}; + } // namespace detail } // namespace pybind11 diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c7e66bb28770a659626c05dccbca7aa5d6bad10c..b4cea6761dcd84e047f98929644a1e264976503d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -613,7 +613,9 @@ def convert_np_dtype_to_dtype_(np_dtype): elif dtype == np.bool: return core.VarDesc.VarType.BOOL elif dtype == np.uint16: - return core.VarDesc.VarType.INT16 + # since there is still no support for bfloat16 in NumPy, + # uint16 is used for casting bfloat16 + return core.VarDesc.VarType.BF16 elif dtype == np.uint8: return core.VarDesc.VarType.UINT8 elif dtype == np.int8: diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac33383fb26b2a35362e8e39e5994d82d6fe497 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_bf16_mkldnn_op.py @@ -0,0 +1,208 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import struct + +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp + + +def conv2d_forward_refer(input, filter, group, conv_param): + out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group, + conv_param) + return out + + +def conv2d_residual_naive(out, residual): + assert out.shape == residual.shape + out = np.add(out, residual) + return out + + +class TestConv2dBf16Op(TestConv2dOp): + def setUp(self): + self.op_type = "conv2d" + self.use_cudnn = False + self.exhaustive_search = False + self.use_cuda = False + self.use_mkldnn = True + self.weight_type = np.float32 + self.input_type = np.float32 + self.use_mkldnn = True + self.mkldnn_data_type = "bfloat16" + self.force_fp32_output = False + self.init_group() + self.init_dilation() + self.init_test_case() + self.init_fuse_relu() + self.init_fuse_residual() + self.init_data_type() + self.init_force_fp32_output() + + conv2d_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + self.input = np.random.random(self.input_size).astype(np.float32) + self.filter = np.random.random(self.filter_size).astype(np.float32) + conv_out, _, _, _, _ = conv2d_forward_naive(self.input, self.filter, + self.groups, conv2d_param) + self.conv_output_float = conv_out + + if self.fuse_residual: + self.input_residual = np.random.random( + self.input_residual_size).astype(np.float32) + self.conv_output_float = conv2d_residual_naive( + self.conv_output_float, self.input_residual) + self.conv_output = convert_float_to_uint16(self.conv_output_float) + self.outputs = {'Output': self.conv_output} + elif self.force_fp32_output: + self.outputs = {'Output': self.conv_output_float.astype(np.float32)} + + if self.input_type is not np.float32: + self.input = convert_float_to_uint16(self.input) + + self.inputs = { + 'Input': self.input.view(self.input_type), + 'Filter': OpTest.np_dtype_to_fluid_dtype( + self.filter.astype(self.weight_type)) + } + + if self.fuse_residual: + self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype( + convert_float_to_uint16(self.input_residual)) + + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'mkldnn_data_type': self.mkldnn_data_type, + 'force_fp32_output': self.force_fp32_output, + 'fuse_residual_connection': self.fuse_residual + } + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + pass + + def test_check_grad_no_filter(self): + pass + + def test_check_grad_no_input(self): + pass + + def init_test_case(self): + TestConv2dOp.init_test_case(self) + self.input_size = [1, 1, 5, 5] # NCHW + f_c = self.input_size[1] // self.groups + self.input_residual_size = [1, 2, 3, 3] + self.filter_size = [2, f_c, 3, 3] + + def init_data_type(self): + self.weight_type = np.float32 + self.input_type = np.float32 + + def init_force_fp32_output(self): + self.force_fp32_output = False + + def init_fuse_relu(self): + self.fuse_activation = "relu" + + def init_fuse_residual(self): + self.fuse_residual = True + + +class TestConv2d(TestConv2dBf16Op): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + self.input_residual_size = [2, 6, 3, 3] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_data_type(self): + self.input_type = np.uint16 + + +class TestWithPad(TestConv2d): + def init_test_case(self): + TestConv2d.init_test_case(self) + self.pad = [1, 1] + self.input_residual_size = [2, 6, 5, 5] + + +class TestWithGroup(TestConv2d): + def init_group(self): + self.groups = 3 + + +class TestWithStride(TestConv2dBf16Op): + def init_test_case(self): + self.pad = [1, 1] + self.stride = [2, 2] + self.input_size = [2, 3, 6, 6] + self.input_residual_size = [2, 6, 3, 3] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + + def init_data_type(self): + self.input_type = np.uint16 + + +class TestWith1x1ForceFP32Output(TestConv2dBf16Op): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [1, 3, 5, 5] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 1, 1] + + def init_force_fp32_output(self): + self.force_fp32_output = True + + def init_fuse_residual(self): + self.fuse_residual = False + + +class TestWithInput1x1Filter1x1(TestConv2dBf16Op): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 1, 1] + self.input_residual_size = [2, 6, 1, 1] + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 1, 1] + + def init_group(self): + self.groups = 3 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index 7a494e3c2c3040356641d05772c883e15e4579e3..9731efced69d4b53bbb5b57b4d252d9a7a0c4f5a 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -36,6 +36,7 @@ class TestConv2dInt8Op(TestConv2dOp): self.use_cuda = False self.use_mkldnn = False self.data_format = "NCHW" + self.mkldnn_data_type = "int8" self.weighttype = np.float32 self.use_mkldnn = True self.init_group() @@ -141,7 +142,8 @@ class TestConv2dInt8Op(TestConv2dOp): 'Scale_weights': self.scale_weights, 'Scale_in_eltwise': self.scale_in_eltwise, 'fuse_activation': self.fuse_activation, - 'fuse_residual_connection': self.fuse_residual + 'fuse_residual_connection': self.fuse_residual, + 'mkldnn_data_type': self.mkldnn_data_type } self.outputs = {'Output': output} diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_dequantize_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_dequantize_mkldnn_op.py index 35419462909df1700219fbbe3841e4dbd094e719..70c76f1fb7186fcc983c0378af657d4aae2d2b32 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_dequantize_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_dequantize_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 class TestDeQuantizeOp(OpTest): @@ -32,6 +32,9 @@ class TestDeQuantizeOp(OpTest): input = (np.random.randint(0, 100, self.input_size) - 50 ).astype(self.data_type) output = (input * (1 / self.scale)).astype('float') + elif self.data_type == 'uint16': + output = np.random.random(self.input_size).astype(np.float32) + input = convert_float_to_uint16(output) else: input = (np.random.randint(0, 100, self.input_size)).astype(self.data_type) @@ -70,5 +73,13 @@ class TestDeQuantizeOp2(TestDeQuantizeOp): self.data_type = 'uint8' +class TestDeQuantizeOpBf16(TestDeQuantizeOp): + def set_scale(self): + self.scale = 1.0 + + def set_data_type(self): + self.data_type = 'uint16' + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index a6a4b9574c50e254def870783adbc0a0dc3c3ed8..d02fdafe99568b2e4cd55dcd92a7f8f26bc626a5 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -20,6 +20,7 @@ import warnings import numpy as np import random import six +import struct import time import itertools import collections @@ -167,6 +168,18 @@ def skip_check_grad_ci(reason=None): return wrapper +def copy_bits_from_float_to_uint16(f): + return struct.unpack('> 16 + + +def convert_float_to_uint16(float_list): + new_output = [] + for x in np.nditer(float_list): + new_output.append(np.uint16(copy_bits_from_float_to_uint16(x))) + + return np.reshape(new_output, float_list.shape).view(np.uint16) + + class OpTest(unittest.TestCase): @classmethod def setUpClass(cls): @@ -242,6 +255,11 @@ class OpTest(unittest.TestCase): self.call_once = True self.dtype = data_type + def is_bfloat16_op(self): + return self.dtype == np.uint16 or ( + hasattr(self, 'mkldnn_data_type') and + getattr(self, 'mkldnn_data_type') is "bfloat16") + def infer_dtype_from_inputs_outputs(self, inputs, outputs): def is_np_data(input): return isinstance(input, (np.ndarray, np.generic)) @@ -276,8 +294,9 @@ class OpTest(unittest.TestCase): infer_dtype(inputs, dtype_set) dtype_list = [ np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16), - np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.int16), - np.dtype(np.int8), np.dtype(np.uint8), np.dtype(np.bool) + np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.uint16), + np.dtype(np.int16), np.dtype(np.int8), np.dtype(np.uint8), + np.dtype(np.bool) ] # check the dtype in dtype_list in order, select the first dtype that in dtype_set for dtype in dtype_list: @@ -957,6 +976,14 @@ class OpTest(unittest.TestCase): self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST: atol = 0 + if self.is_bfloat16_op(): + check_dygraph = False + if hasattr(self, 'force_fp32_output') and getattr( + self, 'force_fp32_output'): + atol = 1e-2 + else: + atol = 2 + if no_check_set is not None: if self.op_type not in no_check_set_white_list.no_check_set_white_list: raise AssertionError( @@ -1286,8 +1313,9 @@ class OpTest(unittest.TestCase): no_grad_set = set() else: if (self.op_type not in no_grad_set_white_list.NEED_TO_FIX_OP_LIST - ) and (self.op_type not in - no_grad_set_white_list.NOT_CHECK_OP_LIST): + ) and ( + self.op_type not in no_grad_set_white_list.NOT_CHECK_OP_LIST + ) and (not self.is_bfloat16_op()): raise AssertionError("no_grad_set must be None, op_type is " + self.op_type + " Op.")