diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index b981d802255a2a4550034b69fe7d7fc10e341587..b882fcc2a6c960032936c040bb1604776f60de6a 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -114,7 +114,8 @@ PD_REGISTER_KERNEL(less_than, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, @@ -125,7 +126,8 @@ PD_REGISTER_KERNEL(less_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, @@ -136,7 +138,8 @@ PD_REGISTER_KERNEL(greater_than, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(greater_equal, KPS, ALL_LAYOUT, @@ -147,7 +150,8 @@ PD_REGISTER_KERNEL(greater_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, @@ -158,7 +162,8 @@ PD_REGISTER_KERNEL(equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, @@ -169,7 +174,8 @@ PD_REGISTER_KERNEL(not_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(equal_all, KPS,