diff --git a/paddle/fluid/operators/index_select_op.cu b/paddle/fluid/operators/index_select_op.cu index 36a91d98a2aded6cda5db6c294c8ee78bb5dd070..752e8b277da75f738d4b3abc3ab72d1140011f06 100644 --- a/paddle/fluid/operators/index_select_op.cu +++ b/paddle/fluid/operators/index_select_op.cu @@ -12,18 +12,185 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_select_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +__global__ void index_select_cuda_kernel(const T* input, T* output, + const IndexT* index, int64_t N, + int64_t stride, int64_t size, + int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + IndexT src_dim_idx = index[dim_idx]; + int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride; + output[idx] = input[input_idx]; +} + +template +__global__ void index_select_grad_cuda_kernel(const T* output_grad, + T* input_grad, + const IndexT* index, int64_t nums, + int64_t N, int64_t stride, + int64_t size, int64_t delta) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int64_t pre_idx = idx / (stride * size); + int64_t dim_idx = idx % (stride * size) / stride; + int64_t begin_idx = idx + (delta * pre_idx - dim_idx) * stride; + + input_grad[idx] = 0.0; + for (int64_t i = 0; i < nums; i++) { + if (index[i] == dim_idx) { + input_grad[idx] += output_grad[begin_idx + i * stride]; + } + } +} + +template +class IndexSelectCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* index = context.Input("Index"); + auto* out = context.Output("Out"); + int dim = context.Attr("dim"); + auto input_dim = in->dims(); + auto output_dim = out->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = output_dim[dim]; + int64_t delta = input_dim[dim] - size; + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + auto* in_data = in->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + int64_t numel = out->numel(); + + auto stream = + context.template device_context().stream(); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_select_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data, + numel, stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } else { + const int* index_data = index->data(); + index_select_cuda_kernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>( + in_data, out_data, index_data, numel, stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } + } +}; + +template +class IndexSelectGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* output_grad = context.Input(framework::GradVarName("Out")); + auto* in_grad = context.Output(framework::GradVarName("X")); + auto* index = context.Input("Index"); + + auto* output_grad_data = output_grad->data(); + auto* in_grad_data = in_grad->mutable_data(context.GetPlace()); + + int dim = context.Attr("dim"); + auto input_dim = in_grad->dims(); + auto output_dim = output_grad->dims(); + dim = dim >= 0 ? dim : dim + input_dim.size(); + auto stride_dim = framework::stride(input_dim); + int64_t stride = stride_dim[dim]; + int64_t size = input_dim[dim]; + int64_t delta = output_dim[dim] - size; + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + int64_t numel = in_grad->numel(); + int64_t index_nums = index->numel(); + + auto stream = + context.template device_context().stream(); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + index_select_grad_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, + index_data, index_nums, numel, + stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } else { + const int* index_data = index->data(); + index_select_grad_cuda_kernel<<< + (numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data, + index_data, index_nums, numel, + stride, size, delta); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); + } + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( index_select, - ops::IndexSelectKernel, - ops::IndexSelectKernel, - ops::IndexSelectKernel, - ops::IndexSelectKernel); + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel, + ops::IndexSelectCUDAKernel); REGISTER_OP_CUDA_KERNEL( index_select_grad, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel, - ops::IndexSelectGradKernel); + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel, + ops::IndexSelectGradCUDAKernel);