bincount_op.cu 6.1 KB
Newer Older
S
smallv0221 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/bincount_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using platform::PADDLE_CUDA_NUM_THREADS;

inline int GET_BLOCKS(const int N) {
  return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
}

template <typename T, typename InputT, typename OutT>
__global__ void KernelBincount(const InputT* input, const int total_elements,
                               const bool has_weights, const T* weights,
                               OutT* output) {
  if (!has_weights) {
    for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
      paddle::platform::CudaAtomicAdd(&output[input[i]], 1L);
    }
  } else {
    for (int i = threadIdx.x; i < total_elements; i += blockDim.x) {
      paddle::platform::CudaAtomicAdd(&output[input[i]],
                                      static_cast<OutT>(weights[i]));
    }
  }
}

template <typename DeviceContext, typename T, typename InputT>
void BincountCUDAInner(const framework::ExecutionContext& context) {
  const Tensor* input = context.Input<framework::Tensor>("X");
  const Tensor* weights = context.Input<framework::Tensor>("Weights");
  Tensor* output = context.Output<framework::Tensor>("Out");
  auto& minlength = context.Attr<int>("minlength");

  const InputT* input_data = input->data<InputT>();

  const int input_numel = input->numel();

  if (input_data == nullptr) {
    framework::DDim out_dim{0};
    output->Resize(out_dim);
    output->mutable_data<T>(context.GetPlace());
    return;
  }
  auto input_x = framework::EigenVector<InputT>::Flatten(*input);

  framework::Tensor input_min_t, input_max_t;
  auto* input_max_data =
      input_max_t.mutable_data<InputT>({1}, context.GetPlace());
  auto* input_min_data =
      input_min_t.mutable_data<InputT>({1}, context.GetPlace());

  auto input_max_scala = framework::EigenScalar<InputT>::From(input_max_t);
  auto input_min_scala = framework::EigenScalar<InputT>::From(input_min_t);

  auto* place = context.template device_context<DeviceContext>().eigen_device();
  input_max_scala.device(*place) = input_x.maximum();
  input_min_scala.device(*place) = input_x.minimum();

  Tensor input_min_cpu, input_max_cpu;
  TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);
  TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);

  InputT input_min = input_min_cpu.data<InputT>()[0];

  PADDLE_ENFORCE_GE(
      input_min, static_cast<InputT>(0),
      platform::errors::InvalidArgument(
          "The elements in input tensor must be non-negative ints"));

  int64_t output_size =
      static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;

  output_size = std::max(output_size, static_cast<int64_t>(minlength));
  framework::DDim out_dim{output_size};
  output->Resize(out_dim);

  bool has_weights = (weights != nullptr);

  const T* weights_data = has_weights ? weights->data<T>() : nullptr;

  auto stream =
      context.template device_context<platform::CUDADeviceContext>().stream();

  if (!has_weights) {
    int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
    math::SetConstant<DeviceContext, int64_t>()(
        context.template device_context<DeviceContext>(), output, 0L);

    KernelBincount<T, InputT, int64_t><<<GET_BLOCKS(input_numel),
                                         PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        input_data, input_numel, has_weights, weights_data, output_data);
  } else {
    const auto& weights_type = weights->type();

    if (weights_type == framework::proto::VarType::FP32) {
      float* output_data = output->mutable_data<float>(context.GetPlace());
      math::SetConstant<DeviceContext, float>()(
          context.template device_context<DeviceContext>(), output,
          static_cast<float>(0));

      KernelBincount<T, InputT, float><<<GET_BLOCKS(input_numel),
                                         PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
          input_data, input_numel, has_weights, weights_data, output_data);
    } else {
      double* output_data = output->mutable_data<double>(context.GetPlace());
      math::SetConstant<DeviceContext, double>()(
          context.template device_context<DeviceContext>(), output,
          static_cast<double>(0));

      KernelBincount<T, InputT, double><<<GET_BLOCKS(input_numel),
                                          PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
          input_data, input_numel, has_weights, weights_data, output_data);
    }
  }
}

template <typename DeviceContext, typename T>
class BincountCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    const Tensor* input = context.Input<framework::Tensor>("X");
    const auto& input_type = input->type();

    if (input_type == framework::proto::VarType::INT32) {
      BincountCUDAInner<DeviceContext, T, int>(context);
    } else if (input_type == framework::proto::VarType::INT64) {
      BincountCUDAInner<DeviceContext, T, int64_t>(context);
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    bincount, ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int>,
    ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>,
    ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, float>,
    ops::BincountCUDAKernel<paddle::platform::CUDADeviceContext, double>);