未验证 提交 808bf2b4 编写于 作者: Z zhaoyingli 提交者: GitHub

fix shard_index kernel (#46491)

上级 d150c3a6
...@@ -33,7 +33,15 @@ __global__ void ShardIndexInner(const T* in_data, ...@@ -33,7 +33,15 @@ __global__ void ShardIndexInner(const T* in_data,
int shard_size = (index_num + nshards - 1) / nshards; int shard_size = (index_num + nshards - 1) / nshards;
int idx = blockIdx.x * blockDim.x + threadIdx.x; int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) { if (idx < numel) {
assert(in_data[idx] >= 0 && in_data[idx] < index_num); PADDLE_ENFORCE(in_data[idx] >= 0,
"The input_index for Op(shard_index) must be "
"greater or equal to 0, but the value given is %d.",
in_data[idx]);
PADDLE_ENFORCE(in_data[idx] < index_num,
"The input_index for Op(shard_index) must be less "
"than index_num (%d), but the value given is %d.",
index_num,
in_data[idx]);
if (in_data[idx] / shard_size == shard_id) { if (in_data[idx] / shard_size == shard_id) {
out_data[idx] = in_data[idx] % shard_size; out_data[idx] = in_data[idx] % shard_size;
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册