From 1670db5e86ef2c1fead409b519dd16a471b7ab8b Mon Sep 17 00:00:00 2001 From: hutuxian Date: Sat, 25 May 2019 22:06:41 +0800 Subject: [PATCH] Gather Op Index Support int64_t datatype (#17610) * gather_op support int64_t index by adding a template typename * add UT and rename typename test=develop --- paddle/fluid/operators/gather.cu.h | 20 +++++----- paddle/fluid/operators/gather.h | 8 ++-- paddle/fluid/operators/gather_op.cu | 33 ++++++++++++++-- paddle/fluid/operators/gather_op.h | 34 ++++++++++++++-- paddle/fluid/operators/scatter.cu.h | 16 ++++---- paddle/fluid/operators/scatter.h | 8 ++-- .../fluid/tests/unittests/test_gather_op.py | 39 ++++++++++++++++++- 7 files changed, 124 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index 5bc2e63757f..fff817fbd02 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -26,14 +26,15 @@ using platform::DeviceContext; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) -template -__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, - size_t index_size, size_t slice_size) { +template +__global__ void GatherCUDAKernel(const T* params, const IndexT* indices, + T* output, size_t index_size, + size_t slice_size) { CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice - int gather_i = indices[indices_i]; - int params_i = gather_i * slice_size + slice_i; + IndexT gather_i = indices[indices_i]; + IndexT params_i = gather_i * slice_size + slice_i; *(output + i) = *(params + params_i); } } @@ -42,10 +43,10 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output, * A thin wrapper on gpu tensor * Return a new tensor from source tensor, gathered according to index * input[src]: type-T source Tensor - * input[index]: type-int index Tensor (1-D) + * input[index]: type-IndexT index Tensor (1-D) * return: output tensor */ -template +template void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, const Tensor& index, Tensor* output) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); @@ -64,15 +65,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; const T* p_src = src.data(); - // why must be int? - const int* p_index = index.data(); + const IndexT* p_index = index.data(); T* p_output = output->data(); int block = 512; int n = slice_size * index_size; int grid = (n + block - 1) / block; - GatherCUDAKernel<<< + GatherCUDAKernel<<< grid, block, 0, reinterpret_cast(ctx).stream()>>>( p_src, p_index, p_output, index_size, slice_size); diff --git a/paddle/fluid/operators/gather.h b/paddle/fluid/operators/gather.h index dc08ee5efac..1e02c036e35 100644 --- a/paddle/fluid/operators/gather.h +++ b/paddle/fluid/operators/gather.h @@ -30,10 +30,10 @@ using framework::Tensor; * A thin wrapper for gathering on cpu tensor * Return a new tensor from source tensor, gathered according to index * input[src]: type-T source Tensor - * input[index]: type-int index Tensor (1-D) + * input[index]: type-IndexT index Tensor (1-D) * return: output tensor */ -template +template void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, const Tensor& index, Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); @@ -45,7 +45,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, auto src_dims = src.dims(); const T* p_src = src.data(); - const int* p_index = index.data(); + const IndexT* p_index = index.data(); T* p_output = output->data(); // slice size @@ -55,7 +55,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, const size_t slice_bytes = slice_size * sizeof(T); for (int64_t i = 0; i < index_size; ++i) { - int index_ = p_index[i]; + IndexT index_ = p_index[i]; memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); } } diff --git a/paddle/fluid/operators/gather_op.cu b/paddle/fluid/operators/gather_op.cu index 490ba9a585e..7a0b290ec86 100644 --- a/paddle/fluid/operators/gather_op.cu +++ b/paddle/fluid/operators/gather_op.cu @@ -32,7 +32,20 @@ class GatherOpCUDAKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - GPUGather(ctx.device_context(), *x, *index, output); + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE( + index_type_match, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + GPUGather(ctx.device_context(), *x, *index, output); + } else if (index_type == framework::proto::VarType::INT64) { + GPUGather(ctx.device_context(), *x, *index, output); + } } }; @@ -42,7 +55,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto *Index = ctx.Input("Index"); + auto *index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); @@ -52,7 +65,21 @@ class GatherGradOpCUDAKernel : public framework::OpKernel { .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; - GPUScatterAssign(ctx.device_context(), *dO, *Index, dX); + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE( + index_type_match, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + GPUScatterAssign(ctx.device_context(), *dO, *index, dX); + } else if (index_type == framework::proto::VarType::INT64) { + GPUScatterAssign(ctx.device_context(), *dO, *index, dX); + } } }; diff --git a/paddle/fluid/operators/gather_op.h b/paddle/fluid/operators/gather_op.h index 2e18298cf8e..a58f794efa9 100644 --- a/paddle/fluid/operators/gather_op.h +++ b/paddle/fluid/operators/gather_op.h @@ -36,7 +36,21 @@ class GatherOpKernel : public framework::OpKernel { output->mutable_data(ctx.GetPlace()); if (x->numel() == 0) return; - CPUGather(ctx.device_context(), *x, *index, output); + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE( + index_type_match, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + CPUGather(ctx.device_context(), *x, *index, output); + } else if (index_type == framework::proto::VarType::INT64) { + CPUGather(ctx.device_context(), *x, *index, output); + } } }; @@ -47,7 +61,7 @@ class GatherGradientOpKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "This kernel only runs on CPU."); - auto *Index = ctx.Input("Index"); + auto *index = ctx.Input("Index"); auto *dX = ctx.Output(framework::GradVarName("X")); auto *dO = ctx.Input(framework::GradVarName("Out")); @@ -57,7 +71,21 @@ class GatherGradientOpKernel : public framework::OpKernel { .eigen_device(); dxt.device(place) = dxt.constant(static_cast(0)); if (dO->numel() == 0) return; - ScatterAssign(ctx.device_context(), *dO, *Index, dX); + + const auto &index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT32 || + index_type == framework::proto::VarType::INT64; + PADDLE_ENFORCE( + index_type_match, + "Index holds the wrong type, it holds %s, but desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString(framework::proto::VarType::INT32), + paddle::framework::DataTypeToString(framework::proto::VarType::INT64)); + if (index_type == framework::proto::VarType::INT32) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } else if (index_type == framework::proto::VarType::INT64) { + ScatterAssign(ctx.device_context(), *dO, *index, dX); + } } }; diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index b2e79f6c82b..030719baa8f 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -25,15 +25,15 @@ using Tensor = framework::Tensor; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) -template -__global__ void ScatterCUDAKernel(const T* params, const int* indices, +template +__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, T* output, size_t index_size, size_t slice_size) { CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice - int scatter_i = indices[indices_i]; - int out_i = scatter_i * slice_size + slice_i; + IndexT scatter_i = indices[indices_i]; + IndexT out_i = scatter_i * slice_size + slice_i; *(output + out_i) = *(params + i); } } @@ -43,10 +43,10 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices, * Return a new updated tensor from source tensor, scatter-assigned according to * index * input[src]: type-T source Tensor - * input[index]: type-int index Tensor (1-D) + * input[index]: type-IndexT index Tensor (1-D) * return: output tensor */ -template +template void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, const Tensor& index, Tensor* output) { // PADDLE_ENFORCE(platform::is_gpu_place(place)); @@ -64,14 +64,14 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; const T* p_src = src.data(); - const int* p_index = index.data(); + const IndexT* p_index = index.data(); T* p_output = output->data(); int block = 512; int n = slice_size * index_size; int grid = (n + block - 1) / block; - ScatterCUDAKernel<<< + ScatterCUDAKernel<<< grid, block, 0, reinterpret_cast(ctx).stream()>>>( p_src, p_index, p_output, index_size, slice_size); diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 8bae6606c94..17d7d82144d 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -29,10 +29,10 @@ using Tensor = framework::Tensor; * Return a updated tensor from source tensor, scattered according to index: * dst[i] = src[index[i]] * input[src]: type-T source Tensor - * input[index]: type-int index Tensor (1-D) + * input[index]: type-IndexT index Tensor (1-D) * return: output tensor */ -template +template void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, const Tensor& index, Tensor* output) { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); @@ -45,7 +45,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, auto dst_dims = output->dims(); const T* p_src = src.data(); - const int* p_index = index.data(); + const IndexT* p_index = index.data(); T* p_output = output->data(); // check src shape and dst shape should match @@ -59,7 +59,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, const size_t slice_bytes = slice_size * sizeof(T); for (int i = 0; i < index_size; ++i) { - int index_ = p_index[i]; + IndexT index_ = p_index[i]; memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes); } } diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index bd5785aa55a..daa5e60498e 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -23,8 +23,11 @@ class TestGatherOp(OpTest): def setUp(self): self.op_type = "gather" self.config() - xnp = np.random.random(self.x_shape).astype("float32") - self.inputs = {'X': xnp, 'Index': np.array(self.index).astype("int32")} + xnp = np.random.random(self.x_shape).astype(self.x_type) + self.inputs = { + 'X': xnp, + 'Index': np.array(self.index).astype(self.index_type) + } self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} def test_check_output(self): @@ -34,14 +37,46 @@ class TestGatherOp(OpTest): self.check_grad(['X'], 'Out') def config(self): + """ + For multi-dimension input + """ self.x_shape = (10, 20) + self.x_type = "float32" self.index = [1, 3, 5] + self.index_type = "int32" class TestCase1(TestGatherOp): def config(self): + """ + For one dimension input + """ self.x_shape = (10) + self.x_type = "float32" self.index = [1, 3, 5] + self.index_type = "int32" + + +class TestCase2(TestGatherOp): + def config(self): + """ + For int64_t index type + """ + self.x_shape = (10) + self.x_type = "float32" + self.index = [1, 3, 5] + self.index_type = "int64" + + +class TestCase3(TestGatherOp): + def config(self): + """ + For other input type + """ + self.x_shape = (10, 20) + self.x_type = "double" + self.index = [1, 3, 5] + self.index_type = "int64" if __name__ == "__main__": -- GitLab