未验证 提交 b0ee1405 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add conv2d bfloat16 support (#27325)

上级 b38e4f28
...@@ -1894,8 +1894,7 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -1894,8 +1894,7 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) { const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types{"conv2d"};
std::unordered_set<std::string>();
if (!bfloat16_enabled_op_types.empty()) { if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types; supported_op_types = bfloat16_enabled_op_types;
} }
......
...@@ -166,7 +166,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -166,7 +166,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#endif #endif
if (input_data_type != framework::proto::VarType::INT8 && if (input_data_type != framework::proto::VarType::INT8 &&
input_data_type != framework::proto::VarType::UINT8) { input_data_type != framework::proto::VarType::UINT8 &&
input_data_type != framework::proto::VarType::BF16) {
auto filter_data_type = ctx.Input<Tensor>("Filter")->type(); auto filter_data_type = ctx.Input<Tensor>("Filter")->type();
PADDLE_ENFORCE_EQ(input_data_type, filter_data_type, PADDLE_ENFORCE_EQ(input_data_type, filter_data_type,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -455,6 +456,11 @@ void Conv3DOpMaker::Make() { ...@@ -455,6 +456,11 @@ void Conv3DOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"});
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>("fuse_activation", AddAttr<std::string>("fuse_activation",
......
...@@ -55,12 +55,12 @@ inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format, ...@@ -55,12 +55,12 @@ inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
} }
} }
static mkldnn::memory::data_type GetDstType(bool is_int8, static mkldnn::memory::data_type GetDstType(bool is_int8, bool is_bfloat16,
bool force_fp32_output, bool force_fp32_output,
std::string fuse_activation, std::string fuse_activation,
bool fuse_residual_conn, bool fuse_residual_conn,
const Tensor* residual_param) { const Tensor* residual_param) {
auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float auto dst_dt = mkldnn::memory::data_type::f32;
if (is_int8) { if (is_int8) {
dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6")
? mkldnn::memory::data_type::u8 ? mkldnn::memory::data_type::u8
...@@ -72,6 +72,13 @@ static mkldnn::memory::data_type GetDstType(bool is_int8, ...@@ -72,6 +72,13 @@ static mkldnn::memory::data_type GetDstType(bool is_int8,
auto residual_dt = framework::ToMKLDNNDataType(residual_param->type()); auto residual_dt = framework::ToMKLDNNDataType(residual_param->type());
if (dst_dt != residual_dt) dst_dt = residual_dt; if (dst_dt != residual_dt) dst_dt = residual_dt;
} }
} else {
if (!force_fp32_output && is_bfloat16) {
dst_dt = mkldnn::memory::data_type::bf16;
if (fuse_residual_conn && residual_param) {
dst_dt = framework::ToMKLDNNDataType(residual_param->type());
}
}
} }
return dst_dt; return dst_dt;
} }
...@@ -224,12 +231,15 @@ class ConvMKLDNNHandlerT ...@@ -224,12 +231,15 @@ class ConvMKLDNNHandlerT
src_tz.size(), chosen_memory_format); src_tz.size(), chosen_memory_format);
} }
} }
auto data_type = mkldnn::memory::data_type::f32;
const auto src_md = platform::MKLDNNMemDesc( if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); std::is_same<T_out, platform::bfloat16>::value)
const auto weights_md = data_type = mkldnn::memory::data_type::bf16;
platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any); const auto src_md =
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
const auto weights_md = platform::MKLDNNMemDesc(weights_tz, data_type,
MKLDNNMemoryFormat::any);
const auto dst_md = platform::MKLDNNMemDesc( const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
...@@ -241,8 +251,8 @@ class ConvMKLDNNHandlerT ...@@ -241,8 +251,8 @@ class ConvMKLDNNHandlerT
if (bias) { if (bias) {
auto bias_tz = framework::vectorize(bias->dims()); auto bias_tz = framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md =
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
...@@ -384,15 +394,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -384,15 +394,21 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Operator DNNL Conv must use CPUPlace")); "Operator DNNL Conv must use CPUPlace"));
bool is_INT8 = bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
bool is_BFLOAT16 = ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
auto residual_param = ctx.Input<Tensor>("ResidualData");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto dst_dt =
GetDstType(is_INT8, is_BFLOAT16, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (!is_INT8) { if (!is_INT8) {
ComputeFP32<float>(ctx); if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeFP32<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::bf16) {
ComputeFP32<platform::bfloat16>(ctx);
}
} else { } else {
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) { if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx); ComputeINT8<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::u8) { } else if (dst_dt == mkldnn::memory::data_type::u8) {
...@@ -1103,6 +1119,10 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, ...@@ -1103,6 +1119,10 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
ops::kConvMKLDNNFP32, ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float, float>); ops::ConvMKLDNNOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
conv2d, MKLDNN, ::paddle::platform::CPUPlace, BF16, ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
::paddle::platform::CPUPlace, U8, ::paddle::platform::CPUPlace, U8,
ops::kConvMKLDNNINT8, ops::kConvMKLDNNINT8,
......
...@@ -110,4 +110,5 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -110,4 +110,5 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace,
ops::DeQuantOpKernel<uint8_t>, ops::DeQuantOpKernel<int8_t>); ops::DeQuantOpKernel<uint8_t>, ops::DeQuantOpKernel<int8_t>,
ops::DeQuantOpKernel<paddle::platform::bfloat16>);
...@@ -41,6 +41,7 @@ namespace detail { ...@@ -41,6 +41,7 @@ namespace detail {
// import numpy as np // import numpy as np
// print np.dtype(np.float16).num # 23 // print np.dtype(np.float16).num # 23
constexpr int NPY_FLOAT16_ = 23; constexpr int NPY_FLOAT16_ = 23;
constexpr int NPY_UINT16_ = 4;
// Note: Since float16 is not a builtin type in C++, we register // Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16. // paddle::platform::float16 as numpy.float16.
...@@ -60,6 +61,23 @@ struct npy_format_descriptor<paddle::platform::float16> { ...@@ -60,6 +61,23 @@ struct npy_format_descriptor<paddle::platform::float16> {
static PYBIND11_DESCR name() { return _("float16"); } static PYBIND11_DESCR name() { return _("float16"); }
}; };
// Note: Since bfloat16 is not a builtin type in C++ and in numpy,
// we register paddle::platform::bfloat16 as numpy.uint16.
template <>
struct npy_format_descriptor<paddle::platform::bfloat16> {
static py::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_UINT16_);
return reinterpret_borrow<py::dtype>(ptr);
}
static std::string format() {
// Note: "H" represents UINT16.
// Details at:
// https://docs.python.org/3/library/struct.html#format-characters.
return "H";
}
static PYBIND11_DESCR name() { return _("bfloat16"); }
};
} // namespace detail } // namespace detail
} // namespace pybind11 } // namespace pybind11
......
...@@ -613,7 +613,9 @@ def convert_np_dtype_to_dtype_(np_dtype): ...@@ -613,7 +613,9 @@ def convert_np_dtype_to_dtype_(np_dtype):
elif dtype == np.bool: elif dtype == np.bool:
return core.VarDesc.VarType.BOOL return core.VarDesc.VarType.BOOL
elif dtype == np.uint16: elif dtype == np.uint16:
return core.VarDesc.VarType.INT16 # since there is still no support for bfloat16 in NumPy,
# uint16 is used for casting bfloat16
return core.VarDesc.VarType.BF16
elif dtype == np.uint8: elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8 return core.VarDesc.VarType.UINT8
elif dtype == np.int8: elif dtype == np.int8:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import struct
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp
def conv2d_forward_refer(input, filter, group, conv_param):
out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group,
conv_param)
return out
def conv2d_residual_naive(out, residual):
assert out.shape == residual.shape
out = np.add(out, residual)
return out
class TestConv2dBf16Op(TestConv2dOp):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
self.exhaustive_search = False
self.use_cuda = False
self.use_mkldnn = True
self.weight_type = np.float32
self.input_type = np.float32
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
self.force_fp32_output = False
self.init_group()
self.init_dilation()
self.init_test_case()
self.init_fuse_relu()
self.init_fuse_residual()
self.init_data_type()
self.init_force_fp32_output()
conv2d_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
self.input = np.random.random(self.input_size).astype(np.float32)
self.filter = np.random.random(self.filter_size).astype(np.float32)
conv_out, _, _, _, _ = conv2d_forward_naive(self.input, self.filter,
self.groups, conv2d_param)
self.conv_output_float = conv_out
if self.fuse_residual:
self.input_residual = np.random.random(
self.input_residual_size).astype(np.float32)
self.conv_output_float = conv2d_residual_naive(
self.conv_output_float, self.input_residual)
self.conv_output = convert_float_to_uint16(self.conv_output_float)
self.outputs = {'Output': self.conv_output}
elif self.force_fp32_output:
self.outputs = {'Output': self.conv_output_float.astype(np.float32)}
if self.input_type is not np.float32:
self.input = convert_float_to_uint16(self.input)
self.inputs = {
'Input': self.input.view(self.input_type),
'Filter': OpTest.np_dtype_to_fluid_dtype(
self.filter.astype(self.weight_type))
}
if self.fuse_residual:
self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype(
convert_float_to_uint16(self.input_residual))
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
'force_fp32_output': self.force_fp32_output,
'fuse_residual_connection': self.fuse_residual
}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
def test_check_grad_no_filter(self):
pass
def test_check_grad_no_input(self):
pass
def init_test_case(self):
TestConv2dOp.init_test_case(self)
self.input_size = [1, 1, 5, 5] # NCHW
f_c = self.input_size[1] // self.groups
self.input_residual_size = [1, 2, 3, 3]
self.filter_size = [2, f_c, 3, 3]
def init_data_type(self):
self.weight_type = np.float32
self.input_type = np.float32
def init_force_fp32_output(self):
self.force_fp32_output = False
def init_fuse_relu(self):
self.fuse_activation = "relu"
def init_fuse_residual(self):
self.fuse_residual = True
class TestConv2d(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.input_residual_size = [2, 6, 3, 3]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_data_type(self):
self.input_type = np.uint16
class TestWithPad(TestConv2d):
def init_test_case(self):
TestConv2d.init_test_case(self)
self.pad = [1, 1]
self.input_residual_size = [2, 6, 5, 5]
class TestWithGroup(TestConv2d):
def init_group(self):
self.groups = 3
class TestWithStride(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 6, 6]
self.input_residual_size = [2, 6, 3, 3]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
def init_data_type(self):
self.input_type = np.uint16
class TestWith1x1ForceFP32Output(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [1, 3, 5, 5]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_force_fp32_output(self):
self.force_fp32_output = True
def init_fuse_residual(self):
self.fuse_residual = False
class TestWithInput1x1Filter1x1(TestConv2dBf16Op):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 1, 1]
self.input_residual_size = [2, 6, 1, 1]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
def init_group(self):
self.groups = 3
if __name__ == '__main__':
unittest.main()
...@@ -36,6 +36,7 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -36,6 +36,7 @@ class TestConv2dInt8Op(TestConv2dOp):
self.use_cuda = False self.use_cuda = False
self.use_mkldnn = False self.use_mkldnn = False
self.data_format = "NCHW" self.data_format = "NCHW"
self.mkldnn_data_type = "int8"
self.weighttype = np.float32 self.weighttype = np.float32
self.use_mkldnn = True self.use_mkldnn = True
self.init_group() self.init_group()
...@@ -141,7 +142,8 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -141,7 +142,8 @@ class TestConv2dInt8Op(TestConv2dOp):
'Scale_weights': self.scale_weights, 'Scale_weights': self.scale_weights,
'Scale_in_eltwise': self.scale_in_eltwise, 'Scale_in_eltwise': self.scale_in_eltwise,
'fuse_activation': self.fuse_activation, 'fuse_activation': self.fuse_activation,
'fuse_residual_connection': self.fuse_residual 'fuse_residual_connection': self.fuse_residual,
'mkldnn_data_type': self.mkldnn_data_type
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
class TestDeQuantizeOp(OpTest): class TestDeQuantizeOp(OpTest):
...@@ -32,6 +32,9 @@ class TestDeQuantizeOp(OpTest): ...@@ -32,6 +32,9 @@ class TestDeQuantizeOp(OpTest):
input = (np.random.randint(0, 100, self.input_size) - 50 input = (np.random.randint(0, 100, self.input_size) - 50
).astype(self.data_type) ).astype(self.data_type)
output = (input * (1 / self.scale)).astype('float') output = (input * (1 / self.scale)).astype('float')
elif self.data_type == 'uint16':
output = np.random.random(self.input_size).astype(np.float32)
input = convert_float_to_uint16(output)
else: else:
input = (np.random.randint(0, 100, input = (np.random.randint(0, 100,
self.input_size)).astype(self.data_type) self.input_size)).astype(self.data_type)
...@@ -70,5 +73,13 @@ class TestDeQuantizeOp2(TestDeQuantizeOp): ...@@ -70,5 +73,13 @@ class TestDeQuantizeOp2(TestDeQuantizeOp):
self.data_type = 'uint8' self.data_type = 'uint8'
class TestDeQuantizeOpBf16(TestDeQuantizeOp):
def set_scale(self):
self.scale = 1.0
def set_data_type(self):
self.data_type = 'uint16'
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,6 +20,7 @@ import warnings ...@@ -20,6 +20,7 @@ import warnings
import numpy as np import numpy as np
import random import random
import six import six
import struct
import time import time
import itertools import itertools
import collections import collections
...@@ -167,6 +168,18 @@ def skip_check_grad_ci(reason=None): ...@@ -167,6 +168,18 @@ def skip_check_grad_ci(reason=None):
return wrapper return wrapper
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(float_list):
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
return np.reshape(new_output, float_list.shape).view(np.uint16)
class OpTest(unittest.TestCase): class OpTest(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -242,6 +255,11 @@ class OpTest(unittest.TestCase): ...@@ -242,6 +255,11 @@ class OpTest(unittest.TestCase):
self.call_once = True self.call_once = True
self.dtype = data_type self.dtype = data_type
def is_bfloat16_op(self):
return self.dtype == np.uint16 or (
hasattr(self, 'mkldnn_data_type') and
getattr(self, 'mkldnn_data_type') is "bfloat16")
def infer_dtype_from_inputs_outputs(self, inputs, outputs): def infer_dtype_from_inputs_outputs(self, inputs, outputs):
def is_np_data(input): def is_np_data(input):
return isinstance(input, (np.ndarray, np.generic)) return isinstance(input, (np.ndarray, np.generic))
...@@ -276,8 +294,9 @@ class OpTest(unittest.TestCase): ...@@ -276,8 +294,9 @@ class OpTest(unittest.TestCase):
infer_dtype(inputs, dtype_set) infer_dtype(inputs, dtype_set)
dtype_list = [ dtype_list = [
np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16), np.dtype(np.float64), np.dtype(np.float32), np.dtype(np.float16),
np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.int16), np.dtype(np.int64), np.dtype(np.int32), np.dtype(np.uint16),
np.dtype(np.int8), np.dtype(np.uint8), np.dtype(np.bool) np.dtype(np.int16), np.dtype(np.int8), np.dtype(np.uint8),
np.dtype(np.bool)
] ]
# check the dtype in dtype_list in order, select the first dtype that in dtype_set # check the dtype in dtype_list in order, select the first dtype that in dtype_set
for dtype in dtype_list: for dtype in dtype_list:
...@@ -957,6 +976,14 @@ class OpTest(unittest.TestCase): ...@@ -957,6 +976,14 @@ class OpTest(unittest.TestCase):
self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST: self.op_type not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST:
atol = 0 atol = 0
if self.is_bfloat16_op():
check_dygraph = False
if hasattr(self, 'force_fp32_output') and getattr(
self, 'force_fp32_output'):
atol = 1e-2
else:
atol = 2
if no_check_set is not None: if no_check_set is not None:
if self.op_type not in no_check_set_white_list.no_check_set_white_list: if self.op_type not in no_check_set_white_list.no_check_set_white_list:
raise AssertionError( raise AssertionError(
...@@ -1286,8 +1313,9 @@ class OpTest(unittest.TestCase): ...@@ -1286,8 +1313,9 @@ class OpTest(unittest.TestCase):
no_grad_set = set() no_grad_set = set()
else: else:
if (self.op_type not in no_grad_set_white_list.NEED_TO_FIX_OP_LIST if (self.op_type not in no_grad_set_white_list.NEED_TO_FIX_OP_LIST
) and (self.op_type not in ) and (
no_grad_set_white_list.NOT_CHECK_OP_LIST): self.op_type not in no_grad_set_white_list.NOT_CHECK_OP_LIST
) and (not self.is_bfloat16_op()):
raise AssertionError("no_grad_set must be None, op_type is " + raise AssertionError("no_grad_set must be None, op_type is " +
self.op_type + " Op.") self.op_type + " Op.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册