diff --git a/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc index 585c27bdcec97e11a68cdc536c829f76c000a8df..a5c9dc4c55e495833f40ec7499e6c0373594d319 100644 --- a/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(segment_pool_grad, ALL_LAYOUT, phi::SegmentPoolGradKernel, float, - double) {} + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/segment_pool_kernel.cc b/paddle/phi/kernels/cpu/segment_pool_kernel.cc index d0413457f8177338aa450211539dc16d0880c74c..ad76a7a86bcb28f291288418c43740ed0b7adb97 100644 --- a/paddle/phi/kernels/cpu/segment_pool_kernel.cc +++ b/paddle/phi/kernels/cpu/segment_pool_kernel.cc @@ -18,5 +18,11 @@ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" -PD_REGISTER_KERNEL( - segment_pool, CPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} +PD_REGISTER_KERNEL(segment_pool, + CPU, + ALL_LAYOUT, + phi::SegmentPoolKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/segment_pooling.cc b/paddle/phi/kernels/funcs/segment_pooling.cc index bf4a21f37223dab5a67649406496e9828b0bcf3f..fbd744430aa11ab1a5a17c76b6d37c10c3085556 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cc +++ b/paddle/phi/kernels/funcs/segment_pooling.cc @@ -149,10 +149,19 @@ template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; + template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/segment_pooling.cu b/paddle/phi/kernels/funcs/segment_pooling.cu index 305cd39f077bc359543b399a8775b5a92a2eb00d..95606b152672916116813c97cbbc0856d33e49a7 100644 --- a/paddle/phi/kernels/funcs/segment_pooling.cu +++ b/paddle/phi/kernels/funcs/segment_pooling.cu @@ -453,10 +453,19 @@ template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; + template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu index d9618dc159a6d3f5b24bdfcfdb219ec649e051f9..9d1769e18b4b809fbc353513a05553e0ccd97572 100644 --- a/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu @@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(segment_pool_grad, ALL_LAYOUT, phi::SegmentPoolGradKernel, float, - double) {} + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/segment_pool_kernel.cu b/paddle/phi/kernels/gpu/segment_pool_kernel.cu index c38e935adf837ef00c48fa31bc1e37eea2948673..3128e534166acba6ca136331ad8efea66b18621f 100644 --- a/paddle/phi/kernels/gpu/segment_pool_kernel.cu +++ b/paddle/phi/kernels/gpu/segment_pool_kernel.cu @@ -19,5 +19,11 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -PD_REGISTER_KERNEL( - segment_pool, GPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} +PD_REGISTER_KERNEL(segment_pool, + GPU, + ALL_LAYOUT, + phi::SegmentPoolKernel, + float, + double, + int, + int64_t) {} diff --git a/python/paddle/incubate/tensor/math.py b/python/paddle/incubate/tensor/math.py index 9f577d5ff38024fe9264deec2980ff091996a1d8..2d0b079ee9280e2dd0cc0e62c6c5932565ba9dfd 100644 --- a/python/paddle/incubate/tensor/math.py +++ b/python/paddle/incubate/tensor/math.py @@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None): where sum is over j such that `segment_ids[j] == i`. Args: - data (Tensor): A tensor, available data type float32, float64. + data (Tensor): A tensor, available data type float32, float64, int32, int64. segment_ids (Tensor): A 1-D tensor, which have the same size with the first dimension of input data. Available data type is int32, int64. @@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool") @@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None): of all index 'segment_ids[j] == i'. Args: - data (tensor): a tensor, available data type float32, float64. + data (tensor): a tensor, available data type float32, float64, int32, int64. segment_ids (tensor): a 1-d tensor, which have the same size with the first dimension of input data. available data type is int32, int64. @@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool") @@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None): where min is over j such that `segment_ids[j] == i`. Args: - data (tensor): a tensor, available data type float32, float64. + data (tensor): a tensor, available data type float32, float64, int32, int64. segment_ids (tensor): a 1-d tensor, which have the same size with the first dimension of input data. available data type is int32, int64. @@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool") @@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None): where max is over j such that `segment_ids[j] == i`. Args: - data (tensor): a tensor, available data type float32, float64. + data (tensor): a tensor, available data type float32, float64, int32, int64. segment_ids (tensor): a 1-d tensor, which have the same size with the first dimension of input data. available data type is int32, int64. @@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None): out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX") return out - check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") + check_variable_and_dtype(data, "X", ("float32", "float64", "int32", + "int64"), "segment_pool") check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool")