// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/bincount_kernel.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { using paddle::platform::PADDLE_CUDA_NUM_THREADS; inline int GET_BLOCKS(const int N) { return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; } template __global__ void KernelBincount(const InputT* input, const int total_elements, const bool has_weights, const T* weights, OutT* output) { if (!has_weights) { for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { paddle::platform::CudaAtomicAdd(&output[input[i]], 1L); } } else { for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { paddle::platform::CudaAtomicAdd(&output[input[i]], static_cast(weights[i])); } } } template void BincountCUDAInner(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& weights, int minlength, DenseTensor* out) { const DenseTensor* input = &x; DenseTensor* output = out; const InputT* input_data = input->data(); const int input_numel = input->numel(); if (input_data == nullptr) { phi::DDim out_dim{0}; output->Resize(out_dim); dev_ctx.template Alloc(output); return; } auto input_x = EigenVector::Flatten(*input); DenseTensor input_min_t, input_max_t; input_max_t.Resize({1}); auto* input_max_data = dev_ctx.template Alloc(&input_max_t); input_min_t.Resize({1}); auto* input_min_data = dev_ctx.template Alloc(&input_min_t); auto input_max_scala = EigenScalar::From(input_max_t); auto input_min_scala = EigenScalar::From(input_min_t); auto* place = dev_ctx.eigen_device(); input_max_scala.device(*place) = input_x.maximum(); input_min_scala.device(*place) = input_x.minimum(); DenseTensor input_min_cpu, input_max_cpu; paddle::framework::TensorCopySync( input_max_t, phi::CPUPlace(), &input_max_cpu); paddle::framework::TensorCopySync( input_min_t, phi::CPUPlace(), &input_min_cpu); InputT input_min = input_min_cpu.data()[0]; PADDLE_ENFORCE_GE( input_min, static_cast(0), phi::errors::InvalidArgument( "The elements in input tensor must be non-negative ints")); int64_t output_size = static_cast(input_max_cpu.data()[0]) + 1L; output_size = std::max(output_size, static_cast(minlength)); phi::DDim out_dim{output_size}; output->Resize(out_dim); bool has_weights = weights.is_initialized(); const T* weights_data = has_weights ? weights->data() : nullptr; auto stream = dev_ctx.stream(); if (!has_weights) { int64_t* output_data = dev_ctx.template Alloc(output); phi::funcs::SetConstant()(dev_ctx, output, 0L); KernelBincount <<>>( input_data, input_numel, has_weights, weights_data, output_data); } else { const auto& weights_type = paddle::framework::TransToProtoVarType(weights->dtype()); if (weights->dtype() == DataType::FLOAT32) { float* output_data = dev_ctx.template Alloc(output); phi::funcs::SetConstant()( dev_ctx, output, static_cast(0)); KernelBincount <<>>( input_data, input_numel, has_weights, weights_data, output_data); } else { double* output_data = dev_ctx.template Alloc(output); phi::funcs::SetConstant()( dev_ctx, output, static_cast(0)); KernelBincount <<>>( input_data, input_numel, has_weights, weights_data, output_data); } } } template void BincountKernel(const Context& dev_ctx, const DenseTensor& x, const paddle::optional& weights, int minlength, DenseTensor* out) { if (x.dtype() == DataType::INT32) { BincountCUDAInner(dev_ctx, x, weights, minlength, out); } else if (x.dtype() == DataType::INT64) { BincountCUDAInner(dev_ctx, x, weights, minlength, out); } } } // namespace phi PD_REGISTER_KERNEL(bincount, GPU, ALL_LAYOUT, phi::BincountKernel, float, double, int, int64_t) {}