diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 8fac84176d97fd371ddfac25dab2aee8c098607a..fda168c94e1e064c65e3b5fcf56b606772345b9d 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -328,7 +328,7 @@ class InterpolateOp : public framework::OperatorWithKernel { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - auto interp_method = ctx.Attr("interp_method"); + const auto& interp_method = ctx.Attr("interp_method"); // TODO(danqing): support other interp_method if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index 7783303785998e9db05a5f5117a047e2729de848..4b5a18141d5aa9ac5d1f5354fafbad0e38bb8474 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -414,7 +414,7 @@ class InterpolateV2Op : public framework::OperatorWithKernel { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - auto interp_method = ctx.Attr("interp_method"); + const auto& interp_method = ctx.Attr("interp_method"); // TODO(danqing): support other interp_method if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { diff --git a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc index 33ea36d24b8aef833890277fd69ed02e4859802f..04b90d2f1f380a72dd076774f2b68c2d1bc7e55b 100644 --- a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc @@ -53,17 +53,13 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { std::vector ComputeOutputShape( const framework::ExecutionContext& ctx) const { const auto* x = ctx.Input("X"); - auto in_dims = x->dims(); - const bool is_channel_last = false; // In mkldnn kernel, always use NCHW - - framework::DDim in_dhw_dims; - if (is_channel_last) { // NDHWC, NHWC, NWC - in_dhw_dims = phi::slice_ddim(in_dims, 1, in_dims.size() - 1); - } else { // NCDHW, NCHW, NCW - in_dhw_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); - } + const auto& in_dims = x->dims(); + + const framework::DDim in_dhw_dims = + phi::slice_ddim(in_dims, 2, in_dims.size()); std::vector out_dims; + out_dims.reserve(5); if (in_dhw_dims.size() == 1) { out_dims.push_back(ctx.Attr("out_w")); } else if (in_dhw_dims.size() == 2) { @@ -125,12 +121,8 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { "out_d, out_h, out_w of Op(interpolate) " "should be greater than 0.")); - out_dims.insert(out_dims.begin(), in_dims[0]); - if (is_channel_last) { - out_dims.push_back(in_dims[in_dims.size() - 1]); - } else { - out_dims.insert(out_dims.begin() + 1, in_dims[1]); - } + const std::vector nc_dims = {in_dims[0], in_dims[1]}; + out_dims.insert(out_dims.begin(), nc_dims.begin(), nc_dims.end()); return out_dims; } @@ -143,12 +135,12 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { const auto* x = ctx.Input("X"); auto* z = ctx.Output("Out"); - auto interp_method = ctx.Attr("interp_method"); - dnnl::algorithm algo = (interp_method == "nearest") - ? dnnl::algorithm::resampling_nearest - : dnnl::algorithm::resampling_linear; + const auto interp_method = ctx.Attr("interp_method"); + const dnnl::algorithm algo = (interp_method == "nearest") + ? dnnl::algorithm::resampling_nearest + : dnnl::algorithm::resampling_linear; - auto out_dims_vec = ComputeOutputShape(ctx); + const auto out_dims_vec = ComputeOutputShape(ctx); framework::DDim dim_out = phi::make_ddim(out_dims_vec); z->Resize(dim_out); @@ -162,6 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { const std::unordered_map args = { {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + resampling_prim->execute(astream, args); astream.wait(); @@ -184,6 +177,7 @@ REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, ops::InterpolateMKLDNNKernel, + ops::InterpolateMKLDNNKernel, ops::InterpolateMKLDNNKernel, ops::InterpolateMKLDNNKernel); REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py index 24ebf40216f4bad6f351f865b86662ce0718f690..d72a1d53d3aa57a3ce4e61f03435eef4d1471d21 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py @@ -16,7 +16,7 @@ from __future__ import print_function import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci @@ -59,6 +59,7 @@ def nearest_neighbor_interp_mkldnn_np(X, @skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.") +@OpTestTool.skip_if_not_cpu_bf16() class TestNearestInterpV2MKLDNNOp(OpTest): def init_test_case(self): pass @@ -84,7 +85,7 @@ class TestNearestInterpV2MKLDNNOp(OpTest): self.init_test_case() self.init_data_type() - if self.dtype == np.float32: + if self.dtype == np.float32 or self.dtype == np.uint16: input_np = np.random.random(self.input_shape).astype(self.dtype) else: init_low, init_high = (-5, 5) if self.dtype == np.int8 else (0, 10) @@ -126,6 +127,9 @@ class TestNearestInterpV2MKLDNNOp(OpTest): if isinstance(self.scale, float): self.scale = [self.scale] + if self.dtype == np.uint16: + input_np = convert_float_to_uint16(input_np) + self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size @@ -191,6 +195,10 @@ def create_test_class(parent): def init_data_type(self): self.dtype = np.float32 + class TestBf16Case(parent): + def init_data_type(self): + self.dtype = np.uint16 + class TestInt8Case(parent): def init_data_type(self): self.dtype = np.int8 @@ -199,12 +207,14 @@ def create_test_class(parent): def init_data_type(self): self.dtype = np.uint8 - TestFp32Case.__name__ = parent.__name__ - TestInt8Case.__name__ = parent.__name__ - TestUint8Case.__name__ = parent.__name__ - globals()[parent.__name__] = TestFp32Case - globals()[parent.__name__] = TestInt8Case - globals()[parent.__name__] = TestUint8Case + TestFp32Case.__name__ = "{0}_{1}".format(parent.__name__, "FP32") + TestBf16Case.__name__ = "{0}_{1}".format(parent.__name__, "BF16") + TestInt8Case.__name__ = "{0}_{1}".format(parent.__name__, "INT8") + TestUint8Case.__name__ = "{0}_{1}".format(parent.__name__, "UINT8") + globals()[TestFp32Case.__name__] = TestFp32Case + globals()[TestBf16Case.__name__] = TestBf16Case + globals()[TestInt8Case.__name__] = TestInt8Case + globals()[TestUint8Case.__name__] = TestUint8Case create_test_class(TestNearestInterpV2MKLDNNOp)