未验证 提交 1670db5e 编写于 作者: H hutuxian 提交者: GitHub

Gather Op Index Support int64_t datatype (#17610)

* gather_op support int64_t index by adding a template typename

* add UT and rename typename

test=develop
上级 febc07f0
......@@ -26,14 +26,15 @@ using platform::DeviceContext;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
size_t index_size, size_t slice_size) {
template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t index_size,
size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int gather_i = indices[indices_i];
int params_i = gather_i * slice_size + slice_i;
IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
}
}
......@@ -42,10 +43,10 @@ __global__ void GatherCUDAKernel(const T* params, const int* indices, T* output,
* A thin wrapper on gpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* input[index]: type-IndexT index Tensor (1-D)
* return: output tensor
*/
template <typename T>
template <typename T, typename IndexT = int>
void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
......@@ -64,15 +65,14 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
// why must be int?
const int* p_index = index.data<int>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
GatherCUDAKernel<T><<<
GatherCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
......
......@@ -30,10 +30,10 @@ using framework::Tensor;
* A thin wrapper for gathering on cpu tensor
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* input[index]: type-IndexT index Tensor (1-D)
* return: output tensor
*/
template <typename T>
template <typename T, typename IndexT = int>
void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
......@@ -45,7 +45,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
auto src_dims = src.dims();
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// slice size
......@@ -55,7 +55,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < index_size; ++i) {
int index_ = p_index[i];
IndexT index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
}
}
......
......@@ -32,7 +32,20 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
GPUGather<T>(ctx.device_context(), *x, *index, output);
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
GPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
GPUGather<T, int64_t>(ctx.device_context(), *x, *index, output);
}
}
};
......@@ -42,7 +55,7 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto *Index = ctx.Input<Tensor>("Index");
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -52,7 +65,21 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
GPUScatterAssign<T, int>(ctx.device_context(), *dO, *index, dX);
} else if (index_type == framework::proto::VarType::INT64) {
GPUScatterAssign<T, int64_t>(ctx.device_context(), *dO, *index, dX);
}
}
};
......
......@@ -36,7 +36,21 @@ class GatherOpKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
CPUGather<T>(ctx.device_context(), *x, *index, output);
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
CPUGather<T, int>(ctx.device_context(), *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
CPUGather<T, int64_t>(ctx.device_context(), *x, *index, output);
}
}
};
......@@ -47,7 +61,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"This kernel only runs on CPU.");
auto *Index = ctx.Input<Tensor>("Index");
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -57,7 +71,21 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
const auto &index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
ScatterAssign<T, int>(ctx.device_context(), *dO, *index, dX);
} else if (index_type == framework::proto::VarType::INT64) {
ScatterAssign<T, int64_t>(ctx.device_context(), *dO, *index, dX);
}
}
};
......
......@@ -25,15 +25,15 @@ using Tensor = framework::Tensor;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void ScatterCUDAKernel(const T* params, const int* indices,
template <typename T, typename IndexT = int>
__global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t index_size,
size_t slice_size) {
CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
int scatter_i = indices[indices_i];
int out_i = scatter_i * slice_size + slice_i;
IndexT scatter_i = indices[indices_i];
IndexT out_i = scatter_i * slice_size + slice_i;
*(output + out_i) = *(params + i);
}
}
......@@ -43,10 +43,10 @@ __global__ void ScatterCUDAKernel(const T* params, const int* indices,
* Return a new updated tensor from source tensor, scatter-assigned according to
* index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* input[index]: type-IndexT index Tensor (1-D)
* return: output tensor
*/
template <typename T>
template <typename T, typename IndexT = int>
void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
// PADDLE_ENFORCE(platform::is_gpu_place(place));
......@@ -64,14 +64,14 @@ void GPUScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
int block = 512;
int n = slice_size * index_size;
int grid = (n + block - 1) / block;
ScatterCUDAKernel<T><<<
ScatterCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
......
......@@ -29,10 +29,10 @@ using Tensor = framework::Tensor;
* Return a updated tensor from source tensor, scattered according to index:
* dst[i] = src[index[i]]
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* input[index]: type-IndexT index Tensor (1-D)
* return: output tensor
*/
template <typename T>
template <typename T, typename IndexT = int>
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const Tensor& index, Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
......@@ -45,7 +45,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
auto dst_dims = output->dims();
const T* p_src = src.data<T>();
const int* p_index = index.data<int>();
const IndexT* p_index = index.data<IndexT>();
T* p_output = output->data<T>();
// check src shape and dst shape should match
......@@ -59,7 +59,7 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i];
IndexT index_ = p_index[i];
memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes);
}
}
......
......@@ -23,8 +23,11 @@ class TestGatherOp(OpTest):
def setUp(self):
self.op_type = "gather"
self.config()
xnp = np.random.random(self.x_shape).astype("float32")
self.inputs = {'X': xnp, 'Index': np.array(self.index).astype("int32")}
xnp = np.random.random(self.x_shape).astype(self.x_type)
self.inputs = {
'X': xnp,
'Index': np.array(self.index).astype(self.index_type)
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def test_check_output(self):
......@@ -34,14 +37,46 @@ class TestGatherOp(OpTest):
self.check_grad(['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase1(TestGatherOp):
def config(self):
"""
For one dimension input
"""
self.x_shape = (10)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase2(TestGatherOp):
def config(self):
"""
For int64_t index type
"""
self.x_shape = (10)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int64"
class TestCase3(TestGatherOp):
def config(self):
"""
For other input type
"""
self.x_shape = (10, 20)
self.x_type = "double"
self.index = [1, 3, 5]
self.index_type = "int64"
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册