未验证 提交 82ac3913 编写于 作者: C cyberslack_lee 提交者: GitHub

【Hackathon4】No5 nextafter (#52544)

上级 bfeedd29
......@@ -1421,6 +1421,16 @@
data_transform :
skip_transform : out_size, size_tensor, scale_tensor
- op : nextafter
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
param: [x, y]
kernel :
func : nextafter
data_type : x
- op : nll_loss
args : (Tensor input, Tensor label, Tensor weight, int64_t ignore_index = -100, str reduction = "mean")
output : Tensor(out), Tensor(total_weight)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nextafter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nextafter_kernel_impl.h"
PD_REGISTER_KERNEL(
nextafter, CPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/nextafter_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/nextafter_kernel_impl.h"
PD_REGISTER_KERNEL(
nextafter, GPU, ALL_LAYOUT, phi::NextafterKernel, float, double) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math.h"
#include "paddle/phi/kernels/nextafter_kernel.h"
namespace phi {
template <typename T>
struct NextafterOut {
using type = T;
};
template <>
struct NextafterOut<int32_t> {
using type = double;
};
template <>
struct NextafterOut<int64_t> {
using type = double;
};
template <typename T>
struct NextafterFunctor {
NextafterFunctor(const T* x,
const T* y,
typename NextafterOut<T>::type* out,
int64_t numel)
: x_(x), y_(y), out_(out), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = static_cast<typename NextafterOut<T>::type>(std::nextafter(
static_cast<float>(x_[idx]), static_cast<float>(y_[idx])));
}
const T* x_;
const T* y_;
typename NextafterOut<T>::type* out_;
int64_t numel_;
};
template <>
struct NextafterFunctor<double> {
NextafterFunctor(const double* x, const double* y, double* out, int64_t numel)
: x_(x), y_(y), out_(out), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
out_[idx] = std::nextafter(x_[idx], y_[idx]);
}
const double* x_;
const double* y_;
double* out_;
int64_t numel_;
};
template <typename T, typename Context>
void NextafterKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
auto* out_data = ctx.template Alloc<T>(out);
auto x_data = x.data<T>();
auto y_data = y.data<T>();
auto x_numel = x.numel();
phi::funcs::ForRange<Context> for_range(ctx, x_numel);
phi::NextafterFunctor<T> functor(x_data, y_data, out_data, x_numel);
for_range(functor);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void NextafterKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
} // namespace phi
......@@ -299,6 +299,7 @@ from .tensor.math import frexp # noqa: F401
from .tensor.math import trapezoid # noqa: F401
from .tensor.math import cumulative_trapezoid # noqa: F401
from .tensor.math import vander # noqa: F401
from .tensor.math import nextafter # noqa: F401
from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
......@@ -688,4 +689,5 @@ __all__ = [ # noqa
'cumulative_trapezoid',
'polar',
'vander',
'nextafter',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest
import paddle
def ref_nextafter(x, y):
out = np.nextafter(x, y)
return out
class TestNextafterAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.rand(2, 3, 4, 5).astype('float32')
self.y = np.random.rand(2, 3, 4, 5).astype('float32')
self.x1 = np.array([0, 0, 10]).astype("float32")
self.y1 = np.array([np.inf, -np.inf, 10]).astype("float32")
self.x2 = np.random.rand(100).astype("float32")
self.y2 = np.random.rand(100).astype("float32")
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name='x', shape=self.x.shape, dtype='float32'
)
y = paddle.static.data(
name='y', shape=self.y.shape, dtype='float32'
)
out = paddle.nextafter(x, y)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out])
out_ref = ref_nextafter(self.x, self.y)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
with paddle.static.program_guard(paddle.static.Program()):
x1 = paddle.static.data(
name='x', shape=self.x1.shape, dtype='float32'
)
y1 = paddle.static.data(
name='y', shape=self.y1.shape, dtype='float32'
)
out = paddle.nextafter(x1, y1)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x1, 'y': self.y1}, fetch_list=[out])
out_ref = ref_nextafter(self.x1, self.y1)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
with paddle.static.program_guard(paddle.static.Program()):
x2 = paddle.static.data(
name='x', shape=self.x2.shape, dtype='float32'
)
y2 = paddle.static.data(
name='y', shape=self.y2.shape, dtype='float32'
)
out = paddle.nextafter(x2, y2)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x2, 'y': self.y2}, fetch_list=[out])
out_ref = ref_nextafter(self.x2, self.y2)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
out = paddle.nextafter(x, y)
out_ref = ref_nextafter(self.x, self.y)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-05)
paddle.enable_static()
class TestNextafterOP(OpTest):
def setUp(self):
self.op_type = "nextafter"
self.python_api = paddle.nextafter
self.init_dtype()
x = np.array([1, 2]).astype(self.dtype)
y = np.array([2, 1]).astype(self.dtype)
out = np.nextafter(x, y)
self.inputs = {'x': x, 'y': y}
self.outputs = {'out': out}
def test_check_output(self):
self.check_output()
def init_dtype(self):
self.dtype = np.float64
class TestNextafterOPFP32(TestNextafterOP):
def init_dtype(self):
self.dtype = np.float32
if __name__ == "__main__":
unittest.main()
......@@ -251,6 +251,7 @@ from .math import cumulative_trapezoid # noqa: F401
from .math import sigmoid # noqa: F401
from .math import sigmoid_ # noqa: F401
from .math import vander # noqa: F401
from .math import nextafter # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -540,6 +541,7 @@ tensor_method_func = [ # noqa
'sigmoid',
'sigmoid_',
'vander',
'nextafter',
]
# this list used in math_op_patch.py for magic_method bind
......
......@@ -5486,3 +5486,39 @@ def vander(x, n=None, increasing=False, name=None):
res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1)
res = res[:, ::-1] if not increasing else res
return res
def nextafter(x, y, name=None):
r"""
Return the next floating-point value after input towards other, elementwise.
The shapes of input and other must be broadcastable.
Args:
x (Tensor): An N-D Tensor, the data type is float32, float64.
y (Tensor): An N-D Tensor, the data type is float32, float64.
name(str, optional):Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor): An N-D Tensor, the shape and data type is the same with input.
Examples:
.. code-block:: python
import paddle
out = paddle.nextafter(paddle.to_tensor([1.0,2.0]),paddle.to_tensor([2.0,1.0]))
print(out)
#Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.00000012, 1.99999988])
"""
if in_dygraph_mode():
return _C_ops.nextafter(x, y)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'nextafter')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'nextafter')
op_type = "nextafter"
helper = LayerHelper(op_type, **locals())
inputs = {"x": x, "y": y}
out = helper.create_variable_for_type_inference(dtype=paddle.float32)
outputs = {"out": out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册