未验证 提交 2ec943a7 编写于 作者: J jakpiase 提交者: GitHub

Added nearest interp v2 BF16 FWD kernel (#39490)

* added nearest interp v2 bf16

* disabled bilinear interp nhwc test

* added skipping UT for gpu

* added NHWC support

* removed unnecessary statements

* minor change

* CI fix

* added appropriate changes to interpolate_v1

* fix after review

* minor change

* minor change

* revert unwanted deletions

* CI fix
上级 1abfc8dd
......@@ -328,7 +328,7 @@ class InterpolateOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
......
......@@ -414,7 +414,7 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
auto interp_method = ctx.Attr<std::string>("interp_method");
const auto& interp_method = ctx.Attr<std::string>("interp_method");
// TODO(danqing): support other interp_method
if (this->CanMKLDNNBeUsed(ctx, data_type) &&
(interp_method == "nearest" || interp_method == "bilinear")) {
......
......@@ -53,17 +53,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
std::vector<int> ComputeOutputShape(
const framework::ExecutionContext& ctx) const {
const auto* x = ctx.Input<Tensor>("X");
auto in_dims = x->dims();
const bool is_channel_last = false; // In mkldnn kernel, always use NCHW
framework::DDim in_dhw_dims;
if (is_channel_last) { // NDHWC, NHWC, NWC
in_dhw_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1);
} else { // NCDHW, NCHW, NCW
in_dhw_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
}
const auto& in_dims = x->dims();
const framework::DDim in_dhw_dims =
phi::slice_ddim(in_dims, 2, in_dims.size());
std::vector<int> out_dims;
out_dims.reserve(5);
if (in_dhw_dims.size() == 1) {
out_dims.push_back(ctx.Attr<int>("out_w"));
} else if (in_dhw_dims.size() == 2) {
......@@ -125,12 +121,8 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
"out_d, out_h, out_w of Op(interpolate) "
"should be greater than 0."));
out_dims.insert(out_dims.begin(), in_dims[0]);
if (is_channel_last) {
out_dims.push_back(in_dims[in_dims.size() - 1]);
} else {
out_dims.insert(out_dims.begin() + 1, in_dims[1]);
}
const std::vector<int64_t> nc_dims = {in_dims[0], in_dims[1]};
out_dims.insert(out_dims.begin(), nc_dims.begin(), nc_dims.end());
return out_dims;
}
......@@ -143,12 +135,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const auto* x = ctx.Input<Tensor>("X");
auto* z = ctx.Output<Tensor>("Out");
auto interp_method = ctx.Attr<std::string>("interp_method");
dnnl::algorithm algo = (interp_method == "nearest")
const auto interp_method = ctx.Attr<std::string>("interp_method");
const dnnl::algorithm algo = (interp_method == "nearest")
? dnnl::algorithm::resampling_nearest
: dnnl::algorithm::resampling_linear;
auto out_dims_vec = ComputeOutputShape(ctx);
const auto out_dims_vec = ComputeOutputShape(ctx);
framework::DDim dim_out = phi::make_ddim(out_dims_vec);
z->Resize(dim_out);
......@@ -162,6 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
resampling_prim->execute(astream, args);
astream.wait();
......@@ -184,6 +177,7 @@ REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,
REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<paddle::platform::bfloat16>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace,
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
......@@ -59,6 +59,7 @@ def nearest_neighbor_interp_mkldnn_np(X,
@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
@OpTestTool.skip_if_not_cpu_bf16()
class TestNearestInterpV2MKLDNNOp(OpTest):
def init_test_case(self):
pass
......@@ -84,7 +85,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
self.init_test_case()
self.init_data_type()
if self.dtype == np.float32:
if self.dtype == np.float32 or self.dtype == np.uint16:
input_np = np.random.random(self.input_shape).astype(self.dtype)
else:
init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10)
......@@ -126,6 +127,9 @@ class TestNearestInterpV2MKLDNNOp(OpTest):
if isinstance(self.scale, float):
self.scale = [self.scale]
if self.dtype == np.uint16:
input_np = convert_float_to_uint16(input_np)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
......@@ -191,6 +195,10 @@ def create_test_class(parent):
def init_data_type(self):
self.dtype = np.float32
class TestBf16Case(parent):
def init_data_type(self):
self.dtype = np.uint16
class TestInt8Case(parent):
def init_data_type(self):
self.dtype = np.int8
......@@ -199,12 +207,14 @@ def create_test_class(parent):
def init_data_type(self):
self.dtype = np.uint8
TestFp32Case.__name__ = parent.__name__
TestInt8Case.__name__ = parent.__name__
TestUint8Case.__name__ = parent.__name__
globals()[parent.__name__] = TestFp32Case
globals()[parent.__name__] = TestInt8Case
globals()[parent.__name__] = TestUint8Case
TestFp32Case.__name__ = "{0}_{1}".format(parent.__name__, "FP32")
TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16")
TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8")
TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8")
globals()[TestFp32Case.__name__] = TestFp32Case
globals()[TestBf16Case.__name__] = TestBf16Case
globals()[TestInt8Case.__name__] = TestInt8Case
globals()[TestUint8Case.__name__] = TestUint8Case
create_test_class(TestNearestInterpV2MKLDNNOp)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册