未验证 提交 b3e5b0c4 编写于 作者: H houj04 提交者: GitHub

[XPU] add int type for concat and split functor (#50200)

上级 35ce2bd9
...@@ -143,6 +143,12 @@ void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx, ...@@ -143,6 +143,12 @@ void ConcatDenseTensorWithType(const phi::XPUContext &dev_ctx,
case phi::DataType::FLOAT32: case phi::DataType::FLOAT32:
ConcatDenseTensor<phi::XPUContext, float>()(dev_ctx, t_list, p_out); ConcatDenseTensor<phi::XPUContext, float>()(dev_ctx, t_list, p_out);
break; break;
case phi::DataType::INT32:
ConcatDenseTensor<phi::XPUContext, int32_t>()(dev_ctx, t_list, p_out);
break;
case phi::DataType::INT64:
ConcatDenseTensor<phi::XPUContext, int64_t>()(dev_ctx, t_list, p_out);
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it concats tensors.", type)); "Data type (%s) is not supported when it concats tensors.", type));
...@@ -205,6 +211,12 @@ void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx, ...@@ -205,6 +211,12 @@ void SplitDenseTensorWithType(const phi::XPUContext &dev_ctx,
case phi::DataType::FLOAT32: case phi::DataType::FLOAT32:
SplitDenseTensor<phi::XPUContext, float>()(dev_ctx, t_in, p_list); SplitDenseTensor<phi::XPUContext, float>()(dev_ctx, t_in, p_list);
break; break;
case phi::DataType::INT32:
SplitDenseTensor<phi::XPUContext, int32_t>()(dev_ctx, t_in, p_list);
break;
case phi::DataType::INT64:
SplitDenseTensor<phi::XPUContext, int64_t>()(dev_ctx, t_in, p_list);
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors.", type)); "Data type (%s) is not supported when it splits tensors.", type));
......
...@@ -127,6 +127,8 @@ class SplitFunctor<XPUContext, T> { ...@@ -127,6 +127,8 @@ class SplitFunctor<XPUContext, T> {
DEFINE_XPU_FUNCTOR(float) DEFINE_XPU_FUNCTOR(float)
DEFINE_XPU_FUNCTOR(phi::dtype::float16) DEFINE_XPU_FUNCTOR(phi::dtype::float16)
DEFINE_XPU_FUNCTOR(int32_t)
DEFINE_XPU_FUNCTOR(int64_t)
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册