From b3e5b0c469d15bbbaf6548f2a48df8c9e24a0712 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Mon, 6 Feb 2023 11:34:00 +0800 Subject: [PATCH] [XPU] add int type for concat and split functor (#50200) --- paddle/fluid/pybind/process_group_utils.h | 12 ++++++++++++ paddle/phi/kernels/xpu/concat_and_split_functor.cc | 2 ++ 2 files changed, 14 insertions(+) diff --git a/paddle/fluid/pybind/process_group_utils.h b/paddle/fluid/pybind/process_group_utils.h index d6930923316..4e5af25f26d 100644 --- a/paddle/fluid/pybind/process_group_utils.h +++ b/paddle/fluid/pybind/process_group_utils.h @@ -143,6 +143,12 @@ void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx, case phi::DataType::FLOAT32: ConcatDenseTensor()(dev_ctx, t_list, p_out); break; + case phi::DataType::INT32: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::INT64: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors.", type)); @@ -205,6 +211,12 @@ void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx, case phi::DataType::FLOAT32: SplitDenseTensor()(dev_ctx, t_in, p_list); break; + case phi::DataType::INT32: + SplitDenseTensor()(dev_ctx, t_in, p_list); + break; + case phi::DataType::INT64: + SplitDenseTensor()(dev_ctx, t_in, p_list); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors.", type)); diff --git a/paddle/phi/kernels/xpu/concat_and_split_functor.cc b/paddle/phi/kernels/xpu/concat_and_split_functor.cc index 769458523a6..8cb24859c78 100644 --- a/paddle/phi/kernels/xpu/concat_and_split_functor.cc +++ b/paddle/phi/kernels/xpu/concat_and_split_functor.cc @@ -127,6 +127,8 @@ class SplitFunctor { DEFINE_XPU_FUNCTOR(float) DEFINE_XPU_FUNCTOR(phi::dtype::float16) +DEFINE_XPU_FUNCTOR(int32_t) +DEFINE_XPU_FUNCTOR(int64_t) } // namespace funcs } // namespace phi -- GitLab