未验证 提交 14ed2f54 编写于 作者: L Leo Chen 提交者: GitHub

[pten] update isnan registration (#39419)

* update isnan registration

* fix compile
上级 c7c1db33
...@@ -132,4 +132,33 @@ namespace ops = paddle::operators; ...@@ -132,4 +132,33 @@ namespace ops = paddle::operators;
REGISTER_OP_MAKER(isinf, "isinf(X)"); REGISTER_OP_MAKER(isinf, "isinf(X)");
REGISTER_OP_MAKER(isnan, "isnan(X)"); REGISTER_OP_MAKER(isnan, "isnan(X)");
REGISTER_OP_MAKER(isfinite, "isfinite(X)"); REGISTER_OP_MAKER(isfinite, "isfinite(X)");
FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CPU_KERNEL);
REGISTER_OP_CPU_KERNEL(isinf,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::InfinityFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::InfinityFunctor>);
REGISTER_OP_CPU_KERNEL(isnan,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::NANFunctor>);
REGISTER_OP_CPU_KERNEL(isfinite,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
float, ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CPUDeviceContext,
double, ops::IsfiniteFunctor>);
...@@ -17,15 +17,32 @@ ...@@ -17,15 +17,32 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#define REGISTER_OVERFLOW_CUDA_KERNEL(op_type, functor) \ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL( \ isinf, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
op_type, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int, \ ops::InfinityFunctor>,
ops::functor>, \ ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float, \ ops::InfinityFunctor>,
ops::functor>, \ ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double, \ ops::InfinityFunctor>,
ops::functor>, \ ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16, \ ops::InfinityFunctor>);
ops::functor>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_OVERFLOW_CUDA_KERNEL); REGISTER_OP_CUDA_KERNEL(isnan,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
int, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
float, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
double, ops::NANFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext,
plat::float16, ops::NANFunctor>);
REGISTER_OP_CUDA_KERNEL(
isfinite, ops::OverflowKernel<paddle::platform::CUDADeviceContext, int,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, float,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, double,
ops::IsfiniteFunctor>,
ops::OverflowKernel<paddle::platform::CUDADeviceContext, plat::float16,
ops::IsfiniteFunctor>);
...@@ -73,8 +73,3 @@ class OverflowKernel : public framework::OpKernel<T> { ...@@ -73,8 +73,3 @@ class OverflowKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(isinf, InfinityFunctor); \
__macro(isnan, NANFunctor); \
__macro(isfinite, IsfiniteFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册