From af8ded773ab0c254c37aaea34cd1c7f4e4a6f439 Mon Sep 17 00:00:00 2001 From: TTerror Date: Wed, 16 Dec 2020 23:35:24 +0800 Subject: [PATCH] update activation op on kunlun (#29577) * fix expand && concat/transpose to new api * update xpu_header * update activation op on kunlun * update activation op on kunlun * update activation op on kunlun * update activation op on kunlun * update activation op on kunlun * add nearest_interp on kunlun * update error message --- cmake/external/xpu.cmake | 2 +- paddle/fluid/operators/activation_op_xpu.cc | 333 +++++++++----- paddle/fluid/operators/interpolate_op_xpu.cc | 7 +- .../unittests/xpu/test_activation_op_xpu.py | 59 ++- .../xpu/test_nearest_interp_op_xpu.py | 432 ++++++++++++++++++ 5 files changed, 722 insertions(+), 111 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_op_xpu.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 75e0eb2e27..6b24354440 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,7 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_11.tar.gz" CACHE STRING "" FORCE) +SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_15.tar.gz" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc index 48e55e8f61..2c7219ef68 100644 --- a/paddle/fluid/operators/activation_op_xpu.cc +++ b/paddle/fluid/operators/activation_op_xpu.cc @@ -54,55 +54,27 @@ class XPUActivationGradKernel }; template -void xpu_activation_forward(const framework::ExecutionContext &ctx, - xpu::Activation_t type) { +void xpu_activation_forward( + const framework::ExecutionContext &ctx, + std::function func) { const auto *x = ctx.Input("X"); auto *y = ctx.Output("Out"); const T *x_data = x->data(); T *y_data = y->mutable_data(ctx.GetPlace()); - int r = 0; - auto xpu_context = ctx.device_context().x_context(); - - switch (type.type) { - case xpu::Activation_t::HARD_SWISH: { - float threshold = ctx.Attr("threshold"); - float scale = ctx.Attr("scale"); - float offset = ctx.Attr("offset"); - PADDLE_ENFORCE_EQ(threshold, 6.0f, - platform::errors::External( - "Not support threshold [%f] in XPU", threshold)); - PADDLE_ENFORCE_EQ( - scale, 6.0f, - platform::errors::External("Not support scale [%f] in XPU", scale)); - PADDLE_ENFORCE_EQ( - offset, 3.0f, - platform::errors::External("Not support offset [%f] in XPU", offset)); - - r = xpu::hard_swish(xpu_context, reinterpret_cast(x_data), - reinterpret_cast(y_data), x->numel()); - break; - } - case xpu::Activation_t::ACT_POW: { - type.pow_factor = ctx.Attr("factor"); - } - default: { - r = xpu::activation_forward(xpu_context, type, x->numel(), - reinterpret_cast(x_data), - reinterpret_cast(y_data)); - break; - } - } - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External( - "XPU API return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + auto xpu_context = ctx.device_context().x_context(); + int r = func(xpu_context, x_data, y_data, x->numel()); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU activation op return wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); } template void xpu_activation_backward(const framework::ExecutionContext &ctx, - xpu::Activation_t type) { + std::function + func) { /* TODO: relu tanh sigmoid are inplace */ const auto *x = ctx.Input("X"); auto *y = ctx.Input("Out"); @@ -115,99 +87,248 @@ void xpu_activation_backward(const framework::ExecutionContext &ctx, if (y != nullptr) y_data = y->data(); if (dOut != nullptr) y_grad = dOut->data(); T *x_grad = dX->mutable_data(ctx.GetPlace()); - int r = 0; auto xpu_context = ctx.device_context().x_context(); - switch (type.type) { - case xpu::Activation_t::HARD_SWISH: { - float threshold = ctx.Attr("threshold"); - float scale = ctx.Attr("scale"); - float offset = ctx.Attr("offset"); - PADDLE_ENFORCE_EQ(threshold, 6.0f, - platform::errors::External( - "Not support threshold [%f] in XPU", threshold)); - PADDLE_ENFORCE_EQ( - scale, 6.0f, - platform::errors::External("Not support scale [%f] in XPU", scale)); - PADDLE_ENFORCE_EQ( - offset, 3.0f, - platform::errors::External("Not support offset [%f] in XPU", offset)); - r = xpu::hard_swish_grad(xpu_context, - reinterpret_cast(x_data), - reinterpret_cast(y_data), - reinterpret_cast(y_grad), - reinterpret_cast(x_grad), dX->numel()); - break; - } - default: { - r = xpu::activation_backward(xpu_context, type, dX->numel(), - reinterpret_cast(x_data), - reinterpret_cast(y_data), - reinterpret_cast(y_grad), - reinterpret_cast(x_grad)); - break; - } - } - - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + int r = func(xpu_context, x_data, y_data, y_grad, x_grad, dX->numel()); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, platform::errors::External( - "XPU API return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + "XPU activation grad op return wrong value[%d %s].", r, + XPUAPIErrorMsg[r])); } -template -struct XPUActivationFunc : public BaseActivationFunctor { +template +struct XPUReluFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { xpu_activation_forward(ctx, - algorithm); + xpu::relu); } }; -template -struct XPUActivationGradFunc : public BaseActivationFunctor { +template +struct XPUSigmoidFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward(ctx, - algorithm); + xpu_activation_forward( + ctx, xpu::sigmoid); } }; template -using XPUReluFunctor = XPUActivationFunc; +struct XPUTanhFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward(ctx, + xpu::tanh); + } +}; + template -using XPUSigmoidFunctor = XPUActivationFunc; +struct XPUGeluFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward(ctx, + xpu::gelu); + } +}; + template -using XPUTanhFunctor = XPUActivationFunc; +struct XPULogFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward(ctx, + xpu::log); + } +}; + template -using XPUGeluFunctor = XPUActivationFunc; +struct XPUSquareFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward( + ctx, xpu::square); + } +}; + template -using XPULogFunctor = XPUActivationFunc; +struct XPUSqrtFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward(ctx, + xpu::sqrt); + } +}; + template -using XPUSquareFunctor = XPUActivationFunc; +struct XPUAbsFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward(ctx, + xpu::abs); + } +}; + template -using XPUHardSwishFunctor = XPUActivationFunc; +struct XPUPowFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + const auto *x = ctx.Input("X"); + auto *y = ctx.Output("Out"); + auto pow_factor = ctx.Attr("factor"); + const T *x_data = x->data(); + T *y_data = y->mutable_data(ctx.GetPlace()); + T *factor_data = nullptr; + + auto xpu_context = + ctx.device_context().x_context(); + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&factor_data), + x->numel() * sizeof(T)), + XPU_SUCCESS, platform::errors::ResourceExhausted( + "XPU has no enough memory")); + int r = xpu::constant(xpu_context, factor_data, x->numel(), pow_factor); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU constant op return" + " wrong value[%d %s] in pow op.", + r, XPUAPIErrorMsg[r])); + r = xpu::pow(xpu_context, x_data, factor_data, y_data, x->numel()); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU pow op return" + " wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); + if (xpu_context->xpu_stream != nullptr) { + xpu_wait(xpu_context->xpu_stream); + } + xpu_free(factor_data); + } +}; + template -using XPUSuareGradFunctor = XPUActivationGradFunc; +struct XPUHardSwishFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + float threshold = ctx.Attr("threshold"); + float scale = ctx.Attr("scale"); + float offset = ctx.Attr("offset"); + PADDLE_ENFORCE_EQ(threshold, 6.0f, + platform::errors::External( + "Not support threshold [%f] in XPU", threshold)); + PADDLE_ENFORCE_EQ(scale, 6.0f, platform::errors::External( + "Not support scale [%f] in XPU", scale)); + PADDLE_ENFORCE_EQ( + offset, 3.0f, + platform::errors::External("Not support offset [%f] in XPU", offset)); + xpu_activation_forward( + ctx, xpu::hard_swish); + } +}; + template -using XPUReluGradFunctor = XPUActivationGradFunc; +struct XPUReluGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward( + ctx, xpu::relu_grad); + } +}; + template -using XPUSigmoidGradFunctor = - XPUActivationGradFunc; +struct XPUTanhGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward( + ctx, xpu::tanh_grad); + } +}; + template -using XPUTanhGradFunctor = XPUActivationGradFunc; +struct XPUSigmoidGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward( + ctx, xpu::sigmoid_grad); + } +}; + template -using XPUGeluGradFunctor = XPUActivationGradFunc; +struct XPUGeluGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward( + ctx, xpu::gelu_grad); + } +}; + template -using XPUSqrtFunctor = XPUActivationFunc; +struct XPUSqrtGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward( + ctx, xpu::sqrt_grad); + } +}; + template -using XPUSqrtGradFunctor = XPUActivationGradFunc; +struct XPUSquareGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward( + ctx, xpu::square_grad); + } +}; + template -using XPUHardSwishGradFunctor = - XPUActivationGradFunc; +struct XPUHardSwishGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + float threshold = ctx.Attr("threshold"); + float scale = ctx.Attr("scale"); + float offset = ctx.Attr("offset"); + PADDLE_ENFORCE_EQ(threshold, 6.0f, + platform::errors::External( + "Not support threshold [%f] in XPU", threshold)); + PADDLE_ENFORCE_EQ(scale, 6.0f, platform::errors::External( + "Not support scale [%f] in XPU", scale)); + PADDLE_ENFORCE_EQ( + offset, 3.0f, + platform::errors::External("Not support offset [%f] in XPU", offset)); + xpu_activation_backward( + ctx, xpu::hard_swish_grad); + } +}; + template -using XPUACTPowFunctor = XPUActivationFunc; +struct XPULeakyReluFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + const auto *x = ctx.Input("X"); + auto *y = ctx.Output("Out"); + float alpha = ctx.Attr("alpha"); + const T *x_data = x->data(); + T *y_data = y->mutable_data(ctx.GetPlace()); + + auto xpu_context = + ctx.device_context().x_context(); + int r = xpu::leaky_relu(xpu_context, x_data, y_data, x->numel(), alpha); + PADDLE_ENFORCE_EQ( + r, xpu::Error_t::SUCCESS, + platform::errors::External("XPU leaky_relu return wrong value[%d %s].", + r, XPUAPIErrorMsg[r])); + } +}; + template -using XPUABSFunctor = XPUActivationFunc; +struct XPULeakyReluGradFunctor : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + const auto *x = ctx.Input("X"); + auto *dOut = ctx.Input(framework::GradVarName("Out")); + auto *dX = ctx.Output(framework::GradVarName("X")); + float alpha = ctx.Attr("alpha"); + const T *x_data = nullptr; + const T *y_grad = nullptr; + if (x != nullptr) x_data = x->data(); + if (dOut != nullptr) y_grad = dOut->data(); + T *x_grad = dX->mutable_data(ctx.GetPlace()); + auto xpu_context = + ctx.device_context().x_context(); + + // The signs of x and y are the same, + // y == nullptr here, + // so we give 2 x to the api + int r = xpu::leaky_relu_grad( + xpu_context, reinterpret_cast(x_data), + reinterpret_cast(x_data), + reinterpret_cast(y_grad), + reinterpret_cast(x_grad), dX->numel(), alpha); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "XPU leaky_relu_grad return wrong value[%d %s].", r, + XPUAPIErrorMsg[r])); + } +}; + } // namespace operators } // namespace paddle @@ -226,14 +347,16 @@ REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor, XPUSigmoidGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(gelu, XPUGeluFunctor, XPUGeluGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSuareGradFunctor) +REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSquareGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(hard_swish, XPUHardSwishFunctor, XPUHardSwishGradFunctor) +REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor, + XPULeakyReluGradFunctor) REGISTER_OP_XPU_KERNEL(log, ops::XPUActivationKernel>); REGISTER_OP_XPU_KERNEL(pow, - ops::XPUActivationKernel>); + ops::XPUActivationKernel>); REGISTER_OP_XPU_KERNEL(abs, - ops::XPUActivationKernel>); + ops::XPUActivationKernel>); #endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/operators/interpolate_op_xpu.cc b/paddle/fluid/operators/interpolate_op_xpu.cc index 6dc4252546..882edc00f2 100644 --- a/paddle/fluid/operators/interpolate_op_xpu.cc +++ b/paddle/fluid/operators/interpolate_op_xpu.cc @@ -229,9 +229,7 @@ class InterpolateGradXPUKernel : public framework::OpKernel { int trans_mode = (align_corners) ? (0) : ((align_mode == 0) ? (1) : (2)); if (nearest) { - PADDLE_ENFORCE_EQ((data_layout == DataLayout::kNCHW), true, - platform::errors::InvalidArgument( - "XPU nearest is only support NCHW")); + trans_mode = (align_corners) ? (0) : (2); } r = xpu::interpolate2d_grad(dev_ctx.x_context(), output_grad->data(), @@ -252,7 +250,10 @@ class InterpolateGradXPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL(bilinear_interp, ops::InterpolateXPUKernel); +REGISTER_OP_XPU_KERNEL(nearest_interp, ops::InterpolateXPUKernel); REGISTER_OP_XPU_KERNEL(bilinear_interp_grad, ops::InterpolateGradXPUKernel); +REGISTER_OP_XPU_KERNEL(nearest_interp_grad, + ops::InterpolateGradXPUKernel); #endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py index 8635a7db36..9f807b06cb 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_activation_op_xpu.py @@ -73,8 +73,7 @@ class TestXPUSigmoid(TestXPUActivation): def test_check_grad(self): if paddle.is_compiled_with_xpu(): place = paddle.XPUPlace(0) - self.check_grad_with_place( - place, ['X'], 'Out', max_relative_error=0.01) + self.check_grad_with_place(place, ['X'], 'Out') @unittest.skipIf(not paddle.is_compiled_with_xpu(), @@ -90,6 +89,11 @@ class TestXPUTanh(TestXPUActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + @unittest.skipIf(not paddle.is_compiled_with_xpu(), "core is not compiled with XPU") @@ -105,6 +109,11 @@ class TestXPUSqrt(TestXPUActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + @unittest.skipIf(not paddle.is_compiled_with_xpu(), "core is not compiled with XPU") @@ -142,6 +151,11 @@ class TestXPURelu(TestXPUActivation): self.inputs = {'X': x} self.outputs = {'Out': out} + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + @unittest.skipIf(not paddle.is_compiled_with_xpu(), "core is not compiled with XPU") @@ -157,6 +171,11 @@ class TestXPUGelu(TestXPUActivation): self.outputs = {'Out': out} self.attrs = {"approximate": approximate, 'use_xpu': True} + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + def gelu(x, approximate): if approximate: @@ -223,6 +242,11 @@ class TestXPUSquare(TestXPUActivation): self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + @unittest.skipIf(not paddle.is_compiled_with_xpu(), "core is not compiled with XPU") @@ -239,5 +263,36 @@ class TestXPUPow(TestXPUActivation): self.outputs = {'Out': out} +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPULeakyRelu(TestXPUActivation): + def setUp(self): + self.op_type = "leaky_relu" + self.init_dtype() + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + alpha = np.random.uniform( + 0, + 1, ) + out = leaky_relu(x, alpha) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'use_xpu': True, 'alpha': alpha} + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + +def leaky_relu(x, alpha): + if (alpha < 1): + y_ref = np.maximum(x, alpha * x) + else: + y_ref = np.minimum(x, alpha * x) + return y_ref.astype(x.dtype) + + if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_op_xpu.py new file mode 100644 index 0000000000..35dadb59bf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_nearest_interp_op_xpu.py @@ -0,0 +1,432 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append("..") +from op_test_xpu import XPUOpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +def nearest_neighbor_interp_np(X, + out_h, + out_w, + out_size=None, + actual_shape=None, + align_corners=True, + data_layout='NCHW'): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + X = np.transpose(X, (0, 3, 1, 2)) # NHWC => NCHW + if out_size is not None: + out_h = out_size[0] + out_w = out_size[1] + if actual_shape is not None: + out_h = actual_shape[0] + out_w = actual_shape[1] + n, c, in_h, in_w = X.shape + + ratio_h = ratio_w = 0.0 + if (out_h > 1): + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + ratio_h = 1.0 * in_h / out_h + if (out_w > 1): + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + ratio_w = 1.0 * in_w / out_w + + out = np.zeros((n, c, out_h, out_w)) + + if align_corners: + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, i, j] = X[:, :, in_i, in_j] + else: + for i in range(out_h): + in_i = int(ratio_h * i) + for j in range(out_w): + in_j = int(ratio_w * j) + out[:, :, i, j] = X[:, :, in_i, in_j] + + if data_layout == "NHWC": + out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC + + return out.astype(X.dtype) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestInterpOp(XPUOpTest): + def setUp(self): + self.use_xpu = True + self.out_size = None + self.actual_shape = None + self.data_layout = 'NCHW' + self.init_test_case() + self.op_type = "nearest_interp" + input_np = np.random.random(self.input_shape).astype("float32") + + if self.data_layout == "NCHW": + in_h = self.input_shape[2] + in_w = self.input_shape[3] + else: + in_h = self.input_shape[1] + in_w = self.input_shape[2] + + if self.scale > 0: + out_h = int(in_h * self.scale) + out_w = int(in_w * self.scale) + else: + out_h = self.out_h + out_w = self.out_w + + output_np = nearest_neighbor_interp_np( + input_np, out_h, out_w, self.out_size, self.actual_shape, + self.align_corners, self.data_layout) + self.inputs = {'X': input_np} + if self.out_size is not None: + self.inputs['OutSize'] = self.out_size + if self.actual_shape is not None: + self.inputs['OutSize'] = self.actual_shape + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'scale': self.scale, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'data_layout': self.data_layout + } + self.outputs = {'Out': output_np} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 4, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 3]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpCase1(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpCase2(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpCase3(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpCase4(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [4, 1, 7, 8] + self.out_h = 1 + self.out_w = 1 + self.scale = 0. + self.out_size = np.array([2, 2]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpCase5(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = np.array([11, 11]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpCase6(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [1, 1, 32, 64] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([65, 129]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpSame(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 3, 32, 64] + self.out_h = 32 + self.out_w = 64 + self.scale = 0. + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpActualShape(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpDataLayout(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 4, 4, 5] + self.out_h = 2 + self.out_w = 2 + self.scale = 0. + self.out_size = np.array([3, 8]).astype("int32") + self.align_corners = True + self.data_layout = "NCHW" + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestInterpWithoutCorners(TestNearestInterpOp): + def set_align_corners(self): + self.align_corners = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpScale1(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 7, 5] + self.out_h = 64 + self.out_w = 32 + self.scale = 2. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpScale2(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 5, 7] + self.out_h = 64 + self.out_w = 32 + self.scale = 1.5 + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestNeighborInterpScale3(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 7, 5] + self.out_h = 64 + self.out_w = 32 + self.scale = 1. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestInterpOp_attr_tensor(XPUOpTest): + def setUp(self): + self.out_size = None + self.actual_shape = None + self.init_test_case() + self.op_type = "nearest_interp" + self.shape_by_1Dtensor = False + self.scale_by_1Dtensor = False + self.attrs = { + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + } + + input_np = np.random.random(self.input_shape).astype("float32") + self.inputs = {'X': input_np} + + if self.scale_by_1Dtensor: + self.inputs['Scale'] = np.array([self.scale]).astype("float32") + elif self.scale > 0: + out_h = int(self.input_shape[2] * self.scale) + out_w = int(self.input_shape[3] * self.scale) + self.attrs['scale'] = self.scale + else: + out_h = self.out_h + out_w = self.out_w + + if self.shape_by_1Dtensor: + self.inputs['OutSize'] = self.out_size + elif self.out_size is not None: + size_tensor = [] + for index, ele in enumerate(self.out_size): + size_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.inputs['SizeTensor'] = size_tensor + + self.attrs['out_h'] = self.out_h + self.attrs['out_w'] = self.out_w + output_np = nearest_neighbor_interp_np(input_np, out_h, out_w, + self.out_size, self.actual_shape, + self.align_corners) + self.outputs = {'Out': output_np} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out', in_place=True) + + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [2, 5, 4, 4] + self.out_h = 3 + self.out_w = 3 + self.scale = 0. + self.out_size = [3, 3] + self.align_corners = True + + +# out_size is a tensor list +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestInterp_attr_tensor_Case1(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 3, 9, 6] + self.out_h = 12 + self.out_w = 12 + self.scale = 0. + self.out_size = [8, 12] + self.align_corners = True + + +# out_size is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestInterp_attr_tensor_Case2(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 0. + self.out_size = np.array([66, 40]).astype("int32") + self.align_corners = True + self.shape_by_1Dtensor = True + + +# scale is a 1-D tensor +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestInterp_attr_tensor_Case3(TestNearestInterpOp_attr_tensor): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 32, 16] + self.out_h = 64 + self.out_w = 32 + self.scale = 2.0 + self.out_size = None + self.align_corners = True + self.scale_by_1Dtensor = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestNearestInterpException(unittest.TestCase): + def test_exception(self): + input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32") + + def attr_data_format(): + # for 4-D input, data_format can only be NCHW or NHWC + out = fluid.layers.resize_nearest( + input, out_shape=[4, 8], data_format='NDHWC') + + def attr_scale_type(): + out = fluid.layers.resize_nearest(input, scale='scale') + + def attr_scale_value(): + out = fluid.layers.resize_nearest(input, scale=-0.3) + + self.assertRaises(ValueError, attr_data_format) + self.assertRaises(TypeError, attr_scale_type) + self.assertRaises(ValueError, attr_scale_value) + + +if __name__ == "__main__": + unittest.main() -- GitLab