diff --git a/paddle/fluid/pybind/process_group_utils.h b/paddle/fluid/pybind/process_group_utils.h index d693092331672c6bbf0bc700ecb17825719f4757..4e5af25f26d92c566ac546ccf2490ee44e385db2 100644 --- a/paddle/fluid/pybind/process_group_utils.h +++ b/paddle/fluid/pybind/process_group_utils.h @@ -143,6 +143,12 @@ void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx, case phi::DataType::FLOAT32: ConcatDenseTensor()(dev_ctx, t_list, p_out); break; + case phi::DataType::INT32: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::INT64: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors.", type)); @@ -205,6 +211,12 @@ void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx, case phi::DataType::FLOAT32: SplitDenseTensor()(dev_ctx, t_in, p_list); break; + case phi::DataType::INT32: + SplitDenseTensor()(dev_ctx, t_in, p_list); + break; + case phi::DataType::INT64: + SplitDenseTensor()(dev_ctx, t_in, p_list); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors.", type)); diff --git a/paddle/phi/kernels/xpu/concat_and_split_functor.cc b/paddle/phi/kernels/xpu/concat_and_split_functor.cc index 769458523a68ccd11a20b21f0cec32d4823196b1..8cb24859c7811fc7bcd0a27d5964cc6266dd9b42 100644 --- a/paddle/phi/kernels/xpu/concat_and_split_functor.cc +++ b/paddle/phi/kernels/xpu/concat_and_split_functor.cc @@ -127,6 +127,8 @@ class SplitFunctor { DEFINE_XPU_FUNCTOR(float) DEFINE_XPU_FUNCTOR(phi::dtype::float16) +DEFINE_XPU_FUNCTOR(int32_t) +DEFINE_XPU_FUNCTOR(int64_t) } // namespace funcs } // namespace phi