未验证 提交 de5dec84 编写于 作者: C chengduo 提交者: GitHub

[Cherry-pick]Fix gather op bug (#19169)

* fix gather op bug
test=release/1.5
上级 cc3ba765
......@@ -49,10 +49,16 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
template <typename T, typename IndexT = int>
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
// check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1 ||
(index.dims().size() == 2 && index.dims()[1] == 1));
if (index.dims().size() == 1) {
PADDLE_ENFORCE_GT(index.dims()[0], 0,
"The index of gather_op should not be empty when the "
"index's rank is 1.");
} else if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
" If the index's rank of gather_op is 2, the second "
"dimension should be 1.");
}
int index_size = index.dims()[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册