diff --git a/paddle/fluid/operators/where_index_op.cu b/paddle/fluid/operators/where_index_op.cu index bb968743585f7d3574d477ab54cf657ef2646873..b1cd172923ee6dc421cc09b27163422207ea099c 100644 --- a/paddle/fluid/operators/where_index_op.cu +++ b/paddle/fluid/operators/where_index_op.cu @@ -12,7 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#ifdef __NVCC__ +#include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif + +#include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/where_index_op.h" @@ -25,52 +33,124 @@ namespace operators { using CUDADeviceContext = paddle::platform::CUDADeviceContext; template -class CUDAWhereIndexKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* condition = context.Input("Condition"); - auto* out = context.Output("Out"); - - // TODO(zhoukunsheng): Should optimize to ensure GPU is faster than CPU. - framework::Tensor cond_cpu; - framework::TensorCopy(*condition, platform::CPUPlace(), &cond_cpu); - - const T* cond_data = cond_cpu.data(); - int64_t numel = cond_cpu.numel(); - auto dims = cond_cpu.dims(); - int rank = dims.size(); - - thrust::host_vector h_true_index; - for (int64_t i = 0; i < numel; i++) { - if (static_cast(cond_data[i])) { - h_true_index.push_back(i); +__global__ void GetTrueNum(const T *cond_data, const int64_t numel, + int64_t *true_num_array) { + const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { + true_num_array[idx] = + static_cast(static_cast(cond_data[idx])); + } +} + +template +__global__ void SetTrueIndex(int64_t *out_ptr, const T *cond_data, + const int64_t numel, const int64_t *stride_array, + const int64_t rank, + const int64_t *true_num_array) { + const int64_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (int64_t idx = tid; idx < numel; idx += gridDim.x * blockDim.x) { + // true_num_array is calculated by cub::InclusiveSum, + // cause the first element of true_num_array is 1, + // so we need substract 1 to get true index. + const int64_t true_index = true_num_array[idx] - 1; + if (static_cast(cond_data[idx])) { + int64_t rank_index = idx; + for (int j = 0; j < rank; j++) { + const int64_t out_index = rank_index / stride_array[j]; + out_ptr[true_index * rank + j] = out_index; + rank_index -= out_index * stride_array[j]; } } - thrust::device_vector d_true_index = h_true_index; - int64_t* ptr_true_index = thrust::raw_pointer_cast(d_true_index.data()); - - size_t true_num = h_true_index.size(); + } +} +template +class CUDAWhereIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *condition = context.Input("Condition"); + auto *out = context.Output("Out"); + auto &dev_ctx = context.template device_context(); + + const T *cond_data = condition->data(); + const int64_t numel = condition->numel(); + auto dims = condition->dims(); + const int rank = dims.size(); + + auto d_array_mem = memory::Alloc(dev_ctx, (numel + rank) * sizeof(int64_t)); + auto h_array_mem = + memory::Alloc(platform::CPUPlace(), (rank + 1) * sizeof(int64_t)); + + // "stride_array" is an array and len(stride_array)==rank, + // each element is the stride of each dimension -- the length from i to i+1. + int64_t *h_stride_array = reinterpret_cast(h_array_mem->ptr()); + int64_t *d_stride_array = reinterpret_cast(d_array_mem->ptr()); + + // "true_num_array" is an array and len(stride_array)==numel, + // at the beginning, + // "true_num_array" will set 1 if condition[i] == true else 0, + // then it will be calculated by cub::InclusiveSum, + // so that we can get the true number before i as the out index + int64_t *d_true_num_array = d_stride_array + rank; + + // the total_true_num is the total number of condition[i] == true + int64_t *h_total_true_num = h_stride_array + rank; + + // alloce cub memory + size_t cub_size = 0; + cub::DeviceScan::InclusiveSum(nullptr, cub_size, d_true_num_array, + d_true_num_array, numel, dev_ctx.stream()); + auto cub_mem = memory::Alloc(dev_ctx, cub_size * sizeof(int64_t)); + void *cub_data = cub_mem->ptr(); + + // set d_true_num_array[i]=1 if cond_data[i]==true else 0 + const int threads = std::min(numel, static_cast(128)); + const int64_t need_grids = (numel + threads - 1) / threads; + const int grids = std::min(need_grids, static_cast(256)); + GetTrueNum<<>>(cond_data, numel, + d_true_num_array); + + // calculate the inclusive prefix sum of "true_num_array" + // to get the index of "out" tensor, + // and the total number of cond_data[i]==true. + // Example: + // condition: F T T F F F T T + // before: 0 1 1 0 0 0 1 1 + // after: 0 1 2 2 2 2 3 4 + // out: 1 2 6 7 + cub::DeviceScan::InclusiveSum(cub_data, cub_size, d_true_num_array, + d_true_num_array, numel, dev_ctx.stream()); + + // calculate each dimension's stride + h_stride_array[rank - 1] = 1; + for (int i = rank - 2; i >= 0; i--) { + h_stride_array[i] = h_stride_array[i + 1] * dims[i + 1]; + } + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + d_stride_array, platform::CPUPlace(), h_stride_array, + rank * sizeof(int64_t), dev_ctx.stream()); + + // get total ture number and set output size + // the last element of cub::InclusiveSum is the total number + memory::Copy(platform::CPUPlace(), h_total_true_num, + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), + d_true_num_array + numel - 1, sizeof(int64_t), + dev_ctx.stream()); + dev_ctx.Wait(); + + int64_t true_num = *h_total_true_num; out->Resize(framework::make_ddim({static_cast(true_num), rank})); - auto out_ptr = out->mutable_data(context.GetPlace()); + auto out_data = out->mutable_data(context.GetPlace()); if (true_num == 0) { return; } - thrust::host_vector h_stride(rank, 0); - h_stride[rank - 1] = 1; - for (int i = rank - 2; i >= 0; i--) { - h_stride[i] = h_stride[i + 1] * dims[i + 1]; - } - thrust::device_vector d_stride = h_stride; - int64_t* ptr_stride = thrust::raw_pointer_cast(d_stride.data()); - - auto& dev_ctx = context.template device_context(); - WhereIndexFunctor functor(ptr_true_index, true_num, ptr_stride, - rank, out_ptr); - platform::ForRange for_range(dev_ctx, true_num); - for_range(functor); + // using true_num_array and stride_array to calculate the output index + SetTrueIndex<<>>( + out_data, cond_data, numel, d_stride_array, rank, d_true_num_array); } };