未验证 提交 c077a6d5 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/support int64 for sum (#5832)

* Support int64 for sum op

* Refine code
上级 e800c0d3
......@@ -145,6 +145,8 @@ struct SelectedRowsAddTo<platform::CPUPlace, T> {
template struct SelectedRowsAddTo<platform::CPUPlace, float>;
template struct SelectedRowsAddTo<platform::CPUPlace, double>;
template struct SelectedRowsAddTo<platform::CPUPlace, int>;
template struct SelectedRowsAddTo<platform::CPUPlace, int64_t>;
template <typename T>
struct SelectedRowsAddToTensor<platform::CPUPlace, T> {
......@@ -175,6 +177,8 @@ struct SelectedRowsAddToTensor<platform::CPUPlace, T> {
template struct SelectedRowsAddToTensor<platform::CPUPlace, float>;
template struct SelectedRowsAddToTensor<platform::CPUPlace, double>;
template struct SelectedRowsAddToTensor<platform::CPUPlace, int>;
template struct SelectedRowsAddToTensor<platform::CPUPlace, int64_t>;
} // namespace math
} // namespace operators
......
......@@ -173,6 +173,8 @@ struct SelectedRowsAddTo<platform::GPUPlace, T> {
template struct SelectedRowsAddTo<platform::GPUPlace, float>;
template struct SelectedRowsAddTo<platform::GPUPlace, double>;
template struct SelectedRowsAddTo<platform::GPUPlace, int>;
template struct SelectedRowsAddTo<platform::GPUPlace, int64_t>;
namespace {
template <typename T, int block_size>
......@@ -223,6 +225,8 @@ struct SelectedRowsAddToTensor<platform::GPUPlace, T> {
template struct SelectedRowsAddToTensor<platform::GPUPlace, float>;
template struct SelectedRowsAddToTensor<platform::GPUPlace, double>;
template struct SelectedRowsAddToTensor<platform::GPUPlace, int>;
template struct SelectedRowsAddToTensor<platform::GPUPlace, int64_t>;
} // namespace math
} // namespace operators
......
......@@ -176,4 +176,6 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
ops::SumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>,
ops::SumKernel<paddle::platform::CPUPlace, double>);
ops::SumKernel<paddle::platform::CPUPlace, double>,
ops::SumKernel<paddle::platform::CPUPlace, int>,
ops::SumKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -14,4 +14,6 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>,
ops::SumKernel<paddle::platform::GPUPlace, double>);
ops::SumKernel<paddle::platform::GPUPlace, double>,
ops::SumKernel<paddle::platform::GPUPlace, int>,
ops::SumKernel<paddle::platform::GPUPlace, int64_t>);
......@@ -31,6 +31,16 @@ constexpr int PADDLE_CUDA_NUM_THREADS = 512;
// For atomicAdd.
USE_CUDA_ATOMIC(Add, float);
USE_CUDA_ATOMIC(Add, int);
USE_CUDA_ATOMIC(Add, unsigned int);
USE_CUDA_ATOMIC(Add, unsigned long long int);
CUDA_ATOMIC_WRAPPER(Add, int64_t) {
static_assert(sizeof(int64_t) == sizeof(long long int),
"long long should be int64");
return CudaAtomicAdd(reinterpret_cast<unsigned long long int*>(address),
static_cast<unsigned long long int>(val));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
USE_CUDA_ATOMIC(Add, double);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册