未验证 提交 a63fb4c8 编写于 作者: L LyndonKong 提交者: GitHub

【Hackathon 4 No.21】Add i1 / i1e to paddle (#53210)

* Add i1 and i1e op

* resolve merge conflicts
上级 89653668
...@@ -821,6 +821,26 @@ ...@@ -821,6 +821,26 @@
kernel : kernel :
func : i0e_grad func : i0e_grad
- backward_op : i1_grad
forward : i1 (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i1_grad
- backward_op : i1e_grad
forward : i1e (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : i1e_grad
- backward_op : imag_grad - backward_op : imag_grad
forward : imag (Tensor x) -> Tensor(out) forward : imag (Tensor x) -> Tensor(out)
args : (Tensor out_grad) args : (Tensor out_grad)
......
...@@ -948,6 +948,24 @@ ...@@ -948,6 +948,24 @@
func : i0e func : i0e
backward : i0e_grad backward : i0e_grad
- op : i1
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i1
backward : i1_grad
- op : i1e
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : i1e
backward : i1e_grad
- op : imag - op : imag
args : (Tensor x) args : (Tensor x)
output : Tensor (out) output : Tensor (out)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I1GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
const T* out_data = out.data<T>();
const T* out_grad_data = out_grad.data<T>();
T* x_grad_data = ctx.template Alloc<T>(x_grad);
phi::funcs::ForRange<Context> for_range(ctx, size);
I1GradFunctor<T> functor(x_data, out_data, out_grad_data, x_grad_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i1_grad, CPU, ALL_LAYOUT, phi::I1GradKernel, float, double) {
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
phi::funcs::ForRange<Context> for_range(ctx, size);
I1Functor<T> functor(x_data, out_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i1, CPU, ALL_LAYOUT, phi::I1Kernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1e_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I1eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
const T* out_data = out.data<T>();
const T* out_grad_data = out_grad.data<T>();
T* x_grad_data = ctx.template Alloc<T>(x_grad);
phi::funcs::ForRange<Context> for_range(ctx, size);
I1eGradFunctor<T> functor(x_data, out_data, out_grad_data, x_grad_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
i1e_grad, CPU, ALL_LAYOUT, phi::I1eGradKernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1e_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/impl/bessel_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const int64_t size = x.numel();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
phi::funcs::ForRange<Context> for_range(ctx, size);
I1eFunctor<T> functor(x_data, out_data, size);
for_range(functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i1e, CPU, ALL_LAYOUT, phi::I1eKernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I1GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI1GradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i1_grad, GPU, ALL_LAYOUT, phi::I1GradKernel, float, double) {
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI1Functor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i1, GPU, ALL_LAYOUT, phi::I1Kernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1e_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_grad_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I1eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
std::vector<const DenseTensor*> ins = {&x, &out, &out_grad};
std::vector<DenseTensor*> outs = {x_grad};
auto functor = CudaI1eGradFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
i1e_grad, GPU, ALL_LAYOUT, phi::I1eGradKernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/i1e_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/bessel_kernel_cuda_impl.h"
namespace phi {
template <typename T, typename Context>
void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto functor = CudaI1eFunctor<T>();
phi::funcs::ElementwiseKernel<T>(ctx, ins, &outs, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(i1e, GPU, ALL_LAYOUT, phi::I1eKernel, float, double) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
/**
* @brief This kernel calculate gradient of Modified Bessel function of order 1.
* @param ctx device context
* @param x
* @param out
* @param out_grad
* @param x_grad
*/
template <typename T, typename Context>
void I1GradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2023PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
/**
* @brief This kernel calculate Modified Bessel function of order 1.
* @param ctx device context
* @param x The input tensor of i1
* @param out The output tensor of i1 kernel, it has the same shape and
* dtype with input. Each element corresponds to input tensor
*/
template <typename T, typename Context>
void I1Kernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void I1eGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2023PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
/**
* @brief This kernel calculate Exponentially scaled modified Bessel function of
* order 1..
* @param ctx device context
* @param x The input tensor of i1e
* @param out The output tensor of i1e kernel, it has the same shape and
* dtype with input. Each element corresponds to input tensor
*/
template <typename T, typename Context>
void I1eKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
...@@ -87,4 +87,80 @@ struct CudaI0eGradFunctor { ...@@ -87,4 +87,80 @@ struct CudaI0eGradFunctor {
} }
}; };
template <typename T>
struct CudaI1GradFunctor {
__device__ __forceinline__ T operator()(const T _x,
const T _out,
const T _out_grad) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
const MT mp_out = static_cast<MT>(_out);
const MT mp_out_grad = static_cast<MT>(_out_grad);
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
MT eps = static_cast<MT>(std::numeric_limits<T>::epsilon());
if (x <= eps) {
MT out = (MT{0.5}) * mp_out_grad;
return static_cast<T>(out);
} else {
return static_cast<T>(
(std::exp(x) * Chbevl<MT>(y, A, len) - mp_out / mp_x) *
mp_out_grad);
}
}
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
return static_cast<T>(
(std::exp(x) * Chbevl<MT>(y, B, len) / std::sqrt(x) - mp_out / mp_x) *
mp_out_grad);
}
};
template <typename T>
struct CudaI1eGradFunctor {
__device__ __forceinline__ T operator()(const T _x,
const T _out,
const T _out_grad) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
const MT mp_out = static_cast<MT>(_out);
const MT mp_out_grad = static_cast<MT>(_out_grad);
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
MT eps = static_cast<MT>(std::numeric_limits<T>::epsilon());
if (x <= eps) {
MT out = (MT{0.5}) * mp_out_grad;
return static_cast<T>(out);
} else {
MT out = (Chbevl<MT>(y, A, len) -
mp_out * (std::copysign(MT{1.0}, mp_x) + (MT{1.0}) / mp_x)) *
mp_out_grad;
return static_cast<T>(out);
}
}
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
return static_cast<T>(
(Chbevl<T>(y, B, len) / std::sqrt(x) -
mp_out * (std::copysign(MT{1.0}, mp_x) + (MT{1.0}) / mp_x)) *
mp_out_grad);
}
};
} // namespace phi } // namespace phi
...@@ -109,4 +109,103 @@ struct I0eGradFunctor { ...@@ -109,4 +109,103 @@ struct I0eGradFunctor {
int64_t numel_; int64_t numel_;
}; };
template <typename T>
struct I1GradFunctor {
I1GradFunctor(
const T* x, const T* out, const T* out_grad, T* x_grad, int64_t numel)
: input_x_(x),
input_out_(out),
input_out_grad_(out_grad),
output_x_grad_(x_grad),
numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
T x = std::abs(input_x_[idx]);
T x_ = input_x_[idx];
T out_ = input_out_[idx];
T out_grad_ = input_out_grad_[idx];
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<T>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
T y = (x / T{2.0}) - T{2.0};
T eps = std::numeric_limits<T>::epsilon();
if (x <= eps) {
output_x_grad_[idx] = static_cast<T>(T{0.5} * out_grad_);
} else {
output_x_grad_[idx] = static_cast<T>(
(std::exp(x) * Chbevl<T>(y, A, len) - out_ / x_) * out_grad_);
}
} else {
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<T>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
T y = (T{32.0} / x) - T{2.0};
output_x_grad_[idx] = static_cast<T>(
(std::exp(x) * Chbevl<T>(y, B, len) / std::sqrt(x) - out_ / x_) *
out_grad_);
}
}
private:
const T* input_x_;
const T* input_out_;
const T* input_out_grad_;
T* output_x_grad_;
int64_t numel_;
};
template <typename T>
struct I1eGradFunctor {
I1eGradFunctor(
const T* x, const T* out, const T* out_grad, T* x_grad, int64_t numel)
: input_x_(x),
input_out_(out),
input_out_grad_(out_grad),
output_x_grad_(x_grad),
numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
T x = std::abs(input_x_[idx]);
T x_ = input_x_[idx];
T out_ = input_out_[idx];
T out_grad_ = input_out_grad_[idx];
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI0e_A<T>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
T y = (x / T{2.0}) - T{2.0};
T eps = std::numeric_limits<T>::epsilon();
if (x <= eps) {
output_x_grad_[idx] = static_cast<T>(T{0.5} * out_grad_);
} else {
output_x_grad_[idx] =
static_cast<T>((Chbevl<T>(y, A, len) -
out_ * (std::copysign(T{1.0}, x_) + T{1.0} / x_)) *
out_grad_);
}
} else {
auto coeff_pair_B = ChebyshevCoefficientsI0e_B<T>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
T y = (T{32.0} / x) - T{2.0};
output_x_grad_[idx] =
static_cast<T>((Chbevl<T>(y, B, len) / std::sqrt(x) -
out_ * (std::copysign(T{1.0}, x_) + T{1.0} / x_)) *
out_grad_);
}
}
private:
const T* input_x_;
const T* input_out_;
const T* input_out_grad_;
T* output_x_grad_;
int64_t numel_;
};
} // namespace phi } // namespace phi
...@@ -228,4 +228,52 @@ ChebyshevCoefficientsI1e_B() { ...@@ -228,4 +228,52 @@ ChebyshevCoefficientsI1e_B() {
return std::make_tuple(coeff, 7); return std::make_tuple(coeff, 7);
} }
template <typename T>
struct CudaI1Functor {
__device__ __forceinline__ T operator()(const T _x) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
const T out = std::exp(x) * x * Chbevl<MT>(y, A, len);
return (mp_x < MT{0.0}) ? -out : out;
}
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
const T out = (std::exp(x) * Chbevl<MT>(y, B, len)) / std::sqrt(x);
return (mp_x < MT{0.0}) ? -out : out;
}
};
template <typename T>
struct CudaI1eFunctor {
__device__ __forceinline__ T operator()(const T _x) const {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
const MT mp_x = static_cast<MT>(_x);
MT x = std::abs(mp_x);
if (x <= MT{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<MT>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
MT y = (x / MT{2.0}) - MT{2.0};
const T out = static_cast<T>(Chbevl<T>(y, A, len) * x);
return (mp_x < MT{0.0}) ? -out : out;
}
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<MT>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
MT y = (MT{32.0} / x) - MT{2.0};
const T out = static_cast<T>(Chbevl<T>(y, B, len) / std::sqrt(x));
return (mp_x < MT{0.0}) ? -out : out;
}
};
} // namespace phi } // namespace phi
...@@ -251,4 +251,65 @@ ChebyshevCoefficientsI1e_B() { ...@@ -251,4 +251,65 @@ ChebyshevCoefficientsI1e_B() {
return std::make_tuple(coeff, 7); return std::make_tuple(coeff, 7);
} }
template <typename T>
struct I1Functor {
I1Functor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
T x = std::abs(input_[idx]);
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<T>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
T y = (x / T{2.0}) - T{2.0};
const T out = std::exp(x) * x * Chbevl(y, A, len);
output_[idx] = (input_[idx] < T{0.0}) ? -out : out;
} else {
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<T>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
T y = (T{32.0} / x) - T{2.0};
const T out = (std::exp(x) * Chbevl(y, B, len)) / std::sqrt(x);
output_[idx] = (input_[idx] < T{0.0}) ? -out : out;
}
}
private:
const T* input_;
T* output_;
int64_t numel_;
};
template <typename T>
struct I1eFunctor {
I1eFunctor(const T* input, T* output, int64_t numel)
: input_(input), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
T x = std::abs(input_[idx]);
if (x <= T{8.0}) {
auto coeff_pair_A = ChebyshevCoefficientsI1e_A<T>();
auto A = std::get<0>(coeff_pair_A);
auto len = std::get<1>(coeff_pair_A);
T y = (x / T{2.0}) - T{2.0};
const T out = Chbevl<T>(y, A, len) * x;
output_[idx] = (input_[idx] < T{0.0}) ? -out : out;
} else {
auto coeff_pair_B = ChebyshevCoefficientsI1e_B<T>();
auto B = std::get<0>(coeff_pair_B);
auto len = std::get<1>(coeff_pair_B);
T y = (T{32.0} / x) - T{2.0};
const T out = Chbevl<T>(y, B, len) / std::sqrt(x);
output_[idx] = (input_[idx] < T{0.0}) ? -out : out;
}
}
private:
const T* input_;
T* output_;
int64_t numel_;
};
} // namespace phi } // namespace phi
...@@ -306,6 +306,8 @@ from .tensor.math import vander # noqa: F401 ...@@ -306,6 +306,8 @@ from .tensor.math import vander # noqa: F401
from .tensor.math import nextafter # noqa: F401 from .tensor.math import nextafter # noqa: F401
from .tensor.math import i0 # noqa: F401 from .tensor.math import i0 # noqa: F401
from .tensor.math import i0e # noqa: F401 from .tensor.math import i0e # noqa: F401
from .tensor.math import i1 # noqa: F401
from .tensor.math import i1e # noqa: F401
from .tensor.random import bernoulli # noqa: F401 from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401 from .tensor.random import poisson # noqa: F401
...@@ -702,4 +704,6 @@ __all__ = [ # noqa ...@@ -702,4 +704,6 @@ __all__ = [ # noqa
'nextafter', 'nextafter',
'i0', 'i0',
'i0e', 'i0e',
'i1',
'i1e',
] ]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest
from scipy import special
import paddle
from paddle.fluid import core
np.random.seed(42)
paddle.seed(42)
def reference_i1(x):
return special.i1(x)
def reference_i1_grad(x, dout):
eps = np.finfo(x.dtype).eps
not_tiny = abs(x) > eps
safe_x = np.where(not_tiny, x, eps)
gradx = special.i0(safe_x) - special.i1(x) / safe_x
gradx = np.where(not_tiny, gradx, 0.5)
return dout * gradx
class Testi1API(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]
def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_api_static(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
out = paddle.i1(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[out],
)
out_ref = reference_i1(self.x)
np.testing.assert_allclose(res[0], out_ref, rtol=1e-5)
paddle.disable_static()
for place in self.place:
run(place)
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.i1(x)
out_ref = reference_i1(self.x)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-5)
paddle.enable_static()
for place in self.place:
run(place)
def test_empty_input_error(self):
for place in self.place:
paddle.disable_static(place)
x = None
self.assertRaises(ValueError, paddle.i1, x)
paddle.enable_static()
class Testi1Float32Zero2EightCase(Testi1API):
DTYPE = "float32"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class Testi1Float32OverEightCase(Testi1API):
DTYPE = "float32"
DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17]
class Testi1Float64Zero2EightCase(Testi1API):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class Testi1Float64OverEightCase(Testi1API):
DTYPE = "float64"
DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17]
class TestI1Op(OpTest):
# 配置 op 信息以及输入输出等参数
def setUp(self):
self.op_type = "i1"
self.python_api = paddle.i1
self.init_config()
self.outputs = {'out': self.target}
# 测试前向输出结果
def test_check_output(self):
self.check_output()
# 测试反向梯度输出
def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[
reference_i1_grad(
self.case,
1 / self.case.size,
)
],
)
def init_config(self):
# 生成随机的输入数据
zero_case = np.zeros(1).astype('float64')
rand_case = np.random.randn(250).astype('float64')
over_eight_case = np.random.uniform(low=8, high=9, size=250).astype(
'float64'
)
self.case = np.concatenate([zero_case, rand_case, over_eight_case])
self.inputs = {'x': self.case}
self.target = reference_i1(self.inputs['x'])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from eager_op_test import OpTest
from scipy import special
import paddle
from paddle.fluid import core
np.random.seed(42)
paddle.seed(42)
def reference_i1e(x):
return special.i1e(x)
def reference_i1e_grad(x, dout):
eps = np.finfo(x.dtype).eps
not_tiny = abs(x) > eps
safe_x = np.where(not_tiny, x, eps)
gradx = special.i0e(safe_x) - special.i1e(x) * (np.sign(x) + 1 / safe_x)
gradx = np.where(not_tiny, gradx, 0.5)
return dout * gradx
class TestI1e_API(unittest.TestCase):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5]
def setUp(self):
self.x = np.array(self.DATA).astype(self.DTYPE)
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_api_static(self):
def run(place):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.DTYPE
)
y = paddle.i1e(x)
exe = paddle.static.Executor(place)
res = exe.run(
paddle.static.default_main_program(),
feed={"x": self.x},
fetch_list=[y],
)
out_ref = reference_i1e(self.x)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-5)
paddle.disable_static()
for place in self.place:
run(place)
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x = paddle.to_tensor(self.x)
out = paddle.i1e(x)
out_ref = reference_i1e(self.x)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5)
paddle.enable_static()
for place in self.place:
run(place)
def test_empty_input_error(self):
for place in self.place:
paddle.disable_static(place)
x = None
self.assertRaises(ValueError, paddle.i1e, x)
paddle.enable_static()
class Testi1eFloat32Zero2EightCase(TestI1e_API):
DTYPE = "float32"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class Testi1eFloat32OverEightCase(TestI1e_API):
DTYPE = "float32"
DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17]
class Testi1eFloat64Zero2EightCase(TestI1e_API):
DTYPE = "float64"
DATA = [0, 1, 2, 3, 4, 5, 6, 7, 8]
class Testi1eFloat64OverEightCase(TestI1e_API):
DTYPE = "float64"
DATA = [9, 10, 11, 12, 13, 14, 15, 16, 17]
class TestI1eOp(OpTest):
# 配置 op 信息以及输入输出等参数
def setUp(self):
self.op_type = "i1e"
self.python_api = paddle.i1e
self.init_config()
self.outputs = {'out': self.target}
# 测试前向输出结果
def test_check_output(self):
self.check_output()
# 测试反向梯度输出
def test_check_grad(self):
self.check_grad(
['x'],
'out',
user_defined_grads=[
reference_i1e_grad(
self.case,
1 / self.case.size,
)
],
)
# 生成随机的输入数据并计算对应输出
def init_config(self):
zero_case = np.zeros(1).astype('float64')
rand_case = np.random.randn(250).astype('float64')
over_eight_case = np.random.uniform(low=8, high=9, size=250).astype(
'float64'
)
self.case = np.concatenate([zero_case, rand_case, over_eight_case])
self.inputs = {'x': self.case}
self.target = reference_i1e(self.inputs['x'])
if __name__ == "__main__":
unittest.main()
...@@ -258,6 +258,8 @@ from .math import vander # noqa: F401 ...@@ -258,6 +258,8 @@ from .math import vander # noqa: F401
from .math import nextafter # noqa: F401 from .math import nextafter # noqa: F401
from .math import i0 # noqa: F401 from .math import i0 # noqa: F401
from .math import i0e # noqa: F401 from .math import i0e # noqa: F401
from .math import i1 # noqa: F401
from .math import i1e # noqa: F401
from .random import multinomial # noqa: F401 from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401 from .random import standard_normal # noqa: F401
...@@ -554,6 +556,8 @@ tensor_method_func = [ # noqa ...@@ -554,6 +556,8 @@ tensor_method_func = [ # noqa
'unflatten', 'unflatten',
'i0', 'i0',
'i0e', 'i0e',
'i1',
'i1e',
] ]
# this list used in math_op_patch.py for magic_method bind # this list used in math_op_patch.py for magic_method bind
......
...@@ -5657,3 +5657,70 @@ def i0e(x, name=None): ...@@ -5657,3 +5657,70 @@ def i0e(x, name=None):
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='i0e', inputs={'x': x}, outputs={'out': out}) helper.append_op(type='i0e', inputs={'x': x}, outputs={'out': out})
return out return out
def i1(x, name=None):
"""
The function is used to calculate modified bessel function of order 1.
Args:
x (Tensor): The input tensor, it's data type should be float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
- out (Tensor), A Tensor. the value of the modified bessel function of order 1 at x.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
print(paddle.i1(x))
# (Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0., 0.5651591 , 1.59063685 , 3.95337022 , 9.75946515]),
"""
if in_dygraph_mode():
return _C_ops.i1(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "i1")
helper = LayerHelper("i1", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='i1', inputs={'x': x}, outputs={'out': out}, attrs={}
)
return out
def i1e(x, name=None):
"""
The function is used to calculate exponentially scaled modified Bessel function of order 1.
Args:
x (Tensor): The input tensor, it's data type should be float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
- out (Tensor), A Tensor. the value of the exponentially scaled modified Bessel function of order 1 at x.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0, 1, 2, 3, 4], dtype="float32")
print(paddle.i1e(x))
# (Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True, [0., 0.20791042, 0.21526929, 0.24300035, 0.17875084]),
"""
if in_dygraph_mode():
return _C_ops.i1e(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "i1e")
helper = LayerHelper("i1e", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='i1e', inputs={'x': x}, outputs={'out': out}, attrs={}
)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册