未验证 提交 6f7aca9e 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix scatter and gather bug (#35595)

* fix scatter gather bug:

* fix windows ci
上级 42847d2e
......@@ -36,7 +36,7 @@ __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i;
int64_t params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
......@@ -49,7 +49,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = 0;
int64_t gather_i = 0;
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
auto index_value = indices[indices_i * end_size + j];
......@@ -63,7 +63,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
gather_i += (index_value * temp);
temp *= input_dims[j];
}
IndexT input_i = gather_i + slice_i;
int64_t input_i = gather_i + slice_i;
*(output + i) = *(input + input_i);
}
}
......@@ -78,13 +78,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int64_t* input_dims,
template <typename T, typename IndexT = int>
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// check index of shape 1-D
if (index.dims().size() == 1) {
PADDLE_ENFORCE_GT(index.dims()[0], 0,
platform::errors::InvalidArgument(
"The index of gather_op should not be empty"
"when the index's rank is 1."));
} else if (index.dims().size() == 2) {
if (index.dims().size() == 2) {
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
platform::errors::InvalidArgument(
"If the index's rank of gather_op is 2,"
......@@ -93,6 +87,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// index size
int64_t index_size = index.dims()[0];
if (index_size == 0) return;
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
......@@ -248,6 +243,7 @@ void GatherV2CUDAFunction(const Tensor* input, const Tensor* index,
out->Resize(out_dim);
auto* out_data = out->mutable_data<T>(place);
int64_t out_size = out->numel();
if (out_size == 0) return;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), out_size);
......
......@@ -29,9 +29,9 @@ using Tensor = framework::Tensor;
template <typename T, typename IndexT = int>
__global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
size_t index_size, size_t slice_size) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT scatter_i = indices[indices_i];
PADDLE_ENFORCE(scatter_i >= 0,
......@@ -41,7 +41,7 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
"be greater than or equal to 0, but received [%d]",
scatter_i);
IndexT out_i = scatter_i * slice_size + slice_i;
int64_t out_i = scatter_i * slice_size + slice_i;
*(output + out_i) = static_cast<T>(0);
}
}
......@@ -50,9 +50,9 @@ template <typename T, typename IndexT = int>
__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t index_size,
size_t slice_size, bool overwrite) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT scatter_i = indices[indices_i];
PADDLE_ENFORCE(scatter_i >= 0,
......@@ -62,7 +62,7 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
"be greater than or equal to 0, but received [%d]",
scatter_i);
IndexT out_i = scatter_i * slice_size + slice_i;
int64_t out_i = scatter_i * slice_size + slice_i;
if (overwrite) {
*(output + out_i) = *(params + i);
} else {
......@@ -73,13 +73,13 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
template <typename T, typename IndexT = int>
__global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
T* output, const int* output_dims,
T* output, const int64_t* output_dims,
size_t remain_size, size_t slice_size,
size_t end_size) {
CUDA_KERNEL_LOOP(i, remain_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = 0;
CUDA_KERNEL_LOOP_TYPE(i, remain_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
int64_t gather_i = 0;
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = indices[indices_i * end_size + j];
......@@ -95,7 +95,7 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
gather_i += (index_value * temp);
temp *= output_dims[j];
}
IndexT output_i = gather_i + slice_i;
int64_t output_i = gather_i + slice_i;
paddle::platform::CudaAtomicAdd(output + output_i, *(update + i));
}
}
......@@ -128,14 +128,14 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
"But received value is [%d]",
index.dims().size()));
}
int index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
int slice_size = 1;
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
......@@ -145,8 +145,8 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
// set block and grid num
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
// if not overwrite mode, init data
if (!overwrite) {
......@@ -167,10 +167,10 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
template <typename T, typename IndexT = int>
void GPUScatterGradForX(const platform::DeviceContext& ctx, const Tensor& index,
Tensor* output) {
IndexT index_size = index.dims()[0];
int64_t index_size = index.dims()[0];
auto dst_dims = output->dims();
// slice size
IndexT slice_size = 1;
int64_t slice_size = 1;
for (int i = 1; i < dst_dims.size(); ++i) slice_size *= dst_dims[i];
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
......@@ -224,20 +224,20 @@ void GPUScatterNdAdd(const framework::ExecutionContext& context,
const auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto cplace = platform::CPUPlace();
std::vector<int> v_output_dims(output_dims_size);
std::vector<int64_t> v_output_dims(output_dims_size);
for (int i = 0; i < output_dims_size; ++i) {
v_output_dims[i] = static_cast<int>(output_dims[i]);
v_output_dims[i] = output_dims[i];
}
auto& dev_ctx = context.cuda_device_context();
int bytes = output_dims_size * sizeof(int);
int64_t bytes = output_dims_size * sizeof(int64_t);
auto output_dims_ptr = memory::Alloc(dev_ctx, bytes);
int* g_output_dims = reinterpret_cast<int*>(output_dims_ptr->ptr());
int64_t* g_output_dims = reinterpret_cast<int64_t*>(output_dims_ptr->ptr());
memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes,
ctx.stream());
int block = 512;
int n = slice_size * remain_numel;
int grid = (n + block - 1) / block;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
ScatterNdCUDAKernel<T, IndexT><<<
grid, block, 0,
......
......@@ -112,7 +112,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
for (int64_t i = 0; i < index_size; ++i) {
IndexT index_ = p_index[i];
PADDLE_ENFORCE_GE(index_, 0,
......@@ -175,7 +175,7 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
}
// if not in overwrite mode, need to init output data
for (int i = 0; i < index_size; ++i) {
for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_val = p_index[i];
PADDLE_ENFORCE_GE(index_val, 0,
......
......@@ -248,6 +248,17 @@ class API_TestDygraphGather(unittest.TestCase):
self.assertTrue(np.allclose(output_np, expected_output))
paddle.enable_static()
def test_zero_index(self):
paddle.disable_static()
x = paddle.to_tensor([[1, 2], [3, 4]])
index = paddle.to_tensor(np.array([]).astype('int64'))
for axis in range(len(x.shape)):
out = paddle.gather(x, index, axis)
expected_shape = list(x.shape)
expected_shape[axis] = 0
self.assertEqual(list(out.shape), expected_shape)
paddle.enable_static()
def test_large_data(self):
if not paddle.is_compiled_with_cuda():
return
......@@ -340,4 +351,5 @@ class TestCheckOutType(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -16,10 +16,12 @@ from __future__ import print_function
import unittest
import numpy as np
import os
import paddle
import paddle.fluid as fluid
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.dygraph.base import switch_to_static_graph
class TestScatterOp(OpTest):
......@@ -228,6 +230,44 @@ class TestScatterAPI(unittest.TestCase):
self.assertEqual((output1.numpy() == \
np.array([[3., 3.],[6., 6.],[1., 1.]])).all(), True)
def test_large_data(self):
if os.name == "nt" or not paddle.is_compiled_with_cuda():
return
x = np.random.rand(183826, 256).astype("float32")
index = np.ones(10759233, dtype="int64")
updates = np.ones(shape=[10759233, 256], dtype="float32")
def test_dygraph():
with fluid.dygraph.guard():
gpu_out = paddle.scatter(
paddle.to_tensor(x),
paddle.to_tensor(index), paddle.to_tensor(updates))
return gpu_out.numpy()
@switch_to_static_graph
def test_static_graph():
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x_t = paddle.static.data(name="x", dtype=x.dtype, shape=x.shape)
index_t = paddle.static.data(
name="index", dtype=index.dtype, shape=index.shape)
updates_t = paddle.static.data(
name="updates", dtype=updates.dtype, shape=updates.shape)
out_t = paddle.scatter(x_t, index_t, updates_t)
feed = {
x_t.name: x,
index_t.name: index,
updates_t.name: updates
}
fetch = [out_t]
gpu_exe = paddle.static.Executor(paddle.CUDAPlace(0))
gpu_value = gpu_exe.run(feed=feed, fetch_list=fetch)[0]
return gpu_value
self.assertTrue(np.array_equal(test_dygraph(), test_static_graph()))
class TestScatterInplaceAPI(TestScatterAPI):
def executed_api(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册