未验证 提交 14ed2f54 编写于 作者: L Leo Chen 提交者: GitHub

[pten] update isnan registration (#39419)

* update isnan registration

* fix compile
上级 c7c1db33
......@@ -132,4 +132,33 @@ namespace ops = paddle::operators;
REGISTER_OP_MAKER(isinf, "isinf(X)");
REGISTER_OP_MAKER(isnan, "isnan(X)");
REGISTER_OP_MAKER(isfinite, "isfinite(X)");
FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CPU_KERNEL);
REGISTER_OP_CPU_KERNEL(isinf,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::InfinityFunctor>);
REGISTER_OP_CPU_KERNEL(isnan,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::NANFunctor>);
REGISTER_OP_CPU_KERNEL(isfinite,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::IsfiniteFunctor>);
......@@ -17,15 +17,32 @@
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double, \
ops::functor>, \
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16, \
ops::functor>);
REGISTER_OP_CUDA_KERNEL(
isinf, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::InfinityFunctor>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CUDA_KERNEL);
REGISTER_OP_CUDA_KERNEL(isnan,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
int, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
float, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
double, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
plat::float16, ops::NANFunctor>);
REGISTER_OP_CUDA_KERNEL(
isfinite, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::IsfiniteFunctor>);
......@@ -73,8 +73,3 @@ class OverflowKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(isinf, InfinityFunctor); \
__macro(isnan, NANFunctor); \
__macro(isfinite, IsfiniteFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册