未验证 提交 8e6315e4 编写于 作者: P Piotr Paturej 提交者: GitHub

Add bf16 data type support to oneDNN bilinear_interp kernel (#46770)

* Enable bf16 in oneDNN bilinear_interp kernel

* Fix bilinear_interp_v2 not enabled in models

* Remove unnecessary checks
上级 e23dfed9
...@@ -2789,7 +2789,8 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -2789,7 +2789,8 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) { const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"cast", std::unordered_set<std::string>({"bilinear_interp_v2",
"cast",
"clip", "clip",
"concat", "concat",
"conv2d", "conv2d",
......
...@@ -56,9 +56,7 @@ int CPUBfloat16PlacementPass::SetMkldnnDataType(ir::Graph* graph) const { ...@@ -56,9 +56,7 @@ int CPUBfloat16PlacementPass::SetMkldnnDataType(ir::Graph* graph) const {
// Only float input can be converted to bfloat16 // Only float input can be converted to bfloat16
if (op_in->Var()->GetDataType() != proto::VarType::FP32) return; if (op_in->Var()->GetDataType() != proto::VarType::FP32) return;
if ((op->Op()->HasAttr("mkldnn_data_type") || if (platform::HasOpINT8DataType(op->Op()) == false) {
op->Op()->HasProtoAttr("mkldnn_data_type")) &&
!platform::HasOpINT8DataType(op->Op())) {
VLOG(4) << "--- marked " << op->Op()->Type() VLOG(4) << "--- marked " << op->Op()->Type()
<< " operator to bfloat16 "; << " operator to bfloat16 ";
op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16")); op->Op()->SetAttr("mkldnn_data_type", std::string("bfloat16"));
......
...@@ -329,6 +329,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -329,6 +329,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv_transpose_eltwiseadd_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass", "conv_transpose_bias_mkldnn_fuse_pass",
"interpolate_mkldnn_pass",
// TODO(baoachun): Need to support 5-dimensional input. // TODO(baoachun): Need to support 5-dimensional input.
// "conv3d_bias_mkldnn_fuse_pass", // // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
......
...@@ -227,8 +227,12 @@ void NearestInterpKernel( ...@@ -227,8 +227,12 @@ void NearestInterpKernel(
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(bilinear_interp,
bilinear_interp, OneDNN, ONEDNN, phi::BilinearInterpKernel, float) {} OneDNN,
ONEDNN,
phi::BilinearInterpKernel,
float,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(nearest_interp, PD_REGISTER_KERNEL(nearest_interp,
OneDNN, OneDNN,
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
import math import math
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
def bilinear_interp_mkldnn_np( def bilinear_interp_onednn_np(
input, out_h, out_w, out_size=None, actual_shape=None, data_layout='NCHW' input, out_h, out_w, out_size=None, actual_shape=None, data_layout='NCHW'
): ):
"""bilinear interpolation implement in shape [N, C, H, W]""" """bilinear interpolation implement in shape [N, C, H, W]"""
...@@ -65,17 +65,21 @@ def bilinear_interp_mkldnn_np( ...@@ -65,17 +65,21 @@ def bilinear_interp_mkldnn_np(
@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.") @skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
class TestBilinearInterpMKLDNNOp(OpTest): class TestBilinearInterpOneDNNOp(OpTest):
def init_test_case(self): def init_test_case(self):
pass pass
def init_data_type(self):
pass
def setUp(self): def setUp(self):
self.op_type = "bilinear_interp_v2" self.op_type = "bilinear_interp_v2"
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
self._cpu_only = True self._cpu_only = True
self.use_mkldnn = True self.use_onednn = True
self.input_shape = [1, 1, 2, 2] self.input_shape = [1, 1, 2, 2]
self.data_layout = 'NCHW' self.data_layout = 'NCHW'
self.dtype = np.float32
# priority: actual_shape > out_size > scale > out_h & out_w # priority: actual_shape > out_size > scale > out_h & out_w
self.out_h = 1 self.out_h = 1
self.out_w = 1 self.out_w = 1
...@@ -84,8 +88,12 @@ class TestBilinearInterpMKLDNNOp(OpTest): ...@@ -84,8 +88,12 @@ class TestBilinearInterpMKLDNNOp(OpTest):
self.actual_shape = None self.actual_shape = None
self.init_test_case() self.init_test_case()
self.init_data_type()
input_np = np.random.random(self.input_shape).astype(self.dtype)
if self.dtype == np.uint16:
input_np = convert_float_to_uint16(input_np)
input_np = np.random.random(self.input_shape).astype("float32")
if self.data_layout == "NCHW": if self.data_layout == "NCHW":
in_h = self.input_shape[2] in_h = self.input_shape[2]
in_w = self.input_shape[3] in_w = self.input_shape[3]
...@@ -114,7 +122,7 @@ class TestBilinearInterpMKLDNNOp(OpTest): ...@@ -114,7 +122,7 @@ class TestBilinearInterpMKLDNNOp(OpTest):
out_h = self.out_h out_h = self.out_h
out_w = self.out_w out_w = self.out_w
output_np = bilinear_interp_mkldnn_np( output_np = bilinear_interp_onednn_np(
input_np, input_np,
out_h, out_h,
out_w, out_w,
...@@ -137,7 +145,7 @@ class TestBilinearInterpMKLDNNOp(OpTest): ...@@ -137,7 +145,7 @@ class TestBilinearInterpMKLDNNOp(OpTest):
'out_w': self.out_w, 'out_w': self.out_w,
'scale': self.scale, 'scale': self.scale,
'data_layout': self.data_layout, 'data_layout': self.data_layout,
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_onednn,
} }
self.outputs = {'Out': output_np} self.outputs = {'Out': output_np}
...@@ -145,7 +153,7 @@ class TestBilinearInterpMKLDNNOp(OpTest): ...@@ -145,7 +153,7 @@ class TestBilinearInterpMKLDNNOp(OpTest):
self.check_output(check_dygraph=False) self.check_output(check_dygraph=False)
class TestBilinearInterpOpMKLDNNNHWC(TestBilinearInterpMKLDNNOp): class TestBilinearInterpOpOneDNNNHWC(TestBilinearInterpOneDNNOp):
def init_test_case(self): def init_test_case(self):
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 32, 16]
self.out_h = 27 self.out_h = 27
...@@ -154,14 +162,14 @@ class TestBilinearInterpOpMKLDNNNHWC(TestBilinearInterpMKLDNNOp): ...@@ -154,14 +162,14 @@ class TestBilinearInterpOpMKLDNNNHWC(TestBilinearInterpMKLDNNOp):
self.data_layout = 'NHWC' self.data_layout = 'NHWC'
class TestBilinearNeighborInterpMKLDNNCase2(TestBilinearInterpMKLDNNOp): class TestBilinearNeighborInterpOneDNNCase2(TestBilinearInterpOneDNNOp):
def init_test_case(self): def init_test_case(self):
self.input_shape = [3, 3, 9, 6] self.input_shape = [3, 3, 9, 6]
self.out_h = 12 self.out_h = 12
self.out_w = 12 self.out_w = 12
class TestBilinearNeighborInterpCase3(TestBilinearInterpMKLDNNOp): class TestBilinearNeighborInterpOneDNNCase3(TestBilinearInterpOneDNNOp):
def init_test_case(self): def init_test_case(self):
self.input_shape = [1, 1, 32, 64] self.input_shape = [1, 1, 32, 64]
self.out_h = 64 self.out_h = 64
...@@ -169,7 +177,7 @@ class TestBilinearNeighborInterpCase3(TestBilinearInterpMKLDNNOp): ...@@ -169,7 +177,7 @@ class TestBilinearNeighborInterpCase3(TestBilinearInterpMKLDNNOp):
self.scale = [0.1, 0.05] self.scale = [0.1, 0.05]
class TestBilinearNeighborInterpCase4(TestBilinearInterpMKLDNNOp): class TestBilinearNeighborInterpOneDNNCase4(TestBilinearInterpOneDNNOp):
def init_test_case(self): def init_test_case(self):
self.input_shape = [1, 1, 32, 64] self.input_shape = [1, 1, 32, 64]
self.out_h = 64 self.out_h = 64
...@@ -178,7 +186,7 @@ class TestBilinearNeighborInterpCase4(TestBilinearInterpMKLDNNOp): ...@@ -178,7 +186,7 @@ class TestBilinearNeighborInterpCase4(TestBilinearInterpMKLDNNOp):
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearNeighborInterpCase5(TestBilinearInterpMKLDNNOp): class TestBilinearNeighborInterpOneDNNCase5(TestBilinearInterpOneDNNOp):
def init_test_case(self): def init_test_case(self):
self.input_shape = [1, 1, 9, 6] self.input_shape = [1, 1, 9, 6]
self.out_h = 12 self.out_h = 12
...@@ -186,7 +194,7 @@ class TestBilinearNeighborInterpCase5(TestBilinearInterpMKLDNNOp): ...@@ -186,7 +194,7 @@ class TestBilinearNeighborInterpCase5(TestBilinearInterpMKLDNNOp):
self.out_size = np.array([13, 13]).astype("int32") self.out_size = np.array([13, 13]).astype("int32")
class TestBilinearNeighborInterpCase6(TestBilinearInterpMKLDNNOp): class TestBilinearNeighborInterpOneDNNCase6(TestBilinearInterpOneDNNOp):
def init_test_case(self): def init_test_case(self):
self.input_shape = [1, 1, 32, 64] self.input_shape = [1, 1, 32, 64]
self.out_h = 64 self.out_h = 64
...@@ -195,7 +203,7 @@ class TestBilinearNeighborInterpCase6(TestBilinearInterpMKLDNNOp): ...@@ -195,7 +203,7 @@ class TestBilinearNeighborInterpCase6(TestBilinearInterpMKLDNNOp):
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp): class TestBilinearNeighborInterpOneDNNSame(TestBilinearInterpOneDNNOp):
def init_test_case(self): def init_test_case(self):
self.input_shape = [2, 3, 32, 64] self.input_shape = [2, 3, 32, 64]
self.out_h = 32 self.out_h = 32
...@@ -204,6 +212,24 @@ class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp): ...@@ -204,6 +212,24 @@ class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp):
self.out_size = np.array([65, 129]).astype("int32") self.out_size = np.array([65, 129]).astype("int32")
def create_test_class(parent):
class TestBf16Case(parent):
def init_data_type(self):
self.dtype = np.uint16
TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
globals()[TestBf16Case.__name__] = TestBf16Case
create_test_class(TestBilinearInterpOneDNNOp)
create_test_class(TestBilinearInterpOpOneDNNNHWC)
create_test_class(TestBilinearNeighborInterpOneDNNCase2)
create_test_class(TestBilinearNeighborInterpOneDNNCase3)
create_test_class(TestBilinearNeighborInterpOneDNNCase4)
create_test_class(TestBilinearNeighborInterpOneDNNCase5)
create_test_class(TestBilinearNeighborInterpOneDNNCase6)
create_test_class(TestBilinearNeighborInterpOneDNNSame)
if __name__ == "__main__": if __name__ == "__main__":
from paddle import enable_static from paddle import enable_static
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册