diff --git a/paddle/fluid/operators/index_sample_op.cu b/paddle/fluid/operators/index_sample_op.cu index 1dc7a128edc472f277caad7b3ee842616ad17501..c8488eefb984f26559b32104be2fcc38068a5ae5 100644 --- a/paddle/fluid/operators/index_sample_op.cu +++ b/paddle/fluid/operators/index_sample_op.cu @@ -12,7 +12,170 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/index_sample_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +__global__ void IndexSampleForward(const IndexT* index, const T* in_data, + T* out_data, size_t index_length, + size_t input_length, size_t batch_size) { + int index_i = blockDim.x * blockIdx.x + threadIdx.x; + int index_j = blockDim.y * blockIdx.y + threadIdx.y; + int index_idx = index_j * index_length + index_i; + int in_idx = index_j * input_length + index_i; + + if (index_i < index_length & index_j < batch_size) { + IndexT sample_idx = index[index_idx]; + out_data[index_idx] = in_data[in_idx - index_i + sample_idx]; + } +} + +template +__global__ void IndexSampleGrad(const IndexT* index, T* in_grad, + const T* out_grad, size_t index_length, + size_t input_length, size_t batch_size, + bool same_data_in_row = true) { + int index_i = blockDim.x * blockIdx.x + threadIdx.x; + int index_j = blockDim.y * blockIdx.y + threadIdx.y; + int index_idx = index_j * index_length + index_i; + int in_idx = index_j * input_length + index_i; + + if (index_i < index_length & index_j < batch_size) { + IndexT sample_idx = index[index_idx]; + if (same_data_in_row) { + platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]), + out_grad[sample_idx]); + } else { + in_grad[in_idx - index_i + sample_idx] = out_grad[sample_idx]; + } + } +} + +template +class IndexSampleKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* index = ctx.Input("Index"); + auto* output = ctx.Output("Out"); + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + const auto* in_data = input->data(); + auto* out_data = output->mutable_data(ctx.GetPlace()); + auto stream = + ctx.template device_context().stream(); + + auto input_dim = input->dims(); + auto index_dim = index->dims(); + size_t batch_size = input_dim[0]; + size_t input_length = input_dim[1]; + size_t index_length = index_dim[1]; + + auto block_width = platform::RoundToPowerOfTwo(index_length); + int block_height = + platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; + + dim3 block_dim(block_width, block_height); + dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, + (batch_size + block_dim.y - 1) / block_dim.y); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + IndexSampleForward<<>>( + index_data, in_data, out_data, index_length, input_length, + batch_size); + } else if (index_type == framework::proto::VarType::INT32) { + const int* index_data = index->data(); + IndexSampleForward<<>>( + index_data, in_data, out_data, index_length, input_length, + batch_size); + } + } +}; + +template +class IndexSampleGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + auto* index = ctx.Input("Index"); + + const auto* output_grad_data = output_grad->data(); + auto* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + + const auto& index_type = index->type(); + bool index_type_match = index_type == framework::proto::VarType::INT64 || + index_type == framework::proto::VarType::INT32; + PADDLE_ENFORCE_EQ(index_type_match, true, + platform::errors::InvalidArgument( + "Input(Index) holds the wrong type, it holds %s, but " + "desires to be %s or %s", + paddle::framework::DataTypeToString(index_type), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64))); + + auto stream = + ctx.template device_context().stream(); + auto input_num = input_grad->numel(); + auto input_dim = input_grad->dims(); + auto index_dim = index->dims(); + size_t batch_size = index_dim[0]; + size_t input_length = input_dim[1]; + size_t index_length = index_dim[1]; + bool same_data_in_index_row = index_length == 1 ? false : true; + + auto block_width = platform::RoundToPowerOfTwo(index_length); + auto block_height = + platform::RoundToPowerOfTwo(index_length * batch_size) / block_width; + dim3 block_dim(block_width, block_height); + dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x, + (batch_size + block_dim.y - 1) / block_dim.y); + + math::SetConstant set_zero; + auto& dev_ctx = ctx.template device_context(); + set_zero(dev_ctx, input_grad, static_cast(0)); + + if (index_type == framework::proto::VarType::INT64) { + const int64_t* index_data = index->data(); + IndexSampleGrad<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size, same_data_in_index_row); + } else if (index_type == framework::proto::VarType::INT32) { + const int* index_data = index->data(); + IndexSampleGrad<<>>( + index_data, input_grad_data, output_grad_data, index_length, + input_length, batch_size, same_data_in_index_row); + } + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( diff --git a/python/paddle/fluid/tests/unittests/test_index_sample_op.py b/python/paddle/fluid/tests/unittests/test_index_sample_op.py index f640c0531192d65a686e0f21be5bedb9eb0497fb..c1a8299592a2b4fc9d70ce760e0f277d3ed9664f 100644 --- a/python/paddle/fluid/tests/unittests/test_index_sample_op.py +++ b/python/paddle/fluid/tests/unittests/test_index_sample_op.py @@ -92,9 +92,9 @@ class TestCase4(TestIndexSampleOp): """ For int64 index type """ - self.x_shape = (10, 100) + self.x_shape = (10, 128) self.x_type = "float64" - self.index_shape = (10, 10) + self.index_shape = (10, 64) self.index_type = "int64"