From de5dec8412894d715a47993b231cd798a761caca Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Thu, 15 Aug 2019 22:44:59 +0800 Subject: [PATCH] [Cherry-pick]Fix gather op bug (#19169) * fix gather op bug test=release/1.5 --- paddle/fluid/operators/gather.cu.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index fff817fbd0..86b3a25235 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -49,10 +49,16 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices, template void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, const Tensor& index, Tensor* output) { - // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D - PADDLE_ENFORCE(index.dims().size() == 1 || - (index.dims().size() == 2 && index.dims()[1] == 1)); + if (index.dims().size() == 1) { + PADDLE_ENFORCE_GT(index.dims()[0], 0, + "The index of gather_op should not be empty when the " + "index's rank is 1."); + } else if (index.dims().size() == 2) { + PADDLE_ENFORCE_EQ(index.dims()[1], 1, + " If the index's rank of gather_op is 2, the second " + "dimension should be 1."); + } int index_size = index.dims()[0]; -- GitLab