未验证 提交 6849d33b 编写于 作者: Z Zhong Hui 提交者: GitHub

[Ops] segment pool op support for int int64 kernel. (#40577)

* segment pool support for int int64 kernel.

* add support in python api
上级 2dec25db
......@@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(segment_pool_grad,
ALL_LAYOUT,
phi::SegmentPoolGradKernel,
float,
double) {}
double,
int,
int64_t) {}
......@@ -18,5 +18,11 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
segment_pool, CPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {}
PD_REGISTER_KERNEL(segment_pool,
CPU,
ALL_LAYOUT,
phi::SegmentPoolKernel,
float,
double,
int,
int64_t) {}
......@@ -149,10 +149,19 @@ template class SegmentPoolFunctor<CPU, float, int>;
template class SegmentPoolFunctor<CPU, float, int64_t>;
template class SegmentPoolFunctor<CPU, double, int>;
template class SegmentPoolFunctor<CPU, double, int64_t>;
template class SegmentPoolFunctor<CPU, int, int>;
template class SegmentPoolFunctor<CPU, int, int64_t>;
template class SegmentPoolFunctor<CPU, int64_t, int>;
template class SegmentPoolFunctor<CPU, int64_t, int64_t>;
template class SegmentPoolGradFunctor<CPU, float, int>;
template class SegmentPoolGradFunctor<CPU, float, int64_t>;
template class SegmentPoolGradFunctor<CPU, double, int>;
template class SegmentPoolGradFunctor<CPU, double, int64_t>;
template class SegmentPoolGradFunctor<CPU, int, int>;
template class SegmentPoolGradFunctor<CPU, int, int64_t>;
template class SegmentPoolGradFunctor<CPU, int64_t, int>;
template class SegmentPoolGradFunctor<CPU, int64_t, int64_t>;
} // namespace funcs
} // namespace phi
......@@ -453,10 +453,19 @@ template class SegmentPoolFunctor<GPU, float, int>;
template class SegmentPoolFunctor<GPU, float, int64_t>;
template class SegmentPoolFunctor<GPU, double, int>;
template class SegmentPoolFunctor<GPU, double, int64_t>;
template class SegmentPoolFunctor<GPU, int, int>;
template class SegmentPoolFunctor<GPU, int, int64_t>;
template class SegmentPoolFunctor<GPU, int64_t, int>;
template class SegmentPoolFunctor<GPU, int64_t, int64_t>;
template class SegmentPoolGradFunctor<GPU, float, int>;
template class SegmentPoolGradFunctor<GPU, float, int64_t>;
template class SegmentPoolGradFunctor<GPU, double, int>;
template class SegmentPoolGradFunctor<GPU, double, int64_t>;
template class SegmentPoolGradFunctor<GPU, int, int>;
template class SegmentPoolGradFunctor<GPU, int, int64_t>;
template class SegmentPoolGradFunctor<GPU, int64_t, int>;
template class SegmentPoolGradFunctor<GPU, int64_t, int64_t>;
} // namespace funcs
} // namespace phi
......@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(segment_pool_grad,
ALL_LAYOUT,
phi::SegmentPoolGradKernel,
float,
double) {}
double,
int,
int64_t) {}
......@@ -19,5 +19,11 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(
segment_pool, GPU, ALL_LAYOUT, phi::SegmentPoolKernel, float, double) {}
PD_REGISTER_KERNEL(segment_pool,
GPU,
ALL_LAYOUT,
phi::SegmentPoolKernel,
float,
double,
int,
int64_t) {}
......@@ -29,7 +29,7 @@ def segment_sum(data, segment_ids, name=None):
where sum is over j such that `segment_ids[j] == i`.
Args:
data (Tensor): A tensor, available data type float32, float64.
data (Tensor): A tensor, available data type float32, float64, int32, int64.
segment_ids (Tensor): A 1-D tensor, which have the same size
with the first dimension of input data.
Available data type is int32, int64.
......@@ -54,7 +54,8 @@ def segment_sum(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "SUM")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
......@@ -82,7 +83,7 @@ def segment_mean(data, segment_ids, name=None):
of all index 'segment_ids[j] == i'.
Args:
data (tensor): a tensor, available data type float32, float64.
data (tensor): a tensor, available data type float32, float64, int32, int64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
......@@ -107,7 +108,8 @@ def segment_mean(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MEAN")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
......@@ -134,7 +136,7 @@ def segment_min(data, segment_ids, name=None):
where min is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
data (tensor): a tensor, available data type float32, float64, int32, int64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
......@@ -159,7 +161,8 @@ def segment_min(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MIN")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
......@@ -186,7 +189,7 @@ def segment_max(data, segment_ids, name=None):
where max is over j such that `segment_ids[j] == i`.
Args:
data (tensor): a tensor, available data type float32, float64.
data (tensor): a tensor, available data type float32, float64, int32, int64.
segment_ids (tensor): a 1-d tensor, which have the same size
with the first dimension of input data.
available data type is int32, int64.
......@@ -211,7 +214,8 @@ def segment_max(data, segment_ids, name=None):
out, tmp = _C_ops.segment_pool(data, segment_ids, 'pooltype', "MAX")
return out
check_variable_and_dtype(data, "X", ("float32", "float64"), "segment_pool")
check_variable_and_dtype(data, "X", ("float32", "float64", "int32",
"int64"), "segment_pool")
check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"),
"segment_pool")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册