From b5ba801ef065f04ef20e781acedf8fe53d611855 Mon Sep 17 00:00:00 2001 From: chengduo <30176695+chengduoZH@users.noreply.github.com> Date: Wed, 14 Aug 2019 22:18:12 +0800 Subject: [PATCH] Fix gather op bug (#19168) * fix gather op bug test=develop --- paddle/fluid/operators/gather.cu.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index fff817fbd02..86b3a25235c 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -49,10 +49,16 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices, template void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, const Tensor& index, Tensor* output) { - // PADDLE_ENFORCE(platform::is_gpu_place(place)); // check index of shape 1-D - PADDLE_ENFORCE(index.dims().size() == 1 || - (index.dims().size() == 2 && index.dims()[1] == 1)); + if (index.dims().size() == 1) { + PADDLE_ENFORCE_GT(index.dims()[0], 0, + "The index of gather_op should not be empty when the " + "index's rank is 1."); + } else if (index.dims().size() == 2) { + PADDLE_ENFORCE_EQ(index.dims()[1], 1, + " If the index's rank of gather_op is 2, the second " + "dimension should be 1."); + } int index_size = index.dims()[0]; -- GitLab