未验证 提交 6849d33b 编写于 作者: Z Zhong Hui 提交者: GitHub

[Ops] segment pool op support for int int64 kernel. (#40577)

* segment pool support for int int64 kernel.

* add support in python api
上级 2dec25db
...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(segment_pool_grad, ...@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(segment_pool_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::SegmentPoolGradKernel, phi::SegmentPoolGradKernel,
float, float,
double) {} double,
int,
int64_t) {}
...@@ -18,5 +18,11 @@ ...@@ -18,5 +18,11 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(segment_pool,
segment_pool, CPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} CPU,
ALL_LAYOUT,
phi::SegmentPoolKernel,
float,
double,
int,
int64_t) {}
...@@ -149,10 +149,19 @@ template class SegmentPoolFunctor<CPU, float, int>; ...@@ -149,10 +149,19 @@ template class SegmentPoolFunctor<CPU, float, int>;
template class SegmentPoolFunctor<CPU, float, int64_t>; template class SegmentPoolFunctor<CPU, float, int64_t>;
template class SegmentPoolFunctor<CPU, double, int>; template class SegmentPoolFunctor<CPU, double, int>;
template class SegmentPoolFunctor<CPU, double, int64_t>; template class SegmentPoolFunctor<CPU, double, int64_t>;
template class SegmentPoolFunctor<CPU, int, int>;
template class SegmentPoolFunctor<CPU, int, int64_t>;
template class SegmentPoolFunctor<CPU, int64_t, int>;
template class SegmentPoolFunctor<CPU, int64_t, int64_t>;
template class SegmentPoolGradFunctor<CPU, float, int>; template class SegmentPoolGradFunctor<CPU, float, int>;
template class SegmentPoolGradFunctor<CPU, float, int64_t>; template class SegmentPoolGradFunctor<CPU, float, int64_t>;
template class SegmentPoolGradFunctor<CPU, double, int>; template class SegmentPoolGradFunctor<CPU, double, int>;
template class SegmentPoolGradFunctor<CPU, double, int64_t>; template class SegmentPoolGradFunctor<CPU, double, int64_t>;
template class SegmentPoolGradFunctor<CPU, int, int>;
template class SegmentPoolGradFunctor<CPU, int, int64_t>;
template class SegmentPoolGradFunctor<CPU, int64_t, int>;
template class SegmentPoolGradFunctor<CPU, int64_t, int64_t>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -453,10 +453,19 @@ template class SegmentPoolFunctor<GPU, float, int>; ...@@ -453,10 +453,19 @@ template class SegmentPoolFunctor<GPU, float, int>;
template class SegmentPoolFunctor<GPU, float, int64_t>; template class SegmentPoolFunctor<GPU, float, int64_t>;
template class SegmentPoolFunctor<GPU, double, int>; template class SegmentPoolFunctor<GPU, double, int>;
template class SegmentPoolFunctor<GPU, double, int64_t>; template class SegmentPoolFunctor<GPU, double, int64_t>;
template class SegmentPoolFunctor<GPU, int, int>;
template class SegmentPoolFunctor<GPU, int, int64_t>;
template class SegmentPoolFunctor<GPU, int64_t, int>;
template class SegmentPoolFunctor<GPU, int64_t, int64_t>;
template class SegmentPoolGradFunctor<GPU, float, int>; template class SegmentPoolGradFunctor<GPU, float, int>;
template class SegmentPoolGradFunctor<GPU, float, int64_t>; template class SegmentPoolGradFunctor<GPU, float, int64_t>;
template class SegmentPoolGradFunctor<GPU, double, int>; template class SegmentPoolGradFunctor<GPU, double, int>;
template class SegmentPoolGradFunctor<GPU, double, int64_t>; template class SegmentPoolGradFunctor<GPU, double, int64_t>;
template class SegmentPoolGradFunctor<GPU, int, int>;
template class SegmentPoolGradFunctor<GPU, int, int64_t>;
template class SegmentPoolGradFunctor<GPU, int64_t, int>;
template class SegmentPoolGradFunctor<GPU, int64_t, int64_t>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(segment_pool_grad, ...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(segment_pool_grad,
ALL_LAYOUT, ALL_LAYOUT,
phi::SegmentPoolGradKernel, phi::SegmentPoolGradKernel,
float, float,
double) {} double,
int,
int64_t) {}
...@@ -19,5 +19,11 @@ ...@@ -19,5 +19,11 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(segment_pool,
segment_pool, GPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {} GPU,
ALL_LAYOUT,
phi::SegmentPoolKernel,
float,
double,
int,
int64_t) {}
...@@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None): ...@@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None):
where sum is over j such that `segment_ids[j] == i`. where sum is over j such that `segment_ids[j] == i`.
Args: Args:
data (Tensor): A tensor, available data type float32, float64. data (Tensor): A tensor, available data type float32, float64, int32, int64.
segment_ids (Tensor): A 1-D tensor, which have the same size segment_ids (Tensor): A 1-D tensor, which have the same size
with the first dimension of input data. with the first dimension of input data.
Available data type is int32, int64. Available data type is int32, int64.
...@@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None): ...@@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool") "segment_pool")
...@@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None): ...@@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None):
of all index 'segment_ids[j] == i'. of all index 'segment_ids[j] == i'.
Args: Args:
data (tensor): a tensor, available data type float32, float64. data (tensor): a tensor, available data type float32, float64, int32, int64.
segment_ids (tensor): a 1-d tensor, which have the same size segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data. with the first dimension of input data.
available data type is int32, int64. available data type is int32, int64.
...@@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None): ...@@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool") "segment_pool")
...@@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None): ...@@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None):
where min is over j such that `segment_ids[j] == i`. where min is over j such that `segment_ids[j] == i`.
Args: Args:
data (tensor): a tensor, available data type float32, float64. data (tensor): a tensor, available data type float32, float64, int32, int64.
segment_ids (tensor): a 1-d tensor, which have the same size segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data. with the first dimension of input data.
available data type is int32, int64. available data type is int32, int64.
...@@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None): ...@@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool") "segment_pool")
...@@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None): ...@@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None):
where max is over j such that `segment_ids[j] == i`. where max is over j such that `segment_ids[j] == i`.
Args: Args:
data (tensor): a tensor, available data type float32, float64. data (tensor): a tensor, available data type float32, float64, int32, int64.
segment_ids (tensor): a 1-d tensor, which have the same size segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data. with the first dimension of input data.
available data type is int32, int64. available data type is int32, int64.
...@@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None): ...@@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX") out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool") check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool") "segment_pool")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册