diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index aead4e2e00f75710d13cad52d0cd31e37e16eb0c..2943d409a2e42a01f3d6bbb28c8c89c2409c45f5 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -30,7 +30,7 @@ class ArgsortOp : public framework::OperatorWithKernel { "Output(Indices) of ArgsortOp should not be null."); auto in_dims = ctx->GetInputDim("X"); - int axis = static_cast(ctx->Attrs().Get("axis")); + int axis = ctx->Attrs().Get("axis"); auto num_dims = in_dims.size(); PADDLE_ENFORCE(axis < num_dims, diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index d1fbd28e1b6cf5e695ced57c0c49244e1b7eaf32..eac18ea3a0350697c8e1a96b71cc4a9068b26be8 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -26,6 +26,42 @@ namespace operators { using Tensor = framework::Tensor; using platform::PADDLE_CUDA_NUM_THREADS; +__global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, + int axis, int64_t n, int64_t* trg_idx, + int64_t* med_ids) { + int64_t index = threadIdx.x + blockDim.x * blockIdx.x; + if (index < n) { + int64_t* shape_out_axis = new int64_t[dims_size - 1]; + int64_t* dims_out_axis = new int64_t[dims_size - 1]; + int64_t tmp = index; + int64_t pos_in_axis = 0; + int64_t i = dims_size - 2; + int64_t dim_axis = 0; + for (int64_t j = dims_size - 1; j >= 0; --j) { + int64_t dim = in_dims[j]; + if (j != axis) { + shape_out_axis[i] = tmp % dim; + dims_out_axis[i] = dim; + i--; + } else { + dim_axis = dim; + pos_in_axis = tmp % dim_axis; + } + tmp /= dim; + } + int64_t group = (dims_size > 1) ? shape_out_axis[0] : 0; + for (int64_t j = 0; j < dims_size - 2; ++j) { + group = group * dims_out_axis[j + 1] + shape_out_axis[j + 1]; + } + + int64_t traget_idx = group * dim_axis + pos_in_axis; + trg_idx[index] = traget_idx; + med_ids[traget_idx] = pos_in_axis; + delete[] shape_out_axis; + delete[] dims_out_axis; + } +} + template __global__ void PermuteInData(const T* in, const int64_t* trg_idx, int64_t n, T* med_out) { @@ -76,50 +112,27 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { int64_t numel = input->numel(); int64_t groups = numel / in_dims[axis]; - // Mediate tensor for sorting - Tensor mediate_output; + std::vector in_dims_vec = vectorize(in_dims); + thrust::device_vector in_dims_dev(in_dims_vec.begin(), + in_dims_vec.end()); + int64_t* in_dims_data = thrust::raw_pointer_cast(in_dims_dev.data()); + // Mediate tensor for sorting data and indices + Tensor mediate_output, mediate_indices; T* med_out_data = mediate_output.mutable_data(input->dims(), ctx.GetPlace()); - - // The target index of each elemement in mediate tensor - std::vector target_idx(numel, 0); - // To record the index along the given axis for the data in mediate tensor - std::vector mediate_indices(numel, 0); - std::vector in_dims_out_axis = vectorize(in_dims); - in_dims_out_axis.erase(in_dims_out_axis.begin() + axis); - for (int64_t index = 0; index < numel; ++index) { - int64_t tmp = index; - int64_t pos_in_axis = 0; - std::vector shape; - for (int64_t j = in_dims.size() - 1; j >= 0; --j) { - if (j != axis) { - shape.push_back(tmp % in_dims[j]); - } else { - pos_in_axis = tmp % in_dims[j]; - } - tmp /= in_dims[j]; - } - std::reverse(shape.begin(), shape.end()); - int64_t group = (shape.size() > 0) ? shape[0] : 0; - for (size_t j = 0; j < shape.size() - 1; ++j) { - group = group * in_dims_out_axis[j + 1] + shape[j + 1]; - } - - target_idx[index] = group * in_dims[axis] + pos_in_axis; - mediate_indices[target_idx[index]] = pos_in_axis; - } - - thrust::device_vector med_ids_dev(mediate_indices.begin(), - mediate_indices.end()); - int64_t* med_ids_data = thrust::raw_pointer_cast(med_ids_dev.data()); - thrust::device_vector trg_idx_dev(target_idx.begin(), - target_idx.end()); - int64_t* trg_idx = thrust::raw_pointer_cast(trg_idx_dev.data()); + int64_t* med_ids_data = + mediate_indices.mutable_data(in_dims, ctx.GetPlace()); + // Target index of each element along the given axis in the mediate tensors + Tensor trg_idx_t; + int64_t* trg_idx = trg_idx_t.mutable_data(in_dims, ctx.GetPlace()); auto stream = reinterpret_cast( ctx.device_context()) .stream(); - auto num_threads = PADDLE_CUDA_NUM_THREADS; + int num_threads = PADDLE_CUDA_NUM_THREADS; + + ComputeTargetIdx<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( + in_dims_data, in_dims.size(), axis, numel, trg_idx, med_ids_data); PermuteInData<<<(numel - 1) / num_threads + 1, num_threads, 0, stream>>>( in_data, trg_idx, numel, med_out_data);