From 03d6d98c97e90e83e841087a151bca88fc9a2cc6 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Sun, 25 Jun 2023 13:27:16 +0800 Subject: [PATCH] [XPU] fix 0-dim of SplitFunctor. (#54816) --- .../phi/kernels/xpu/concat_and_split_functor.cc | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/xpu/concat_and_split_functor.cc b/paddle/phi/kernels/xpu/concat_and_split_functor.cc index edff1e90143..a1335f33b67 100644 --- a/paddle/phi/kernels/xpu/concat_and_split_functor.cc +++ b/paddle/phi/kernels/xpu/concat_and_split_functor.cc @@ -85,12 +85,21 @@ class SplitFunctor { int num = ins.size(); auto input_dims = ins[0]->dims(); + // special for 0-dim shape + if (input_dims.size() == 0) { + input_dims = {1}; + } std::vector split_list(num); std::vector xdims_list(input_dims.size()); int total_length = 0; for (int i = 0; i < num; ++i) { - split_list[i] = ins[i]->dims()[axis]; - total_length += ins[i]->dims()[axis]; + auto ins_i_dims = ins[i]->dims(); + // special for 0-dim shape + if (ins_i_dims.size() == 0) { + ins_i_dims = {1}; + } + split_list[i] = ins_i_dims[axis]; + total_length += ins_i_dims[axis]; } for (int i = 0; i < input_dims.size(); ++i) { @@ -110,6 +119,9 @@ class SplitFunctor { context.template Alloc(&tmp_data); } + // int split(Context* ctx, const T* x, const std::vector& y_list, const + // std::vector& xshape, const std::vector& split_list, + // int64_t axis); auto r = xpu::split( context.x_context(), reinterpret_cast(tmp_data.data()), -- GitLab