diff --git a/paddle/phi/kernels/xpu/concat_and_split_functor.cc b/paddle/phi/kernels/xpu/concat_and_split_functor.cc index edff1e90143b63796825364c79c0636d30170786..a1335f33b67007c0182869eb2f463fb673511a5b 100644 --- a/paddle/phi/kernels/xpu/concat_and_split_functor.cc +++ b/paddle/phi/kernels/xpu/concat_and_split_functor.cc @@ -85,12 +85,21 @@ class SplitFunctor { int num = ins.size(); auto input_dims = ins[0]->dims(); + // special for 0-dim shape + if (input_dims.size() == 0) { + input_dims = {1}; + } std::vector split_list(num); std::vector xdims_list(input_dims.size()); int total_length = 0; for (int i = 0; i < num; ++i) { - split_list[i] = ins[i]->dims()[axis]; - total_length += ins[i]->dims()[axis]; + auto ins_i_dims = ins[i]->dims(); + // special for 0-dim shape + if (ins_i_dims.size() == 0) { + ins_i_dims = {1}; + } + split_list[i] = ins_i_dims[axis]; + total_length += ins_i_dims[axis]; } for (int i = 0; i < input_dims.size(); ++i) { @@ -110,6 +119,9 @@ class SplitFunctor { context.template Alloc(&tmp_data); } + // int split(Context* ctx, const T* x, const std::vector& y_list, const + // std::vector& xshape, const std::vector& split_list, + // int64_t axis); auto r = xpu::split( context.x_context(), reinterpret_cast(tmp_data.data()),