未验证 提交 ce256f75 编写于 作者: P PommesPeter 提交者: GitHub

【Hackathon 4 No.20】Add i0 / i0e to paddle (#52058)

* added base code for i0 and i0e

* added grad base code for i0 and i0e

* added i0 and i0e python code

* added ops and backward yaml config

* added i0 and i0e cpu kernel, but not test.

* added i0 and i0e code and unitest files

* added test files

* added i0/i0e gpu implementation code

* updated code style

* updated code style

* fixed unitests code

* updated i0 with eigen3

* fixed bug and added more test cases

* refactor: fixed static graph bug

* refactor: removed i0 and i0e from op_compat

* refactor: updated code style

* refactor: updated op_compat.yaml

* refactor: updated op_compat.yaml

* refactor: fixed op name mapping and optimize unittest case

* refactor: manually implement i0 / i0e

* refactor: added grad kernel for i0 / i0e,didn't finish

* Update math.py

* refactor: added equation to doc in English and added comments for computing i0 / i0e gradient

* refactor: removed eigen implementation

* refactor: finished i0 / i0e cpu and gpu op

* refactor: updated code style

* fix: find  a bug but not fix

* fix: incorrect unittest cases

* update: updated code style and remove my file

* update: updated unittest case

* fix: fixed sign error

* fix: fixed mistakes when merging

* refactor: updated code style

* refactor: remove unused code

* refactor: updated code style
上级 4d39cc7f
...@@ -775,6 +775,26 @@ ...@@ -775,6 +775,26 @@
kernel : kernel :
func : huber_loss_grad func : huber_loss_grad
- backward_op : i0_grad
forward : i0 (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i0_grad
- backward_op : i0e_grad
forward : i0e (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i0e_grad
- backward_op : imag_grad - backward_op : imag_grad
forward : imag (Tensor x) -> Tensor(out) forward : imag (Tensor x) -> Tensor(out)
args : (Tensor out_grad) args : (Tensor out_grad)
......
...@@ -920,6 +920,24 @@ ...@@ -920,6 +920,24 @@
intermediate : residual intermediate : residual
backward : huber_loss_grad backward : huber_loss_grad
- op : i0
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i0
backward : i0_grad
- op : i0e
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i0e
backward : i0e_grad
- op : imag - op : imag
args : (Tensor x) args : (Tensor x)
output : Tensor (out) output : Tensor (out)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i0_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I0GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto size = x.numel();
auto* x_data = x.data<T>();
auto* out_grad_data = out_grad.data<T>();
auto* x_grad_data = ctx.template Alloc<T>(x_grad);
phi::funcs::ForRange<Context> for_range(ctx, size);
I0GradFunctor<T> functor(x_data, out_grad_data, x_grad_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i0_grad, CPU, ALL_LAYOUT, phi::I0GradKernel, float, double) {
}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/i0_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I0Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
phi::funcs::ForRange<Context> for_range(ctx, size);
I0Functor<T> functor(x_data, out_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i0, CPU, ALL_LAYOUT, phi::I0Kernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i0e_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I0eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto size = x.numel();
auto* x_data = x.data<T>();
auto* out_data = out.data<T>();
auto* out_grad_data = out_grad.data<T>();
auto* x_gard_data = ctx.template Alloc<T>(x_grad);
phi::funcs::ForRange<Context> for_range(ctx, size);
I0eGradFunctor<T> functor(x_data, out_data, out_grad_data, x_gard_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
i0e_grad, CPU, ALL_LAYOUT, phi::I0eGradKernel, float, double) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/i0e_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I0eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
phi::funcs::ForRange<Context> for_range(ctx, size);
I0eFunctor<T> functor(x_data, out_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i0e, CPU, ALL_LAYOUT, phi::I0eKernel, float, double) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/i0_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I0GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI0GradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i0_grad, GPU, ALL_LAYOUT, phi::I0GradKernel, float, double) {
}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/i0_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I0Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI0Functor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i0, GPU, ALL_LAYOUT, phi::I0Kernel, float, double) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/i0e_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I0eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI0eGradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
i0e_grad, GPU, ALL_LAYOUT, phi::I0eGradKernel, float, double) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/i0e_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I0eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI0eFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i0e, GPU, ALL_LAYOUT, phi::I0eKernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void I0GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2023PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
/**
* @brief This kernel calculate Modified Bessel function of order 0.
* @param ctx device context
* @param x The input tensor of i0
* @param out The output tensor of i0 kernel, it has the same shape and
* dtype with input. Each element corresponds to input tensor
*/
template <typename T, typename Context>
void I0Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void I0eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
/* Copyright (c) 2023PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
/**
* @brief This kernel calculate Exponentially scaled modified Bessel function of
* order 0..
* @param ctx device context
* @param x The input tensor of i0e
* @param out The output tensor of i0e kernel, it has the same shape and
* dtype with input. Each element corresponds to input tensor
*/
template <typename T, typename Context>
void I0eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"
namespace phi {
template <typename T>
struct CudaI0GradFunctor {
__device__ __forceinline__ T operator()(const T _x, const T _out_grad) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
const MT mp_out_grad = static_cast<MT>(_out_grad);
// get ouput of i1
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
const MT i1_out = std::exp(x) * x * Chbevl<MT>(y, A, len);
const MT i1_data = (mp_x < MT{0.0}) ? -i1_out : i1_out;
// calculate i0 gradient
return static_cast<T>(i1_data * mp_out_grad);
}
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
const MT i1_out = (std::exp(x) * Chbevl<MT>(y, B, len)) / std::sqrt(x);
const MT i1_data = (mp_x < MT{0.0}) ? -i1_out : i1_out;
return static_cast<T>(i1_data * mp_out_grad);
}
};
template <typename T>
struct CudaI0eGradFunctor {
__device__ __forceinline__ T operator()(const T _x,
const T _out,
const T _out_grad) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
const MT mp_out = static_cast<MT>(_out);
const MT mp_out_grad = static_cast<MT>(_out_grad);
// get output of i1e
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
const MT i1e_out = Chbevl<MT>(y, A, len) * x;
const MT i1e_data = (mp_x < MT{0.0}) ? -i1e_out : i1e_out;
// calculate i0e gradient
return static_cast<T>((i1e_data - std::copysign(MT{1.0}, mp_x) * mp_out) *
mp_out_grad);
}
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
const MT i1e_out = Chbevl<MT>(y, B, len) / std::sqrt(x);
const MT i1e_data = (mp_x < MT{0.0}) ? -i1e_out : i1e_out;
return static_cast<T>((i1e_data - std::copysign(MT{1.0}, mp_x) * mp_out) *
mp_out_grad);
}
};
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"
namespace phi {
template <typename T>
struct I0GradFunctor {
I0GradFunctor(const T* x, const T* out_grad, T* x_grad, int64_t numel)
: inp_x_(x),
inp_out_grad_(out_grad),
output_x_grad_(x_grad),
numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(inp_x_[idx]);
const MT mp_out_grad = static_cast<MT>(inp_out_grad_[idx]);
MT x = std::abs(mp_x);
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
const MT i1_out = std::exp(x) * x * Chbevl<MT>(y, A, len);
const MT i1_data = (mp_x < T{0.0}) ? -i1_out : i1_out;
output_x_grad_[idx] = static_cast<T>(i1_data * mp_out_grad);
} else {
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
const MT i1_out = (std::exp(x) * Chbevl<MT>(y, B, len)) / std::sqrt(x);
const MT i1_data = (mp_x < MT{0.0}) ? -i1_out : i1_out;
output_x_grad_[idx] = static_cast<T>(i1_data * mp_out_grad);
}
}
private:
const T* inp_x_;
const T* inp_out_grad_;
T* output_x_grad_;
int64_t numel_;
};
template <typename T>
struct I0eGradFunctor {
I0eGradFunctor(
const T* x, const T* out, const T* out_grad, T* x_grad, int64_t numel)
: inp_x_(x),
inp_out_(out),
inp_out_grad_(out_grad),
output_x_grad_(x_grad),
numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
T x = std::abs(inp_x_[idx]);
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<T>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
T y = (x / T{2.0}) - T{2.0};
const T out = Chbevl<T>(y, A, len) * x;
const T i1e_out = (inp_x_[idx] < T{0.0}) ? -out : out;
output_x_grad_[idx] =
(i1e_out - std::copysign(T{1.0}, inp_x_[idx]) * inp_out_[idx]) *
inp_out_grad_[idx];
} else {
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<T>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
T y = (T{32.0} / x) - T{2.0};
const T out = Chbevl<T>(y, B, len) / std::sqrt(x);
const T i1e_out = (inp_x_[idx] < T{0.0}) ? -out : out;
output_x_grad_[idx] =
(i1e_out - std::copysign(T{1.0}, inp_x_[idx]) * inp_out_[idx]) *
inp_out_grad_[idx];
}
}
private:
const T* inp_x_;
const T* inp_out_;
const T* inp_out_grad_;
T* output_x_grad_;
int64_t numel_;
};
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
__host__ __device__ std::tuple<const T*, size_t> ChebyshevCoefficientsI0e_A() {
/* Chebyshev coefficients for I0e(x) in the interval [0,8]. */
static const T coeff[] = {
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1};
return std::make_tuple(coeff, 30);
}
template <typename T>
__host__ __device__ std::tuple<const T*, size_t> ChebyshevCoefficientsI0e_B() {
/* Chebyshev coefficients for I0e(x) in the inverted interval [8,infinity]. */
static const T coeff[] = {
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1};
return std::make_tuple(coeff, 25);
}
template <typename T>
__host__ __device__ T Chbevl(T x, const T array[], size_t len) {
T b0, b1, b2;
b0 = array[0];
b1 = static_cast<T>(0.0);
for (size_t i = 1; i < len; ++i) {
b2 = b1;
b1 = b0;
b0 = x * b1 - b2 + array[i];
}
return (static_cast<T>(0.5) * (b0 - b2));
}
template <typename T>
struct CudaI0Functor {
__device__ __forceinline__ T operator()(const T _x) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
return static_cast<T>(std::exp(x) * Chbevl<MT>(y, A, len));
}
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
return static_cast<T>(std::exp(x) * Chbevl<T>(y, B, len) / std::sqrt(x));
}
};
template <typename T>
struct CudaI0eFunctor {
__device__ __forceinline__ T operator()(const T _x) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
return static_cast<T>(Chbevl<MT>(y, A, len));
}
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
return static_cast<T>(Chbevl<T>(y, B, len) / std::sqrt(x));
}
};
template <typename T>
__host__ __device__ typename std::enable_if<std::is_same<double, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I1(x) / x } = 1/2.
*/
static const T coeff[] = {
2.77791411276104639959E-18, -2.11142121435816608115E-17,
1.55363195773620046921E-16, -1.10559694773538630805E-15,
7.60068429473540693410E-15, -5.04218550472791168711E-14,
3.22379336594557470981E-13, -1.98397439776494371520E-12,
1.17361862988909016308E-11, -6.66348972350202774223E-11,
3.62559028155211703701E-10, -1.88724975172282928790E-9,
9.38153738649577178388E-9, -4.44505912879632808065E-8,
2.00329475355213526229E-7, -8.56872026469545474066E-7,
3.47025130813767847674E-6, -1.32731636560394358279E-5,
4.78156510755005422638E-5, -1.61760815825896745588E-4,
5.12285956168575772895E-4, -1.51357245063125314899E-3,
4.15642294431288815669E-3, -1.05640848946261981558E-2,
2.47264490306265168283E-2, -5.29459812080949914269E-2,
1.02643658689847095384E-1, -1.76416518357834055153E-1,
2.52587186443633654823E-1};
return std::make_tuple(coeff, 29);
}
template <typename T>
__host__ __device__ typename std::enable_if<std::is_same<float, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I1(x) / x } = 1/2.
*/
static const T coeff[] = {9.38153738649577178388E-9f,
-4.44505912879632808065E-8f,
2.00329475355213526229E-7f,
-8.56872026469545474066E-7f,
3.47025130813767847674E-6f,
-1.32731636560394358279E-5f,
4.78156510755005422638E-5f,
-1.61760815825896745588E-4f,
5.12285956168575772895E-4f,
-1.51357245063125314899E-3f,
4.15642294431288815669E-3f,
-1.05640848946261981558E-2f,
2.47264490306265168283E-2f,
-5.29459812080949914269E-2f,
1.02643658689847095384E-1f,
-1.76416518357834055153E-1f,
2.52587186443633654823E-1f};
return std::make_tuple(coeff, 17);
}
template <typename T>
__host__ __device__ typename std::enable_if<std::is_same<double, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
*/
static const T coeff[] = {
7.51729631084210481353E-18, 4.41434832307170791151E-18,
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
2.96262899764595013876E-16, 3.30820231092092828324E-16,
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
1.04202769841288027642E-14, 4.27244001671195135429E-14,
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
1.41258074366137813316E-11, 3.25260358301548823856E-11,
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
7.78576235018280120474E-1};
return std::make_tuple(coeff, 25);
}
template <typename T>
__host__ __device__ typename std::enable_if<std::is_same<float, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
*/
static const T coeff[] = {-3.83538038596423702205E-9f,
-2.63146884688951950684E-8f,
-2.51223623787020892529E-7f,
-3.88256480887769039346E-6f,
-1.10588938762623716291E-4f,
-9.76109749136146840777E-3f,
7.78576235018280120474E-1f};
return std::make_tuple(coeff, 7);
}
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
static inline std::tuple<const T*, size_t> ChebyshevCoefficientsI0e_A() {
/* Chebyshev coefficients for exp(-x) I0(x)
* in the interval [0,8].
*
* lim(x->0) { exp(-x) I0(x) } = 1.
*/
static const T coeff[] = {
-4.41534164647933937950E-18, 3.33079451882223809783E-17,
-2.43127984654795469359E-16, 1.71539128555513303061E-15,
-1.16853328779934516808E-14, 7.67618549860493561688E-14,
-4.85644678311192946090E-13, 2.95505266312963983461E-12,
-1.72682629144155570723E-11, 9.67580903537323691224E-11,
-5.18979560163526290666E-10, 2.65982372468238665035E-9,
-1.30002500998624804212E-8, 6.04699502254191894932E-8,
-2.67079385394061173391E-7, 1.11738753912010371815E-6,
-4.41673835845875056359E-6, 1.64484480707288970893E-5,
-5.75419501008210370398E-5, 1.88502885095841655729E-4,
-5.76375574538582365885E-4, 1.63947561694133579842E-3,
-4.32430999505057594430E-3, 1.05464603945949983183E-2,
-2.37374148058994688156E-2, 4.93052842396707084878E-2,
-9.49010970480476444210E-2, 1.71620901522208775349E-1,
-3.04682672343198398683E-1, 6.76795274409476084995E-1};
return std::make_tuple(coeff, 30);
}
template <typename T>
static inline std::tuple<const T*, size_t> ChebyshevCoefficientsI0e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
*/
static const T coeff[] = {
-7.23318048787475395456E-18, -4.83050448594418207126E-18,
4.46562142029675999901E-17, 3.46122286769746109310E-17,
-2.82762398051658348494E-16, -3.42548561967721913462E-16,
1.77256013305652638360E-15, 3.81168066935262242075E-15,
-9.55484669882830764870E-15, -4.15056934728722208663E-14,
1.54008621752140982691E-14, 3.85277838274214270114E-13,
7.18012445138366623367E-13, -1.79417853150680611778E-12,
-1.32158118404477131188E-11, -3.14991652796324136454E-11,
1.18891471078464383424E-11, 4.94060238822496958910E-10,
3.39623202570838634515E-9, 2.26666899049817806459E-8,
2.04891858946906374183E-7, 2.89137052083475648297E-6,
6.88975834691682398426E-5, 3.36911647825569408990E-3,
8.04490411014108831608E-1};
return std::make_tuple(coeff, 25);
}
template <typename T>
static inline T Chbevl(T x, const T array[], size_t len) {
T b0, b1, b2;
b0 = array[0];
b1 = static_cast<T>(0.0);
for (size_t i = 1; i < len; ++i) {
b2 = b1;
b1 = b0;
b0 = x * b1 - b2 + array[i];
}
return (static_cast<T>(0.5) * (b0 - b2));
}
template <typename T>
struct I0eFunctor {
I0eFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
T x = std::abs(input_[idx]);
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<T>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
T y = (x / T{2.0}) - T{2.0};
output_[idx] = static_cast<T>(Chbevl<T>(y, A, len));
} else {
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<T>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
T y = (T{32.0} / x) - T{2.0};
output_[idx] = static_cast<T>(Chbevl<T>(y, B, len) / std::sqrt(x));
}
}
private:
const T* input_;
T* output_;
int64_t numel_;
};
template <typename T>
struct I0Functor {
I0Functor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
T x = std::abs(input_[idx]);
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<T>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
T y = (x / T{2.0}) - T{2.0};
output_[idx] = static_cast<T>(std::exp(x) * Chbevl<T>(y, A, len));
} else {
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<T>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
T y = (T{32.0} / x) - T{2.0};
output_[idx] =
static_cast<T>(std::exp(x) * Chbevl<T>(y, B, len) / std::sqrt(x));
}
}
private:
const T* input_;
T* output_;
int64_t numel_;
};
template <typename T>
static inline typename std::enable_if<std::is_same<double, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I1(x) / x } = 1/2.
*/
static const T coeff[] = {
2.77791411276104639959E-18, -2.11142121435816608115E-17,
1.55363195773620046921E-16, -1.10559694773538630805E-15,
7.60068429473540693410E-15, -5.04218550472791168711E-14,
3.22379336594557470981E-13, -1.98397439776494371520E-12,
1.17361862988909016308E-11, -6.66348972350202774223E-11,
3.62559028155211703701E-10, -1.88724975172282928790E-9,
9.38153738649577178388E-9, -4.44505912879632808065E-8,
2.00329475355213526229E-7, -8.56872026469545474066E-7,
3.47025130813767847674E-6, -1.32731636560394358279E-5,
4.78156510755005422638E-5, -1.61760815825896745588E-4,
5.12285956168575772895E-4, -1.51357245063125314899E-3,
4.15642294431288815669E-3, -1.05640848946261981558E-2,
2.47264490306265168283E-2, -5.29459812080949914269E-2,
1.02643658689847095384E-1, -1.76416518357834055153E-1,
2.52587186443633654823E-1};
return std::make_tuple(coeff, 29);
}
template <typename T>
static inline typename std::enable_if<std::is_same<float, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_A() {
/* Chebyshev coefficients for exp(-x) I1(x)
* in the interval [0,8].
*
* lim(x->0){ exp(-x) I1(x) / x } = 1/2.
*/
static const T coeff[] = {9.38153738649577178388E-9f,
-4.44505912879632808065E-8f,
2.00329475355213526229E-7f,
-8.56872026469545474066E-7f,
3.47025130813767847674E-6f,
-1.32731636560394358279E-5f,
4.78156510755005422638E-5f,
-1.61760815825896745588E-4f,
5.12285956168575772895E-4f,
-1.51357245063125314899E-3f,
4.15642294431288815669E-3f,
-1.05640848946261981558E-2f,
2.47264490306265168283E-2f,
-5.29459812080949914269E-2f,
1.02643658689847095384E-1f,
-1.76416518357834055153E-1f,
2.52587186443633654823E-1f};
return std::make_tuple(coeff, 17);
}
template <typename T>
static inline typename std::enable_if<std::is_same<double, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
*/
static const T coeff[] = {
7.51729631084210481353E-18, 4.41434832307170791151E-18,
-4.65030536848935832153E-17, -3.20952592199342395980E-17,
2.96262899764595013876E-16, 3.30820231092092828324E-16,
-1.88035477551078244854E-15, -3.81440307243700780478E-15,
1.04202769841288027642E-14, 4.27244001671195135429E-14,
-2.10154184277266431302E-14, -4.08355111109219731823E-13,
-7.19855177624590851209E-13, 2.03562854414708950722E-12,
1.41258074366137813316E-11, 3.25260358301548823856E-11,
-1.89749581235054123450E-11, -5.58974346219658380687E-10,
-3.83538038596423702205E-9, -2.63146884688951950684E-8,
-2.51223623787020892529E-7, -3.88256480887769039346E-6,
-1.10588938762623716291E-4, -9.76109749136146840777E-3,
7.78576235018280120474E-1};
return std::make_tuple(coeff, 25);
}
template <typename T>
static inline typename std::enable_if<std::is_same<float, T>::value,
std::tuple<const T*, size_t>>::type
ChebyshevCoefficientsI1e_B() {
/* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
* in the inverted interval [8,infinity].
*
* lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
*/
static const T coeff[] = {-3.83538038596423702205E-9f,
-2.63146884688951950684E-8f,
-2.51223623787020892529E-7f,
-3.88256480887769039346E-6f,
-1.10588938762623716291E-4f,
-9.76109749136146840777E-3f,
7.78576235018280120474E-1f};
return std::make_tuple(coeff, 7);
}
} // namespace phi
...@@ -304,6 +304,8 @@ from .tensor.math import trapezoid # noqa: F401 ...@@ -304,6 +304,8 @@ from .tensor.math import trapezoid # noqa: F401
from .tensor.math import cumulative_trapezoid # noqa: F401 from .tensor.math import cumulative_trapezoid # noqa: F401
from .tensor.math import vander # noqa: F401 from .tensor.math import vander # noqa: F401
from .tensor.math import nextafter # noqa: F401 from .tensor.math import nextafter # noqa: F401
from .tensor.math import i0 # noqa: F401
from .tensor.math import i0e # noqa: F401
from .tensor.random import bernoulli # noqa: F401 from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401 from .tensor.random import poisson # noqa: F401
...@@ -698,4 +700,6 @@ __all__ = [ # noqa ...@@ -698,4 +700,6 @@ __all__ = [ # noqa
'vander', 'vander',
'unflatten', 'unflatten',
'nextafter', 'nextafter',
'i0',
'i0e',
] ]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest
from scipy import special
import paddle
from paddle.fluid import core
np.random.seed(100)
paddle.seed(100)
def output_i0(x):
return special.i0(x)
def ref_i0_grad(x, dout):
gradx = special.i1(x)
return dout * gradx
class TestI0API(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]
def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_api_static(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
out = paddle.i0(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[out],
)
out_ref = output_i0(self.x)
np.testing.assert_allclose(res[0], out_ref, rtol=1e-5)
paddle.disable_static()
for place in self.place:
run(place)
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.i0(x)
out_ref = output_i0(self.x)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-5)
paddle.enable_static()
for place in self.place:
run(place)
def test_empty_input_error(self):
for place in self.place:
paddle.disable_static(place)
x = None
self.assertRaises(ValueError, paddle.i0, x)
paddle.enable_static()
class TestI0Float32Zero2EightCase(TestI0API):
DTYPE = "float32"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class TestI0Float32OverEightCase(TestI0API):
DTYPE = "float32"
DATA = [9, 10, 11, 12]
class TestI0Float64Zero2EightCase(TestI0API):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class TestI0Float64OverEightCase(TestI0API):
DTYPE = "float64"
DATA = [9, 10, 11, 12]
class TestI0Op(OpTest):
def setUp(self) -> None:
self.op_type = "i0"
self.python_api = paddle.i0
self.init_config()
self.outputs = {"out": self.target}
def init_config(self):
self.dtype = np.float64
zero_case = np.zeros(1).astype(self.dtype)
rand_case = np.random.randn(100).astype(self.dtype)
one2eight_case = np.random.uniform(low=1, high=8, size=100).astype(
self.dtype
)
over_eight_case = np.random.uniform(low=9, high=15, size=100).astype(
self.dtype
)
self.case = np.concatenate(
[zero_case, rand_case, one2eight_case, over_eight_case]
)
self.inputs = {'x': self.case}
self.target = output_i0(self.inputs['x'])
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[ref_i0_grad(self.case, 1 / self.case.size)],
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest
from scipy import special
import paddle
from paddle.fluid import core
np.random.seed(100)
paddle.seed(100)
def output_i0e(x):
return special.i0e(x)
def ref_i0e_grad(x, dout):
eps = np.finfo(x.dtype).eps
not_tiny = abs(x) > eps
safe_x = np.where(not_tiny, x, eps)
gradx = special.i1e(x) - np.sign(x) * output_i0e(safe_x)
gradx = np.where(not_tiny, gradx, -1.0)
return dout * gradx
class TestI0eAPI(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]
def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_api_static(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
y = paddle.i0e(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[y],
)
out_ref = output_i0e(self.x)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-5)
paddle.disable_static()
for place in self.place:
run(place)
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.i0e(x)
out_ref = output_i0e(self.x)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5)
paddle.enable_static()
for place in self.place:
run(place)
def test_empty_input_error(self):
for place in self.place:
paddle.disable_static(place)
x = None
self.assertRaises(ValueError, paddle.i0e, x)
paddle.enable_static()
class TestI0eFloat32Zero2EightCase(TestI0eAPI):
DTYPE = "float32"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class TestI0eFloat32OverEightCase(TestI0eAPI):
DTYPE = "float32"
DATA = [9, 10, 11, 12]
class TestI0eFloat64Zero2EightCase(TestI0eAPI):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class TestI0eFloat64OverEightCase(TestI0eAPI):
DTYPE = "float64"
DATA = [9, 10, 11, 12]
class TestI0eOp(OpTest):
def setUp(self) -> None:
self.op_type = "i0e"
self.python_api = paddle.i0e
self.init_config()
self.outputs = {"out": self.target}
def init_config(self):
self.dtype = np.float64
zero_case = np.zeros(1).astype(self.dtype)
rand_case = np.random.randn(100).astype(self.dtype)
one2eight_case = np.random.uniform(low=1, high=8, size=100).astype(
self.dtype
)
over_eight_case = np.random.uniform(low=9, high=15, size=100).astype(
self.dtype
)
self.case = np.concatenate(
[zero_case, rand_case, one2eight_case, over_eight_case]
)
self.inputs = {'x': self.case}
self.target = output_i0e(self.inputs['x'])
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[ref_i0e_grad(self.case, 1 / self.case.size)],
)
if __name__ == "__main__":
unittest.main()
...@@ -256,6 +256,8 @@ from .math import sigmoid # noqa: F401 ...@@ -256,6 +256,8 @@ from .math import sigmoid # noqa: F401
from .math import sigmoid_ # noqa: F401 from .math import sigmoid_ # noqa: F401
from .math import vander # noqa: F401 from .math import vander # noqa: F401
from .math import nextafter # noqa: F401 from .math import nextafter # noqa: F401
from .math import i0 # noqa: F401
from .math import i0e # noqa: F401
from .random import multinomial # noqa: F401 from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401 from .random import standard_normal # noqa: F401
...@@ -550,6 +552,8 @@ tensor_method_func = [ # noqa ...@@ -550,6 +552,8 @@ tensor_method_func = [ # noqa
'vander', 'vander',
'nextafter', 'nextafter',
'unflatten', 'unflatten',
'i0',
'i0e',
] ]
# this list used in math_op_patch.py for magic_method bind # this list used in math_op_patch.py for magic_method bind
......
...@@ -5584,3 +5584,76 @@ def nextafter(x, y, name=None): ...@@ -5584,3 +5584,76 @@ def nextafter(x, y, name=None):
outputs = {"out": out} outputs = {"out": out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs) helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out return out
def i0(x, name=None):
r"""
The function used to calculate modified bessel function of order 0.
Equation:
.. math::
I_0(x) = \sum^{\infty}_{k=0}\frac{(x^2/4)^k}{(k!)^2}
Args:
x (Tensor): The input tensor, it's data type should be float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
- out (Tensor), A Tensor. the value of the modified bessel function of order 0 at x.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
print(paddle.i0(x))
# (Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0.99999994 , 1.26606596 , 2.27958512 , 4.88079262 , 11.30192089]),
"""
if in_dygraph_mode():
return _C_ops.i0(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "i0")
helper = LayerHelper("i0", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='i0', inputs={'x': x}, outputs={'out': out})
return out
def i0e(x, name=None):
r"""
The function used to calculate exponentially scaled modified Bessel function of order 0.
Equation:
.. math::
I_0(x) = \sum^{\infty}_{k=0}\frac{(x^2/4)^k}{(k!)^2} \\
I_{0e}(x) = e^{-|x|}I_0(x)
Args:
x (Tensor): The input tensor, it's data type should be float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
- out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 0 at x.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
print(paddle.i0e(x))
# (Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [1., 0.46575961, 0.30850832, 0.24300035, 0.20700192]),
"""
if in_dygraph_mode():
return _C_ops.i0e(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "i0e")
helper = LayerHelper("i0e", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='i0e', inputs={'x': x}, outputs={'out': out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册