diff --git a/paddle/fluid/framework/ir/mkldnn/interpolate_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/interpolate_mkldnn_pass.cc index 06df1caca35b922ac96d7d886296a6dee6bfb764..4eb532b47cb4b59cb3df0fe775400caa01354269 100644 --- a/paddle/fluid/framework/ir/mkldnn/interpolate_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/interpolate_mkldnn_pass.cc @@ -43,8 +43,9 @@ void InterpolateMKLDNNPass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; const std::vector interpolate_op_types = { - "bilinear_interp", "nearest_interp", "trilinear_interp", "bicubic_interp", - "linear_interp"}; + "bilinear_interp", "nearest_interp", "trilinear_interp", + "bicubic_interp", "linear_interp", "bilinear_interp_v2", + "nearest_interp_v2"}; for (const Node* node : graph->Nodes()) { if (node->IsOp() && diff --git a/paddle/fluid/framework/ir/placement_pass_base.cc b/paddle/fluid/framework/ir/placement_pass_base.cc index fd604ffe7b5de440fb3509a01fd2a1bc1a553574..35ba92006077999a541e700c6884db0d32f0bfab 100644 --- a/paddle/fluid/framework/ir/placement_pass_base.cc +++ b/paddle/fluid/framework/ir/placement_pass_base.cc @@ -77,7 +77,8 @@ bool PlacementPassBase::IsDefaultOpTypes(const std::string& op_type) const { // the corresponding pass. const std::vector not_default_op_types = { "bilinear_interp", "nearest_interp", "trilinear_interp", - "bicubic_interp", "linear_interp"}; + "bicubic_interp", "linear_interp", "bilinear_interp_v2", + "linear_interp_v2"}; bool is_interpolate_op = std::find(not_default_op_types.begin(), not_default_op_types.end(), op_type) != not_default_op_types.end(); diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index cfbe1778c766467e0f8f9d1c7c395ba6ccfff0a7..cb93044ca58445dcb4817629ef859e312f900983 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -14,6 +14,9 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -359,13 +362,41 @@ class InterpolateV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + framework::LibraryType library = framework::LibraryType::kPlain; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + auto interp_method = ctx.Attr("interp_method"); + // TODO(danqing): support other interp_method + if (this->CanMKLDNNBeUsed(ctx, data_type) && + (interp_method == "nearest" || interp_method == "bilinear")) { + layout = framework::DataLayout::kMKLDNN; + library = framework::LibraryType::kMKLDNN; + } +#endif + + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const Tensor& tensor, const framework::OpKernelType& expected_kernel_type) const override { +#ifdef PADDLE_WITH_MKLDNN + if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN)) { + auto attrs = Attrs(); + auto ar = paddle::framework::AttrReader(attrs); + const std::string data_format = ar.Get("data_layout"); + auto dl = framework::StringToDataLayout(data_format); + // Some models may have intentionally set "AnyLayout" for pool + // op. Treat this as NCHW (default data_format value) + if (dl != framework::DataLayout::kAnyLayout) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), dl); + } + } +#endif if (var_name == "SizeTensor" || var_name == "Scale") { return expected_kernel_type; } @@ -436,6 +467,9 @@ class InterpolateV2OpMaker : public framework::OpProtoAndCheckerMaker { "can be \'0\' for src_idx = scale*(dst_indx+0.5)-0.5 , " "can be \'1\' for src_idx = scale*dst_index .") .SetDefault(1); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( This operator samples input X to given output shape by using specified interpolation method, the interpolation methods can be \"nearest\" diff --git a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc index 64a1903c2da4ff5bc5e903ab33124d49bf1b8cdd..9d80286f4c4efa54ce83ca6148399d0875d64dc0 100644 --- a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc @@ -33,7 +33,7 @@ class InterpolateMKLDNNHandler : public platform::MKLDNNHandlerT { public: InterpolateMKLDNNHandler(const dnnl::algorithm algo, - const paddle::platform::MKLDNNDeviceContext& dev_ctx, + const platform::MKLDNNDeviceContext& dev_ctx, const dnnl::engine engine, platform::Place cpu_place, const Tensor* x, Tensor* z, const std::string& uniq_name) @@ -94,19 +94,32 @@ class InterpolateMKLDNNKernel : public framework::OpKernel { out_dims = out_size_data; } } else { - float scale; + std::vector scale; + scale.reserve(3); auto scale_tensor = ctx.Input("Scale"); if (scale_tensor != nullptr) { auto scale_data = get_new_data_from_tensor(scale_tensor); - scale = scale_data[0]; + scale.resize(3, scale_data[0]); + std::copy(scale_data.begin(), scale_data.end(), scale.begin()); } else { - scale = ctx.Attr("scale"); + std::string op_type = ctx.Type(); + + if (op_type.find("v2") == std::string::npos) { // v1 + scale.push_back(ctx.Attr("scale")); + scale.push_back(scale[0]); + scale.push_back(scale[0]); + } else { // v2 + std::vector scale_attr = ctx.Attr>("scale"); + scale.resize(3, scale_attr[0]); + std::copy(scale_attr.begin(), scale_attr.end(), scale.begin()); + } } - if (scale > 0) { + if (scale[0] > 0.0f && scale[1] > 0.0f && scale[2] > 0.0f) { + int j = 0; std::vector in_dhw_vec = framework::vectorize(in_dhw_dims); std::transform( in_dhw_vec.begin(), in_dhw_vec.end(), out_dims.begin(), - [&](int64_t i) -> int { return static_cast(i * scale); }); + [&](int64_t i) -> int { return static_cast(i * scale[j++]); }); } } @@ -172,3 +185,8 @@ REGISTER_OP_KERNEL(nearest_interp, MKLDNN, ::paddle::platform::CPUPlace, ops::InterpolateMKLDNNKernel); REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace, ops::InterpolateMKLDNNKernel); + +REGISTER_OP_KERNEL(nearest_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, + ops::InterpolateMKLDNNKernel); +REGISTER_OP_KERNEL(bilinear_interp_v2, MKLDNN, ::paddle::platform::CPUPlace, + ops::InterpolateMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_bilinear_interp_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_bilinear_interp_mkldnn_op.py index e86273ea1c28ef56cd5786ca41715efe80ea6f5b..e740efa14c575b0d18876da2a9d3fa62472b51cf 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_bilinear_interp_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_bilinear_interp_mkldnn_op.py @@ -198,4 +198,6 @@ class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp): if __name__ == "__main__": + from paddle import enable_static + enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_bilinear_interp_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_bilinear_interp_v2_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b0639289ab286954fe51b7fcd9aee7731deb7e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_bilinear_interp_v2_mkldnn_op.py @@ -0,0 +1,210 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci + + +def bilinear_interp_mkldnn_np(input, + out_h, + out_w, + out_size=None, + actual_shape=None, + data_layout='NCHW'): + """bilinear interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + batch_size, channel, in_h, in_w = input.shape + + out = np.zeros((batch_size, channel, out_h, out_w)) + + for oh in range(out_h): + h0 = int(math.floor((oh + 0.5) * in_h / out_h - 0.5)) + h1 = int(math.ceil((oh + 0.5) * in_h / out_h - 0.5)) + h0 = max(h0, 0) + h1 = min(h1, in_h - 1) + Wh = (oh + 0.5) * in_h / out_h - 0.5 - h0 + for ow in range(out_w): + w0 = int(math.floor((ow + 0.5) * in_w / out_w - 0.5)) + w1 = int(math.ceil((ow + 0.5) * in_w / out_w - 0.5)) + w0 = max(w0, 0) + w1 = min(w1, in_w - 1) + Ww = (ow + 0.5) * in_w / out_w - 0.5 - w0 + input_h0_w0 = input[:, :, h0, w0] + input_h1_w0 = input[:, :, h1, w0] + input_h0_w1 = input[:, :, h0, w1] + input_h1_w1 = input[:, :, h1, w1] + out[:, :, oh, ow] = input_h0_w0 * (1 - Wh) * ( + 1 - Ww) + input_h1_w0 * Wh * (1 - Ww) + input_h0_w1 * ( + 1 - Wh) * Ww + input_h1_w1 * Wh * Ww + + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + + return out.astype(input.dtype) + + +@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.") +class TestBilinearInterpMKLDNNOp(OpTest): + def init_test_case(self): + pass + + def setUp(self): + self.op_type = "bilinear_interp_v2" + self.interp_method = 'bilinear' + self._cpu_only = True + self.use_mkldnn = True + self.input_shape = [1, 1, 2, 2] + self.data_layout = 'NCHW' + # priority: actual_shape > out_size > scale > out_h & out_w + self.out_h = 1 + self.out_w = 1 + self.scale = 2.0 + self.out_size = None + self.actual_shape = None + + self.init_test_case() + + input_np = np.random.random(self.input_shape).astype("float32") + if self.data_layout == "NCHW": + in_h = self.input_shape[2] + in_w = self.input_shape[3] + else: + in_h = self.input_shape[1] + in_w = self.input_shape[2] + + scale_h = 0 + scale_w = 0 + + if self.scale: + if isinstance(self.scale, float) or isinstance(self.scale, int): + scale_h = float(self.scale) + scale_w = float(self.scale) + if isinstance(self.scale, list) and len(self.scale) == 1: + scale_w = self.scale[0] + scale_h = self.scale[0] + elif isinstance(self.scale, list) and len(self.scale) > 1: + scale_w = self.scale[1] + scale_h = self.scale[0] + + if scale_h > 0 and scale_w > 0: + out_h = int(in_h * scale_h) + out_w = int(in_w * scale_w) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = bilinear_interp_mkldnn_np(input_np, out_h, out_w, + self.out_size, self.actual_shape, + self.data_layout) + + if isinstance(self.scale, float): + self.scale = [self.scale, self.scale] + + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + self.attrs = { + 'interp_method': self.interp_method, + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'data_layout': self.data_layout, + 'use_mkldnn': self.use_mkldnn + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output(check_dygraph=False) + + +class TestBilinearInterpOpMKLDNNNHWC(TestBilinearInterpMKLDNNOp): + def init_test_case(self): + self.input_shape = [3, 2, 32, 16] + self.out_h = 27 + self.out_w = 49 + self.scale = [2.0, 3.0] + self.data_layout = 'NHWC' + + +class TestBilinearNeighborInterpMKLDNNCase2(TestBilinearInterpMKLDNNOp): + def init_test_case(self): + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestBilinearNeighborInterpCase3(TestBilinearInterpMKLDNNOp): + def init_test_case(self): + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 128 + self.scale = [0.1, 0.05] + + +class TestBilinearNeighborInterpCase4(TestBilinearInterpMKLDNNOp): + def init_test_case(self): + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = [13.0, 15.0] + self.out_size = np.array([65, 129]).astype("int32") + + +class TestBilinearNeighborInterpCase5(TestBilinearInterpMKLDNNOp): + def init_test_case(self): + self.input_shape = [1, 1, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.out_size = np.array([13, 13]).astype("int32") + + +class TestBilinearNeighborInterpCase6(TestBilinearInterpMKLDNNOp): + def init_test_case(self): + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 1.0 + self.out_size = np.array([65, 129]).astype("int32") + + +class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp): + def init_test_case(self): + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.scale = 2.0 + self.out_size = np.array([65, 129]).astype("int32") + + +if __name__ == "__main__": + from paddle import enable_static + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py index 1e4bfd5f0cf017359a88d3b4c3754becb61ab77e..9f39826cb3ed2875993452269a66559ec2e84782 100755 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_mkldnn_op.py @@ -163,4 +163,6 @@ class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp): if __name__ == "__main__": + from paddle import enable_static + enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b608ca3af2f3660347278135e1118bb3a3c817d5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_nearest_interp_v2_mkldnn_op.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci + + +def nearest_neighbor_interp_mkldnn_np(X, + out_h, + out_w, + out_size=None, + actual_shape=None, + data_layout='NCHW'): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + X = np.transpose(X, (0, 3, 1, 2)) # NHWC => NCHW + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + + n, c, in_h, in_w = X.shape + + fh = fw = 0.0 + if (out_h > 1): + fh = out_h * 1.0 / in_h + if (out_w > 1): + fw = out_w * 1.0 / in_w + + out = np.zeros((n, c, out_h, out_w)) + + for oh in range(out_h): + ih = int(round((oh + 0.5) / fh - 0.5)) + for ow in range(out_w): + iw = int(round((ow + 0.5) / fw - 0.5)) + out[:, :, oh, ow] = X[:, :, ih, iw] + + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + + return out.astype(X.dtype) + + +@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.") +class TestNearestInterpV2MKLDNNOp(OpTest): + def init_test_case(self): + pass + + def setUp(self): + self.op_type = "nearest_interp_v2" + self.interp_method = 'nearest' + self._cpu_only = True + self.use_mkldnn = True + self.input_shape = [1, 1, 2, 2] + self.data_layout = 'NCHW' + # priority: actual_shape > out_size > scale > out_h & out_w + self.out_h = 1 + self.out_w = 1 + self.scale = [2.0, 3.0] + self.out_size = None + self.actual_shape = None + + self.init_test_case() + + input_np = np.random.random(self.input_shape).astype("float32") + if self.data_layout == "NCHW": + in_h = self.input_shape[2] + in_w = self.input_shape[3] + else: + in_h = self.input_shape[1] + in_w = self.input_shape[2] + + scale_h = 0 + scale_w = 0 + + if self.scale: + if isinstance(self.scale, float) or isinstance(self.scale, int): + scale_h = float(self.scale) + scale_w = float(self.scale) + if isinstance(self.scale, list) and len(self.scale) == 1: + scale_w = self.scale[0] + scale_h = self.scale[0] + elif isinstance(self.scale, list) and len(self.scale) > 1: + scale_w = self.scale[1] + scale_h = self.scale[0] + + if scale_h > 0 and scale_w > 0: + out_h = int(in_h * scale_h) + out_w = int(in_w * scale_w) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = nearest_neighbor_interp_mkldnn_np( + input_np, out_h, out_w, self.out_size, self.actual_shape, + self.data_layout) + + if isinstance(self.scale, float): + self.scale = [self.scale] + + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + self.attrs = { + 'interp_method': self.interp_method, + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'data_layout': self.data_layout, + 'use_mkldnn': self.use_mkldnn + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + self.check_output(check_dygraph=False) + + +class TestNearestInterpOpV2MKLDNNNHWC(TestNearestInterpV2MKLDNNOp): + def init_test_case(self): + self.input_shape = [3, 2, 32, 16] + self.out_h = 27 + self.out_w = 49 + self.scale = [2.0, 3.0] + self.data_layout = 'NHWC' + + +class TestNearestNeighborInterpV2MKLDNNCase2(TestNearestInterpV2MKLDNNOp): + def init_test_case(self): + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + + +class TestNearestNeighborInterpV2MKLDNNCase3(TestNearestInterpV2MKLDNNOp): + def init_test_case(self): + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 128 + self.scale = [0.1, 0.05] + + +class TestNearestNeighborInterpV2MKLDNNCase4(TestNearestInterpV2MKLDNNOp): + def init_test_case(self): + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = [13.0, 15.0] + self.out_size = np.array([65, 129]).astype("int32") + + +class TestNearestNeighborInterpV2MKLDNNSame(TestNearestInterpV2MKLDNNOp): + def init_test_case(self): + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.out_size = np.array([65, 129]).astype("int32") + + +if __name__ == "__main__": + from paddle import enable_static + enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 5de4bffd1601c8f3fa345a34cb70cf013253c011..7c1f54adfb3d928871f8fe4179642198cb2cfbbf 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -603,7 +603,9 @@ STATIC_MODE_TESTING_LIST = [ 'test_fc_mkldnn_op', 'test_fc_bf16_mkldnn_op', 'test_nearest_interp_mkldnn_op', + 'test_nearest_interp_v2_mkldnn_op', 'test_bilinear_interp_mkldnn_op', + 'test_bilinear_interp_v2_mkldnn_op', 'test_fusion_gru_int8_mkldnn_op', 'test_fusion_gru_bf16_mkldnn_op', 'test_fusion_gru_mkldnn_op',