From a63fb4c881a7662c60fbd997b46fd5f35ac8466d Mon Sep 17 00:00:00 2001 From: LyndonKong Date: Wed, 17 May 2023 11:04:54 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=204=20No.21=E3=80=91Add=20i1?= =?UTF-8?q?=20/=20i1e=20to=20paddle=20(#53210)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add i1 and i1e op * resolve merge conflicts --- paddle/phi/api/yaml/backward.yaml | 20 +++ paddle/phi/api/yaml/ops.yaml | 18 +++ paddle/phi/kernels/cpu/i1_grad_kernel.cc | 44 +++++ paddle/phi/kernels/cpu/i1_kernel.cc | 37 +++++ paddle/phi/kernels/cpu/i1e_grad_kernel.cc | 44 +++++ paddle/phi/kernels/cpu/i1e_kernel.cc | 37 +++++ paddle/phi/kernels/gpu/i1_grad_kernel.cu | 37 +++++ paddle/phi/kernels/gpu/i1_kernel.cu | 32 ++++ paddle/phi/kernels/gpu/i1e_grad_kernel.cu | 37 +++++ paddle/phi/kernels/gpu/i1e_kernel.cu | 32 ++++ paddle/phi/kernels/i1_grad_kernel.h | 38 +++++ paddle/phi/kernels/i1_kernel.h | 32 ++++ paddle/phi/kernels/i1e_grad_kernel.h | 29 ++++ paddle/phi/kernels/i1e_kernel.h | 30 ++++ .../impl/bessel_grad_kernel_cuda_impl.h | 76 +++++++++ .../kernels/impl/bessel_grad_kernel_impl.h | 99 ++++++++++++ .../kernels/impl/bessel_kernel_cuda_impl.h | 48 ++++++ paddle/phi/kernels/impl/bessel_kernel_impl.h | 61 +++++++ python/paddle/__init__.py | 4 + .../fluid/tests/unittests/test_i1_op.py | 151 ++++++++++++++++++ .../fluid/tests/unittests/test_i1e_op.py | 151 ++++++++++++++++++ python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 67 ++++++++ 23 files changed, 1128 insertions(+) create mode 100644 paddle/phi/kernels/cpu/i1_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/i1_kernel.cc create mode 100644 paddle/phi/kernels/cpu/i1e_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/i1e_kernel.cc create mode 100644 paddle/phi/kernels/gpu/i1_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/i1_kernel.cu create mode 100644 paddle/phi/kernels/gpu/i1e_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/i1e_kernel.cu create mode 100644 paddle/phi/kernels/i1_grad_kernel.h create mode 100644 paddle/phi/kernels/i1_kernel.h create mode 100644 paddle/phi/kernels/i1e_grad_kernel.h create mode 100644 paddle/phi/kernels/i1e_kernel.h create mode 100644 python/paddle/fluid/tests/unittests/test_i1_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_i1e_op.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 1c38f5ba51c..895f0ccb112 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -821,6 +821,26 @@ kernel : func : i0e_grad +- backward_op : i1_grad + forward : i1 (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : i1_grad + +- backward_op : i1e_grad + forward : i1e (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : i1e_grad + - backward_op : imag_grad forward : imag (Tensor x) -> Tensor(out) args : (Tensor out_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 39799238364..8928a825e46 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -948,6 +948,24 @@ func : i0e backward : i0e_grad +- op : i1 + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : i1 + backward : i1_grad + +- op : i1e + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : i1e + backward : i1e_grad + - op : imag args : (Tensor x) output : Tensor (out) diff --git a/paddle/phi/kernels/cpu/i1_grad_kernel.cc b/paddle/phi/kernels/cpu/i1_grad_kernel.cc new file mode 100644 index 00000000000..3a4f140e190 --- /dev/null +++ b/paddle/phi/kernels/cpu/i1_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h" + +namespace phi { + +template +void I1GradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + const int64_t size = x.numel(); + const T* x_data = x.data(); + const T* out_data = out.data(); + const T* out_grad_data = out_grad.data(); + T* x_grad_data = ctx.template Alloc(x_grad); + + phi::funcs::ForRange for_range(ctx, size); + I1GradFunctor functor(x_data, out_data, out_grad_data, x_grad_data, size); + for_range(functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(i1_grad, CPU, ALL_LAYOUT, phi::I1GradKernel, float, double) { +} diff --git a/paddle/phi/kernels/cpu/i1_kernel.cc b/paddle/phi/kernels/cpu/i1_kernel.cc new file mode 100644 index 00000000000..18388159a9b --- /dev/null +++ b/paddle/phi/kernels/cpu/i1_kernel.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/bessel_kernel_impl.h" + +namespace phi { + +template +void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { + const int64_t size = x.numel(); + const T* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + + phi::funcs::ForRange for_range(ctx, size); + I1Functor functor(x_data, out_data, size); + for_range(functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(i1, CPU, ALL_LAYOUT, phi::I1Kernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/i1e_grad_kernel.cc b/paddle/phi/kernels/cpu/i1e_grad_kernel.cc new file mode 100644 index 00000000000..e97d12c4c12 --- /dev/null +++ b/paddle/phi/kernels/cpu/i1e_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1e_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h" + +namespace phi { + +template +void I1eGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + const int64_t size = x.numel(); + const T* x_data = x.data(); + const T* out_data = out.data(); + const T* out_grad_data = out_grad.data(); + T* x_grad_data = ctx.template Alloc(x_grad); + + phi::funcs::ForRange for_range(ctx, size); + I1eGradFunctor functor(x_data, out_data, out_grad_data, x_grad_data, size); + for_range(functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + i1e_grad, CPU, ALL_LAYOUT, phi::I1eGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/i1e_kernel.cc b/paddle/phi/kernels/cpu/i1e_kernel.cc new file mode 100644 index 00000000000..5f50d723e9a --- /dev/null +++ b/paddle/phi/kernels/cpu/i1e_kernel.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1e_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/impl/bessel_kernel_impl.h" + +namespace phi { + +template +void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { + const int64_t size = x.numel(); + const T* x_data = x.data(); + T* out_data = ctx.template Alloc(out); + + phi::funcs::ForRange for_range(ctx, size); + I1eFunctor functor(x_data, out_data, size); + for_range(functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(i1e, CPU, ALL_LAYOUT, phi::I1eKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/i1_grad_kernel.cu b/paddle/phi/kernels/gpu/i1_grad_kernel.cu new file mode 100644 index 00000000000..35ab69a176e --- /dev/null +++ b/paddle/phi/kernels/gpu/i1_grad_kernel.cu @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h" + +namespace phi { + +template +void I1GradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + ctx.template Alloc(x_grad); + std::vector ins = {&x, &out, &out_grad}; + std::vector outs = {x_grad}; + auto functor = CudaI1GradFunctor(); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(i1_grad, GPU, ALL_LAYOUT, phi::I1GradKernel, float, double) { +} diff --git a/paddle/phi/kernels/gpu/i1_kernel.cu b/paddle/phi/kernels/gpu/i1_kernel.cu new file mode 100644 index 00000000000..c0387f451e5 --- /dev/null +++ b/paddle/phi/kernels/gpu/i1_kernel.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h" + +namespace phi { + +template +void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { + ctx.template Alloc(out); + std::vector ins = {&x}; + std::vector outs = {out}; + auto functor = CudaI1Functor(); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(i1, GPU, ALL_LAYOUT, phi::I1Kernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/i1e_grad_kernel.cu b/paddle/phi/kernels/gpu/i1e_grad_kernel.cu new file mode 100644 index 00000000000..a01a3ab08f4 --- /dev/null +++ b/paddle/phi/kernels/gpu/i1e_grad_kernel.cu @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1e_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h" + +namespace phi { + +template +void I1eGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + ctx.template Alloc(x_grad); + std::vector ins = {&x, &out, &out_grad}; + std::vector outs = {x_grad}; + auto functor = CudaI1eGradFunctor(); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + i1e_grad, GPU, ALL_LAYOUT, phi::I1eGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/i1e_kernel.cu b/paddle/phi/kernels/gpu/i1e_kernel.cu new file mode 100644 index 00000000000..8dff432a628 --- /dev/null +++ b/paddle/phi/kernels/gpu/i1e_kernel.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/i1e_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h" + +namespace phi { + +template +void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { + ctx.template Alloc(out); + std::vector ins = {&x}; + std::vector outs = {out}; + auto functor = CudaI1eFunctor(); + phi::funcs::ElementwiseKernel(ctx, ins, &outs, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(i1e, GPU, ALL_LAYOUT, phi::I1eKernel, float, double) {} diff --git a/paddle/phi/kernels/i1_grad_kernel.h b/paddle/phi/kernels/i1_grad_kernel.h new file mode 100644 index 00000000000..1d95c10b6c9 --- /dev/null +++ b/paddle/phi/kernels/i1_grad_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +/** + * @brief This kernel calculate gradient of Modified Bessel function of order 1. + * @param ctx device context + * @param x + * @param out + * @param out_grad + * @param x_grad + */ + +template +void I1GradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/i1_kernel.h b/paddle/phi/kernels/i1_kernel.h new file mode 100644 index 00000000000..8865c573c9a --- /dev/null +++ b/paddle/phi/kernels/i1_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2023PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/** + * @brief This kernel calculate Modified Bessel function of order 1. + * @param ctx device context + * @param x The input tensor of i1 + * @param out The output tensor of i1 kernel, it has the same shape and + * dtype with input. Each element corresponds to input tensor + */ + +template +void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/i1e_grad_kernel.h b/paddle/phi/kernels/i1e_grad_kernel.h new file mode 100644 index 00000000000..492305022d0 --- /dev/null +++ b/paddle/phi/kernels/i1e_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void I1eGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/i1e_kernel.h b/paddle/phi/kernels/i1e_kernel.h new file mode 100644 index 00000000000..6ca3d9f767f --- /dev/null +++ b/paddle/phi/kernels/i1e_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023PaddlePaddle Authors. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +/** + * @brief This kernel calculate Exponentially scaled modified Bessel function of + * order 1.. + * @param ctx device context + * @param x The input tensor of i1e + * @param out The output tensor of i1e kernel, it has the same shape and + * dtype with input. Each element corresponds to input tensor + */ + +template +void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h b/paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h index 7789e3d8e93..b3e272800b4 100644 --- a/paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h +++ b/paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h @@ -87,4 +87,80 @@ struct CudaI0eGradFunctor { } }; +template +struct CudaI1GradFunctor { + __device__ __forceinline__ T operator()(const T _x, + const T _out, + const T _out_grad) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + const MT mp_out = static_cast(_out); + const MT mp_out_grad = static_cast(_out_grad); + MT x = std::abs(mp_x); + if (x <= MT{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI0e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + MT y = (x / MT{2.0}) - MT{2.0}; + MT eps = static_cast(std::numeric_limits::epsilon()); + + if (x <= eps) { + MT out = (MT{0.5}) * mp_out_grad; + return static_cast(out); + } else { + return static_cast( + (std::exp(x) * Chbevl(y, A, len) - mp_out / mp_x) * + mp_out_grad); + } + } + auto coeff_pair_B = ChebyshevCoefficientsI0e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + MT y = (MT{32.0} / x) - MT{2.0}; + + return static_cast( + (std::exp(x) * Chbevl(y, B, len) / std::sqrt(x) - mp_out / mp_x) * + mp_out_grad); + } +}; + +template +struct CudaI1eGradFunctor { + __device__ __forceinline__ T operator()(const T _x, + const T _out, + const T _out_grad) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + const MT mp_out = static_cast(_out); + const MT mp_out_grad = static_cast(_out_grad); + MT x = std::abs(mp_x); + if (x <= MT{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI0e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + MT y = (x / MT{2.0}) - MT{2.0}; + MT eps = static_cast(std::numeric_limits::epsilon()); + + if (x <= eps) { + MT out = (MT{0.5}) * mp_out_grad; + return static_cast(out); + } else { + MT out = (Chbevl(y, A, len) - + mp_out * (std::copysign(MT{1.0}, mp_x) + (MT{1.0}) / mp_x)) * + mp_out_grad; + return static_cast(out); + } + } + auto coeff_pair_B = ChebyshevCoefficientsI0e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + MT y = (MT{32.0} / x) - MT{2.0}; + + return static_cast( + (Chbevl(y, B, len) / std::sqrt(x) - + mp_out * (std::copysign(MT{1.0}, mp_x) + (MT{1.0}) / mp_x)) * + mp_out_grad); + } +}; + } // namespace phi diff --git a/paddle/phi/kernels/impl/bessel_grad_kernel_impl.h b/paddle/phi/kernels/impl/bessel_grad_kernel_impl.h index 1de9a9fcbfc..e48c0c8ec11 100644 --- a/paddle/phi/kernels/impl/bessel_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/bessel_grad_kernel_impl.h @@ -109,4 +109,103 @@ struct I0eGradFunctor { int64_t numel_; }; +template +struct I1GradFunctor { + I1GradFunctor( + const T* x, const T* out, const T* out_grad, T* x_grad, int64_t numel) + : input_x_(x), + input_out_(out), + input_out_grad_(out_grad), + output_x_grad_(x_grad), + numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + T x = std::abs(input_x_[idx]); + T x_ = input_x_[idx]; + T out_ = input_out_[idx]; + T out_grad_ = input_out_grad_[idx]; + if (x <= T{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI0e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + T y = (x / T{2.0}) - T{2.0}; + T eps = std::numeric_limits::epsilon(); + + if (x <= eps) { + output_x_grad_[idx] = static_cast(T{0.5} * out_grad_); + } else { + output_x_grad_[idx] = static_cast( + (std::exp(x) * Chbevl(y, A, len) - out_ / x_) * out_grad_); + } + } else { + auto coeff_pair_B = ChebyshevCoefficientsI0e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + T y = (T{32.0} / x) - T{2.0}; + + output_x_grad_[idx] = static_cast( + (std::exp(x) * Chbevl(y, B, len) / std::sqrt(x) - out_ / x_) * + out_grad_); + } + } + + private: + const T* input_x_; + const T* input_out_; + const T* input_out_grad_; + T* output_x_grad_; + int64_t numel_; +}; + +template +struct I1eGradFunctor { + I1eGradFunctor( + const T* x, const T* out, const T* out_grad, T* x_grad, int64_t numel) + : input_x_(x), + input_out_(out), + input_out_grad_(out_grad), + output_x_grad_(x_grad), + numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + T x = std::abs(input_x_[idx]); + T x_ = input_x_[idx]; + T out_ = input_out_[idx]; + T out_grad_ = input_out_grad_[idx]; + if (x <= T{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI0e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + T y = (x / T{2.0}) - T{2.0}; + T eps = std::numeric_limits::epsilon(); + + if (x <= eps) { + output_x_grad_[idx] = static_cast(T{0.5} * out_grad_); + } else { + output_x_grad_[idx] = + static_cast((Chbevl(y, A, len) - + out_ * (std::copysign(T{1.0}, x_) + T{1.0} / x_)) * + out_grad_); + } + } else { + auto coeff_pair_B = ChebyshevCoefficientsI0e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + T y = (T{32.0} / x) - T{2.0}; + + output_x_grad_[idx] = + static_cast((Chbevl(y, B, len) / std::sqrt(x) - + out_ * (std::copysign(T{1.0}, x_) + T{1.0} / x_)) * + out_grad_); + } + } + + private: + const T* input_x_; + const T* input_out_; + const T* input_out_grad_; + T* output_x_grad_; + int64_t numel_; +}; + } // namespace phi diff --git a/paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h b/paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h index 32649c90190..63ae22d046a 100644 --- a/paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h +++ b/paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h @@ -228,4 +228,52 @@ ChebyshevCoefficientsI1e_B() { return std::make_tuple(coeff, 7); } +template +struct CudaI1Functor { + __device__ __forceinline__ T operator()(const T _x) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + MT x = std::abs(mp_x); + if (x <= MT{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI1e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + MT y = (x / MT{2.0}) - MT{2.0}; + const T out = std::exp(x) * x * Chbevl(y, A, len); + return (mp_x < MT{0.0}) ? -out : out; + } + auto coeff_pair_B = ChebyshevCoefficientsI1e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + MT y = (MT{32.0} / x) - MT{2.0}; + const T out = (std::exp(x) * Chbevl(y, B, len)) / std::sqrt(x); + return (mp_x < MT{0.0}) ? -out : out; + } +}; + +template +struct CudaI1eFunctor { + __device__ __forceinline__ T operator()(const T _x) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(_x); + MT x = std::abs(mp_x); + if (x <= MT{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI1e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + MT y = (x / MT{2.0}) - MT{2.0}; + + const T out = static_cast(Chbevl(y, A, len) * x); + return (mp_x < MT{0.0}) ? -out : out; + } + auto coeff_pair_B = ChebyshevCoefficientsI1e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + MT y = (MT{32.0} / x) - MT{2.0}; + + const T out = static_cast(Chbevl(y, B, len) / std::sqrt(x)); + return (mp_x < MT{0.0}) ? -out : out; + } +}; + } // namespace phi diff --git a/paddle/phi/kernels/impl/bessel_kernel_impl.h b/paddle/phi/kernels/impl/bessel_kernel_impl.h index 1df09075393..12eae74ce96 100644 --- a/paddle/phi/kernels/impl/bessel_kernel_impl.h +++ b/paddle/phi/kernels/impl/bessel_kernel_impl.h @@ -251,4 +251,65 @@ ChebyshevCoefficientsI1e_B() { return std::make_tuple(coeff, 7); } +template +struct I1Functor { + I1Functor(const T* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + T x = std::abs(input_[idx]); + if (x <= T{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI1e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + T y = (x / T{2.0}) - T{2.0}; + const T out = std::exp(x) * x * Chbevl(y, A, len); + output_[idx] = (input_[idx] < T{0.0}) ? -out : out; + } else { + auto coeff_pair_B = ChebyshevCoefficientsI1e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + T y = (T{32.0} / x) - T{2.0}; + const T out = (std::exp(x) * Chbevl(y, B, len)) / std::sqrt(x); + output_[idx] = (input_[idx] < T{0.0}) ? -out : out; + } + } + + private: + const T* input_; + T* output_; + int64_t numel_; +}; + +template +struct I1eFunctor { + I1eFunctor(const T* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + T x = std::abs(input_[idx]); + if (x <= T{8.0}) { + auto coeff_pair_A = ChebyshevCoefficientsI1e_A(); + auto A = std::get<0>(coeff_pair_A); + auto len = std::get<1>(coeff_pair_A); + T y = (x / T{2.0}) - T{2.0}; + const T out = Chbevl(y, A, len) * x; + output_[idx] = (input_[idx] < T{0.0}) ? -out : out; + } else { + auto coeff_pair_B = ChebyshevCoefficientsI1e_B(); + auto B = std::get<0>(coeff_pair_B); + auto len = std::get<1>(coeff_pair_B); + T y = (T{32.0} / x) - T{2.0}; + + const T out = Chbevl(y, B, len) / std::sqrt(x); + output_[idx] = (input_[idx] < T{0.0}) ? -out : out; + } + } + + private: + const T* input_; + T* output_; + int64_t numel_; +}; + } // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 8626e5e9f37..f878a6b8d7f 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -306,6 +306,8 @@ from .tensor.math import vander # noqa: F401 from .tensor.math import nextafter # noqa: F401 from .tensor.math import i0 # noqa: F401 from .tensor.math import i0e # noqa: F401 +from .tensor.math import i1 # noqa: F401 +from .tensor.math import i1e # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -702,4 +704,6 @@ __all__ = [ # noqa 'nextafter', 'i0', 'i0e', + 'i1', + 'i1e', ] diff --git a/python/paddle/fluid/tests/unittests/test_i1_op.py b/python/paddle/fluid/tests/unittests/test_i1_op.py new file mode 100644 index 00000000000..b5846d6e156 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_i1_op.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from eager_op_test import OpTest +from scipy import special + +import paddle +from paddle.fluid import core + +np.random.seed(42) +paddle.seed(42) + + +def reference_i1(x): + return special.i1(x) + + +def reference_i1_grad(x, dout): + eps = np.finfo(x.dtype).eps + not_tiny = abs(x) > eps + safe_x = np.where(not_tiny, x, eps) + gradx = special.i0(safe_x) - special.i1(x) / safe_x + gradx = np.where(not_tiny, gradx, 0.5) + return dout * gradx + + +class Testi1API(unittest.TestCase): + DTYPE = "float64" + DATA = [0, 1, 2, 3, 4, 5] + + def setUp(self): + self.x = np.array(self.DATA).astype(self.DTYPE) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_api_static(self): + def run(place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="x", shape=self.x.shape, dtype=self.DTYPE + ) + out = paddle.i1(x) + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={"x": self.x}, + fetch_list=[out], + ) + out_ref = reference_i1(self.x) + np.testing.assert_allclose(res[0], out_ref, rtol=1e-5) + paddle.disable_static() + + for place in self.place: + run(place) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.i1(x) + + out_ref = reference_i1(self.x) + np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-5) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_empty_input_error(self): + for place in self.place: + paddle.disable_static(place) + x = None + self.assertRaises(ValueError, paddle.i1, x) + paddle.enable_static() + + +class Testi1Float32Zero2EightCase(Testi1API): + DTYPE = "float32" + DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8] + + +class Testi1Float32OverEightCase(Testi1API): + DTYPE = "float32" + DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17] + + +class Testi1Float64Zero2EightCase(Testi1API): + DTYPE = "float64" + DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8] + + +class Testi1Float64OverEightCase(Testi1API): + DTYPE = "float64" + DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17] + + +class TestI1Op(OpTest): + # 配置 op 信息以及输入输出等参数 + def setUp(self): + self.op_type = "i1" + self.python_api = paddle.i1 + self.init_config() + self.outputs = {'out': self.target} + + # 测试前向输出结果 + def test_check_output(self): + self.check_output() + + # 测试反向梯度输出 + def test_check_grad(self): + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + reference_i1_grad( + self.case, + 1 / self.case.size, + ) + ], + ) + + def init_config(self): + # 生成随机的输入数据 + zero_case = np.zeros(1).astype('float64') + rand_case = np.random.randn(250).astype('float64') + over_eight_case = np.random.uniform(low=8, high=9, size=250).astype( + 'float64' + ) + self.case = np.concatenate([zero_case, rand_case, over_eight_case]) + self.inputs = {'x': self.case} + self.target = reference_i1(self.inputs['x']) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_i1e_op.py b/python/paddle/fluid/tests/unittests/test_i1e_op.py new file mode 100644 index 00000000000..a4c360ae9a3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_i1e_op.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from eager_op_test import OpTest +from scipy import special + +import paddle +from paddle.fluid import core + +np.random.seed(42) +paddle.seed(42) + + +def reference_i1e(x): + return special.i1e(x) + + +def reference_i1e_grad(x, dout): + eps = np.finfo(x.dtype).eps + not_tiny = abs(x) > eps + safe_x = np.where(not_tiny, x, eps) + gradx = special.i0e(safe_x) - special.i1e(x) * (np.sign(x) + 1 / safe_x) + gradx = np.where(not_tiny, gradx, 0.5) + return dout * gradx + + +class TestI1e_API(unittest.TestCase): + DTYPE = "float64" + DATA = [0, 1, 2, 3, 4, 5] + + def setUp(self): + self.x = np.array(self.DATA).astype(self.DTYPE) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_api_static(self): + def run(place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data( + name="x", shape=self.x.shape, dtype=self.DTYPE + ) + y = paddle.i1e(x) + exe = paddle.static.Executor(place) + res = exe.run( + paddle.static.default_main_program(), + feed={"x": self.x}, + fetch_list=[y], + ) + out_ref = reference_i1e(self.x) + np.testing.assert_allclose(out_ref, res[0], rtol=1e-5) + paddle.disable_static() + + for place in self.place: + run(place) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.i1e(x) + + out_ref = reference_i1e(self.x) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_empty_input_error(self): + for place in self.place: + paddle.disable_static(place) + x = None + self.assertRaises(ValueError, paddle.i1e, x) + paddle.enable_static() + + +class Testi1eFloat32Zero2EightCase(TestI1e_API): + DTYPE = "float32" + DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8] + + +class Testi1eFloat32OverEightCase(TestI1e_API): + DTYPE = "float32" + DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17] + + +class Testi1eFloat64Zero2EightCase(TestI1e_API): + DTYPE = "float64" + DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8] + + +class Testi1eFloat64OverEightCase(TestI1e_API): + DTYPE = "float64" + DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17] + + +class TestI1eOp(OpTest): + # 配置 op 信息以及输入输出等参数 + def setUp(self): + self.op_type = "i1e" + self.python_api = paddle.i1e + self.init_config() + self.outputs = {'out': self.target} + + # 测试前向输出结果 + def test_check_output(self): + self.check_output() + + # 测试反向梯度输出 + def test_check_grad(self): + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + reference_i1e_grad( + self.case, + 1 / self.case.size, + ) + ], + ) + + # 生成随机的输入数据并计算对应输出 + def init_config(self): + zero_case = np.zeros(1).astype('float64') + rand_case = np.random.randn(250).astype('float64') + over_eight_case = np.random.uniform(low=8, high=9, size=250).astype( + 'float64' + ) + self.case = np.concatenate([zero_case, rand_case, over_eight_case]) + self.inputs = {'x': self.case} + self.target = reference_i1e(self.inputs['x']) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 1e26e4ce56d..ba13034fc56 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -258,6 +258,8 @@ from .math import vander # noqa: F401 from .math import nextafter # noqa: F401 from .math import i0 # noqa: F401 from .math import i0e # noqa: F401 +from .math import i1 # noqa: F401 +from .math import i1e # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -554,6 +556,8 @@ tensor_method_func = [ # noqa 'unflatten', 'i0', 'i0e', + 'i1', + 'i1e', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index debf74e87ac..f3635a69651 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5657,3 +5657,70 @@ def i0e(x, name=None): out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type='i0e', inputs={'x': x}, outputs={'out': out}) return out + + +def i1(x, name=None): + """ + The function is used to calculate modified bessel function of order 1. + + Args: + x (Tensor): The input tensor, it's data type should be float32, float64. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + - out (Tensor), A Tensor. the value of the modified bessel function of order 1 at x. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32") + print(paddle.i1(x)) + # (Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0., 0.5651591 , 1.59063685 , 3.95337022 , 9.75946515]), + """ + if in_dygraph_mode(): + return _C_ops.i1(x) + else: + check_variable_and_dtype(x, "x", ["float32", "float64"], "i1") + + helper = LayerHelper("i1", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='i1', inputs={'x': x}, outputs={'out': out}, attrs={} + ) + return out + + +def i1e(x, name=None): + """ + The function is used to calculate exponentially scaled modified Bessel function of order 1. + + Args: + + x (Tensor): The input tensor, it's data type should be float32, float64. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + + Returns: + - out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 1 at x. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32") + print(paddle.i1e(x)) + # (Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0., 0.20791042, 0.21526929, 0.24300035, 0.17875084]), + """ + if in_dygraph_mode(): + return _C_ops.i1e(x) + else: + check_variable_and_dtype(x, "x", ["float32", "float64"], "i1e") + + helper = LayerHelper("i1e", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='i1e', inputs={'x': x}, outputs={'out': out}, attrs={} + ) + return out -- GitLab