From 82ac39132c49ca2792a37c2166f9d9242069d63d Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Thu, 27 Apr 2023 10:15:50 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon4=E3=80=91No5=20nextafter=20?= =?UTF-8?q?=20(#52544)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/api/yaml/ops.yaml | 10 ++ paddle/phi/kernels/cpu/nextafter_kernel.cc | 22 ++++ paddle/phi/kernels/gpu/nextafter_kernel.cu | 22 ++++ .../phi/kernels/impl/nextafter_kernel_impl.h | 84 +++++++++++++ paddle/phi/kernels/nextafter_kernel.h | 28 +++++ python/paddle/__init__.py | 2 + .../tests/unittests/test_nextafter_op.py | 118 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 36 ++++++ 9 files changed, 324 insertions(+) create mode 100644 paddle/phi/kernels/cpu/nextafter_kernel.cc create mode 100644 paddle/phi/kernels/gpu/nextafter_kernel.cu create mode 100644 paddle/phi/kernels/impl/nextafter_kernel_impl.h create mode 100644 paddle/phi/kernels/nextafter_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_nextafter_op.py diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 507fef33096..8fc8c4c9b08 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1421,6 +1421,16 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor +- op : nextafter + args : (Tensor x, Tensor y) + output : Tensor(out) + infer_meta : + func : ElementwiseInferMeta + param: [x, y] + kernel : + func : nextafter + data_type : x + - op : nll_loss args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean") output : Tensor(out), Tensor(total_weight) diff --git a/paddle/phi/kernels/cpu/nextafter_kernel.cc b/paddle/phi/kernels/cpu/nextafter_kernel.cc new file mode 100644 index 00000000000..ac4ab00a4d3 --- /dev/null +++ b/paddle/phi/kernels/cpu/nextafter_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nextafter_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/nextafter_kernel_impl.h" + +PD_REGISTER_KERNEL( + nextafter, CPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/nextafter_kernel.cu b/paddle/phi/kernels/gpu/nextafter_kernel.cu new file mode 100644 index 00000000000..e0ac8212853 --- /dev/null +++ b/paddle/phi/kernels/gpu/nextafter_kernel.cu @@ -0,0 +1,22 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/nextafter_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/nextafter_kernel_impl.h" + +PD_REGISTER_KERNEL( + nextafter, GPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {} diff --git a/paddle/phi/kernels/impl/nextafter_kernel_impl.h b/paddle/phi/kernels/impl/nextafter_kernel_impl.h new file mode 100644 index 00000000000..6d540092825 --- /dev/null +++ b/paddle/phi/kernels/impl/nextafter_kernel_impl.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/math.h" +#include "paddle/phi/kernels/nextafter_kernel.h" +namespace phi { +template +struct NextafterOut { + using type = T; +}; + +template <> +struct NextafterOut { + using type = double; +}; + +template <> +struct NextafterOut { + using type = double; +}; +template +struct NextafterFunctor { + NextafterFunctor(const T* x, + const T* y, + typename NextafterOut::type* out, + int64_t numel) + : x_(x), y_(y), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = static_cast::type>(std::nextafter( + static_cast(x_[idx]), static_cast(y_[idx]))); + } + const T* x_; + const T* y_; + typename NextafterOut::type* out_; + int64_t numel_; +}; +template <> +struct NextafterFunctor { + NextafterFunctor(const double* x, const double* y, double* out, int64_t numel) + : x_(x), y_(y), out_(out), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + out_[idx] = std::nextafter(x_[idx], y_[idx]); + } + + const double* x_; + const double* y_; + double* out_; + int64_t numel_; +}; + +template +void NextafterKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + auto* out_data = ctx.template Alloc(out); + auto x_data = x.data(); + auto y_data = y.data(); + auto x_numel = x.numel(); + + phi::funcs::ForRange for_range(ctx, x_numel); + phi::NextafterFunctor functor(x_data, y_data, out_data, x_numel); + for_range(functor); +} + +} // namespace phi diff --git a/paddle/phi/kernels/nextafter_kernel.h b/paddle/phi/kernels/nextafter_kernel.h new file mode 100644 index 00000000000..3a185e39bd9 --- /dev/null +++ b/paddle/phi/kernels/nextafter_kernel.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void NextafterKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f319ab27c06..d155dce987c 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -299,6 +299,7 @@ from .tensor.math import frexp # noqa: F401 from .tensor.math import trapezoid # noqa: F401 from .tensor.math import cumulative_trapezoid # noqa: F401 from .tensor.math import vander # noqa: F401 +from .tensor.math import nextafter # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -688,4 +689,5 @@ __all__ = [ # noqa 'cumulative_trapezoid', 'polar', 'vander', + 'nextafter', ] diff --git a/python/paddle/fluid/tests/unittests/test_nextafter_op.py b/python/paddle/fluid/tests/unittests/test_nextafter_op.py new file mode 100644 index 00000000000..5048778e1b7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nextafter_op.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from eager_op_test import OpTest + +import paddle + + +def ref_nextafter(x, y): + out = np.nextafter(x, y) + return out + + +class TestNextafterAPI(unittest.TestCase): + def setUp(self): + self.x = np.random.rand(2, 3, 4, 5).astype('float32') + self.y = np.random.rand(2, 3, 4, 5).astype('float32') + self.x1 = np.array([0, 0, 10]).astype("float32") + self.y1 = np.array([np.inf, -np.inf, 10]).astype("float32") + self.x2 = np.random.rand(100).astype("float32") + self.y2 = np.random.rand(100).astype("float32") + self.place = ( + paddle.CUDAPlace(0) + if paddle.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name='x', shape=self.x.shape, dtype='float32' + ) + y = paddle.static.data( + name='y', shape=self.y.shape, dtype='float32' + ) + out = paddle.nextafter(x, y) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out]) + out_ref = ref_nextafter(self.x, self.y) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + with paddle.static.program_guard(paddle.static.Program()): + x1 = paddle.static.data( + name='x', shape=self.x1.shape, dtype='float32' + ) + y1 = paddle.static.data( + name='y', shape=self.y1.shape, dtype='float32' + ) + out = paddle.nextafter(x1, y1) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x1, 'y': self.y1}, fetch_list=[out]) + out_ref = ref_nextafter(self.x1, self.y1) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + with paddle.static.program_guard(paddle.static.Program()): + x2 = paddle.static.data( + name='x', shape=self.x2.shape, dtype='float32' + ) + y2 = paddle.static.data( + name='y', shape=self.y2.shape, dtype='float32' + ) + out = paddle.nextafter(x2, y2) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'x': self.x2, 'y': self.y2}, fetch_list=[out]) + out_ref = ref_nextafter(self.x2, self.y2) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out = paddle.nextafter(x, y) + out_ref = ref_nextafter(self.x, self.y) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05) + paddle.enable_static() + + +class TestNextafterOP(OpTest): + def setUp(self): + self.op_type = "nextafter" + self.python_api = paddle.nextafter + self.init_dtype() + + x = np.array([1, 2]).astype(self.dtype) + y = np.array([2, 1]).astype(self.dtype) + out = np.nextafter(x, y) + self.inputs = {'x': x, 'y': y} + self.outputs = {'out': out} + + def test_check_output(self): + self.check_output() + + def init_dtype(self): + self.dtype = np.float64 + + +class TestNextafterOPFP32(TestNextafterOP): + def init_dtype(self): + self.dtype = np.float32 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b78ac0e57c2..bea2fd7323d 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -251,6 +251,7 @@ from .math import cumulative_trapezoid # noqa: F401 from .math import sigmoid # noqa: F401 from .math import sigmoid_ # noqa: F401 from .math import vander # noqa: F401 +from .math import nextafter # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -540,6 +541,7 @@ tensor_method_func = [ # noqa 'sigmoid', 'sigmoid_', 'vander', + 'nextafter', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 2f94f0a7e20..37519996a2a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5486,3 +5486,39 @@ def vander(x, n=None, increasing=False, name=None): res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1) res = res[:, ::-1] if not increasing else res return res + + +def nextafter(x, y, name=None): + r""" + Return the next floating-point value after input towards other, elementwise. + The shapes of input and other must be broadcastable. + + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64. + y (Tensor): An N-D Tensor, the data type is float32, float64. + name(str, optional):Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor, the shape and data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + out = paddle.nextafter(paddle.to_tensor([1.0,2.0]),paddle.to_tensor([2.0,1.0])) + print(out) + #Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, + # [1.00000012, 1.99999988]) + """ + if in_dygraph_mode(): + return _C_ops.nextafter(x, y) + else: + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'nextafter') + op_type = "nextafter" + helper = LayerHelper(op_type, **locals()) + inputs = {"x": x, "y": y} + out = helper.create_variable_for_type_inference(dtype=paddle.float32) + outputs = {"out": out} + helper.append_op(type=op_type, inputs=inputs, outputs=outputs) + return out -- GitLab