bincount_kernel.cu 6.0 KB
Newer Older
0
0x45f 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/bincount_kernel.h"

0
0x45f 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

using paddle::platform::PADDLE_CUDA_NUM_THREADS;

inline int GET_BLOCKS(const int N) {
  return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}

template <typename T, typename InputT, typename OutT>
__global__ void KernelBincount(const InputT* input,
                               const int total_elements,
                               const bool has_weights,
                               const T* weights,
                               OutT* output) {
  if (!has_weights) {
    for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
      paddle::platform::CudaAtomicAdd(&output[input[i]], 1L);
    }
  } else {
    for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
      paddle::platform::CudaAtomicAdd(&output[input[i]],
                                      static_cast<OutT>(weights[i]));
    }
  }
}

template <typename Context, typename T, typename InputT>
void BincountCUDAInner(const Context& dev_ctx,
                       const DenseTensor& x,
52
                       const paddle::optional<DenseTensor>& weights,
0
0x45f 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
                       int minlength,
                       DenseTensor* out) {
  const DenseTensor* input = &x;
  DenseTensor* output = out;
  const InputT* input_data = input->data<InputT>();

  const int input_numel = input->numel();

  if (input_data == nullptr) {
    phi::DDim out_dim{0};
    output->Resize(out_dim);
    dev_ctx.template Alloc<T>(output);
    return;
  }
  auto input_x = EigenVector<InputT>::Flatten(*input);
  DenseTensor input_min_t, input_max_t;
  input_max_t.Resize({1});
  auto* input_max_data = dev_ctx.template Alloc<InputT>(&input_max_t);
  input_min_t.Resize({1});
  auto* input_min_data = dev_ctx.template Alloc<InputT>(&input_min_t);

  auto input_max_scala = EigenScalar<InputT>::From(input_max_t);
  auto input_min_scala = EigenScalar<InputT>::From(input_min_t);

  auto* place = dev_ctx.eigen_device();
  input_max_scala.device(*place) = input_x.maximum();
  input_min_scala.device(*place) = input_x.minimum();

  DenseTensor input_min_cpu, input_max_cpu;
  paddle::framework::TensorCopySync(
      input_max_t, phi::CPUPlace(), &input_max_cpu);
  paddle::framework::TensorCopySync(
      input_min_t, phi::CPUPlace(), &input_min_cpu);

  InputT input_min = input_min_cpu.data<InputT>()[0];

  PADDLE_ENFORCE_GE(
      input_min,
      static_cast<InputT>(0),
      phi::errors::InvalidArgument(
          "The elements in input tensor must be non-negative ints"));

  int64_t output_size =
      static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;

  output_size = std::max(output_size, static_cast<int64_t>(minlength));
  phi::DDim out_dim{output_size};
  output->Resize(out_dim);

  bool has_weights = weights.is_initialized();

  const T* weights_data = has_weights ? weights->data<T>() : nullptr;
  auto stream = dev_ctx.stream();

  if (!has_weights) {
    int64_t* output_data = dev_ctx.template Alloc<int64_t>(output);
    phi::funcs::SetConstant<Context, int64_t>()(dev_ctx, output, 0L);

111 112 113
    KernelBincount<T, InputT, int64_t>
        <<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
            input_data, input_numel, has_weights, weights_data, output_data);
0
0x45f 已提交
114 115 116 117 118 119 120 121 122
  } else {
    const auto& weights_type =
        paddle::framework::TransToProtoVarType(weights->dtype());

    if (weights->dtype() == DataType::FLOAT32) {
      float* output_data = dev_ctx.template Alloc<float>(output);
      phi::funcs::SetConstant<Context, float>()(
          dev_ctx, output, static_cast<float>(0));

123 124 125
      KernelBincount<T, InputT, float>
          <<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
              input_data, input_numel, has_weights, weights_data, output_data);
0
0x45f 已提交
126 127 128 129
    } else {
      double* output_data = dev_ctx.template Alloc<double>(output);
      phi::funcs::SetConstant<Context, double>()(
          dev_ctx, output, static_cast<double>(0));
130 131 132
      KernelBincount<T, InputT, double>
          <<<GET_BLOCKS(input_numel), PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
              input_data, input_numel, has_weights, weights_data, output_data);
0
0x45f 已提交
133 134 135 136 137 138 139
    }
  }
}

template <typename T, typename Context>
void BincountKernel(const Context& dev_ctx,
                    const DenseTensor& x,
140
                    const paddle::optional<DenseTensor>& weights,
141
                    const Scalar& minlength,
0
0x45f 已提交
142
                    DenseTensor* out) {
143 144 145 146 147 148 149 150
  int int_minlength = minlength.to<int>();
  PADDLE_ENFORCE_GE(int_minlength,
                    0,
                    phi::errors::InvalidArgument(
                        "The minlength should be greater than or equal to 0."
                        "But received minlength is %d",
                        int_minlength));

0
0x45f 已提交
151
  if (x.dtype() == DataType::INT32) {
152
    BincountCUDAInner<Context, T, int>(dev_ctx, x, weights, int_minlength, out);
0
0x45f 已提交
153
  } else if (x.dtype() == DataType::INT64) {
154 155
    BincountCUDAInner<Context, T, int64_t>(
        dev_ctx, x, weights, int_minlength, out);
0
0x45f 已提交
156 157 158 159 160 161 162 163 164 165 166 167
  }
}
}  // namespace phi

PD_REGISTER_KERNEL(bincount,
                   GPU,
                   ALL_LAYOUT,
                   phi::BincountKernel,
                   float,
                   double,
                   int,
                   int64_t) {}