未验证 提交 2ec943a7 编写于 作者: J jakpiase 提交者: GitHub

Added nearest interp v2 BF16 FWD kernel (#39490)

* added nearest interp v2 bf16

* disabled bilinear interp nhwc test

* added skipping UT for gpu

* added NHWC support

* removed unnecessary statements

* minor change

* CI fix

* added appropriate changes to interpolate_v1

* fix after review

* minor change

* minor change

* revert unwanted deletions

* CI fix
上级 1abfc8dd
...@@ -328,7 +328,7 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -328,7 +328,7 @@ class InterpolateOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method"); const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method // TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) && if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) { (interp_method == "nearest" || interp_method == "bilinear")) {
......
...@@ -414,7 +414,7 @@ class InterpolateV2Op : public framework::OperatorWithKernel { ...@@ -414,7 +414,7 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method"); const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method // TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) && if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) { (interp_method == "nearest" || interp_method == "bilinear")) {
......
...@@ -53,17 +53,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -53,17 +53,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int> ComputeOutputShape( std::vector<int> ComputeOutputShape(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
const auto* x = ctx.Input<Tensor>("X"); const auto* x = ctx.Input<Tensor>("X");
auto in_dims = x->dims(); const auto& in_dims = x->dims();
const bool is_channel_last = false; // In mkldnn kernel, always use NCHW
const framework::DDim in_dhw_dims =
framework::DDim in_dhw_dims; phi::slice_ddim(in_dims, 2, in_dims.size());
if (is_channel_last) { // NDHWC, NHWC, NWC
in_dhw_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { // NCDHW, NCHW, NCW
in_dhw_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
std::vector<int> out_dims; std::vector<int> out_dims;
out_dims.reserve(5);
if (in_dhw_dims.size() == 1) { if (in_dhw_dims.size() == 1) {
out_dims.push_back(ctx.Attr<int>("out_w")); out_dims.push_back(ctx.Attr<int>("out_w"));
} else if (in_dhw_dims.size() == 2) { } else if (in_dhw_dims.size() == 2) {
...@@ -125,12 +121,8 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -125,12 +121,8 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
"out_d, out_h, out_w of Op(interpolate) " "out_d, out_h, out_w of Op(interpolate) "
"should be greater than 0.")); "should be greater than 0."));
out_dims.insert(out_dims.begin(), in_dims[0]); const std::vector<int64_t> nc_dims = {in_dims[0], in_dims[1]};
if (is_channel_last) { out_dims.insert(out_dims.begin(), nc_dims.begin(), nc_dims.end());
out_dims.push_back(in_dims[in_dims.size() - 1]);
} else {
out_dims.insert(out_dims.begin() + 1, in_dims[1]);
}
return out_dims; return out_dims;
} }
...@@ -143,12 +135,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -143,12 +135,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto* x = ctx.Input<Tensor>("X"); const auto* x = ctx.Input<Tensor>("X");
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
auto interp_method = ctx.Attr<std::string>("interp_method"); const auto interp_method = ctx.Attr<std::string>("interp_method");
dnnl::algorithm algo = (interp_method == "nearest") const dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest ? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear; : dnnl::algorithm::resampling_linear;
auto out_dims_vec = ComputeOutputShape(ctx); const auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = phi::make_ddim(out_dims_vec); framework::DDim dim_out = phi::make_ddim(out_dims_vec);
z->Resize(dim_out); z->Resize(dim_out);
...@@ -162,6 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> { ...@@ -162,6 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
resampling_prim->execute(astream, args); resampling_prim->execute(astream, args);
astream.wait(); astream.wait();
...@@ -184,6 +177,7 @@ REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace, ...@@ -184,6 +177,7 @@ REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,
REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>, ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<paddle::platform::bfloat16>,
ops::InterpolateMKLDNNKernel<int8_t>, ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>); ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
......
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
...@@ -59,6 +59,7 @@ def nearest_neighbor_interp_mkldnn_np(X, ...@@ -59,6 +59,7 @@ def nearest_neighbor_interp_mkldnn_np(X,
@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.") @skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
@OpTestTool.skip_if_not_cpu_bf16()
class TestNearestInterpV2MKLDNNOp(OpTest): class TestNearestInterpV2MKLDNNOp(OpTest):
def init_test_case(self): def init_test_case(self):
pass pass
...@@ -84,7 +85,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest): ...@@ -84,7 +85,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
self.init_test_case() self.init_test_case()
self.init_data_type() self.init_data_type()
if self.dtype == np.float32: if self.dtype == np.float32 or self.dtype == np.uint16:
input_np = np.random.random(self.input_shape).astype(self.dtype) input_np = np.random.random(self.input_shape).astype(self.dtype)
else: else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10) init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
...@@ -126,6 +127,9 @@ class TestNearestInterpV2MKLDNNOp(OpTest): ...@@ -126,6 +127,9 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
if isinstance(self.scale, float): if isinstance(self.scale, float):
self.scale = [self.scale] self.scale = [self.scale]
if self.dtype == np.uint16:
input_np = convert_float_to_uint16(input_np)
self.inputs = {'X': input_np} self.inputs = {'X': input_np}
if self.out_size is not None: if self.out_size is not None:
self.inputs['OutSize'] = self.out_size self.inputs['OutSize'] = self.out_size
...@@ -191,6 +195,10 @@ def create_test_class(parent): ...@@ -191,6 +195,10 @@ def create_test_class(parent):
def init_data_type(self): def init_data_type(self):
self.dtype = np.float32 self.dtype = np.float32
class TestBf16Case(parent):
def init_data_type(self):
self.dtype = np.uint16
class TestInt8Case(parent): class TestInt8Case(parent):
def init_data_type(self): def init_data_type(self):
self.dtype = np.int8 self.dtype = np.int8
...@@ -199,12 +207,14 @@ def create_test_class(parent): ...@@ -199,12 +207,14 @@ def create_test_class(parent):
def init_data_type(self): def init_data_type(self):
self.dtype = np.uint8 self.dtype = np.uint8
TestFp32Case.__name__ = parent.__name__ TestFp32Case.__name__ = "{0}_{1}".format(parent.__name__, "FP32")
TestInt8Case.__name__ = parent.__name__ TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
TestUint8Case.__name__ = parent.__name__ TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
globals()[parent.__name__] = TestFp32Case TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[parent.__name__] = TestInt8Case globals()[TestFp32Case.__name__] = TestFp32Case
globals()[parent.__name__] = TestUint8Case globals()[TestBf16Case.__name__] = TestBf16Case
globals()[TestInt8Case.__name__] = TestInt8Case
globals()[TestUint8Case.__name__] = TestUint8Case
create_test_class(TestNearestInterpV2MKLDNNOp) create_test_class(TestNearestInterpV2MKLDNNOp)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册