未验证 提交 03d6d98c 编写于 作者: H houj04 提交者: GitHub

[XPU] fix 0-dim of SplitFunctor. (#54816)

上级 a702e170
......@@ -85,12 +85,21 @@ class SplitFunctor<XPUContext, T> {
int num = ins.size();
auto input_dims = ins[0]->dims();
// special for 0-dim shape
if (input_dims.size() == 0) {
input_dims = {1};
}
std::vector<int> split_list(num);
std::vector<int> xdims_list(input_dims.size());
int total_length = 0;
for (int i = 0; i < num; ++i) {
split_list[i] = ins[i]->dims()[axis];
total_length += ins[i]->dims()[axis];
auto ins_i_dims = ins[i]->dims();
// special for 0-dim shape
if (ins_i_dims.size() == 0) {
ins_i_dims = {1};
}
split_list[i] = ins_i_dims[axis];
total_length += ins_i_dims[axis];
}
for (int i = 0; i < input_dims.size(); ++i) {
......@@ -110,6 +119,9 @@ class SplitFunctor<XPUContext, T> {
context.template Alloc<T>(&tmp_data);
}
// int split(Context* ctx, const T* x, const std::vector<T*>& y_list, const
// std::vector<int64_t>& xshape, const std::vector<int64_t>& split_list,
// int64_t axis);
auto r = xpu::split<XPUType>(
context.x_context(),
reinterpret_cast<const XPUType*>(tmp_data.data<T>()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册