one_hot_op.cu 3.5 KB
Newer Older
1
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Y
Yang yaming 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/one_hot_op.h"
D
dzhwinter 已提交
16
#include "paddle/fluid/platform/cuda_primitives.h"
Y
Yi Wang 已提交
17
#include "paddle/fluid/platform/gpu_info.h"
Y
Yang yaming 已提交
18 19 20 21 22 23 24 25 26

namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;

template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
                                 const int64_t numel, const int depth) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
27
  if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
Y
Yang yaming 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
    *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
  }
}

template <typename DeviceContext, typename InT>
struct OneHotOpCUDAFunctor {
  const framework::LoDTensor* in_;
  framework::LoDTensor* out_;
  const DeviceContext& ctx_;
  int depth_;

  OneHotOpCUDAFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
                      int depth, const DeviceContext& ctx)
      : in_(in), out_(out), depth_(depth), ctx_(ctx) {}

  template <typename OutT>
D
dzhwinter 已提交
44
  void apply() const {
Y
Yang yaming 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    auto* p_in_data = in_->data<InT>();
    auto numel = in_->numel();
    auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
    auto stream = ctx_.stream();
    math::set_constant(ctx_, out_, 0.0);

    FillOutputKernel<<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
                           PADDLE_CUDA_NUM_THREADS,
                       PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        p_in_data, p_out_data, numel, depth_);
  }
};

using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class OneHotCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<LoDTensor>("X");
    auto* out = context.Output<LoDTensor>("Out");

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
    int depth = -1;
    if (context.HasInput("depth_tensor")) {
      auto* depth_tensor = context.Input<framework::Tensor>("depth_tensor");
      if (platform::is_gpu_place(depth_tensor->place())) {
        framework::Tensor temp;
        TensorCopySync(*depth_tensor, platform::CPUPlace(), &temp);
        depth = *temp.data<int32_t>();
      } else {
        depth = *depth_tensor->data<int32_t>();
      }

      auto in_dims = in->dims();
      framework::DDim out_dims(in_dims);
      out_dims[out_dims.size() - 1] = depth;
      out->Resize(out_dims);
    } else {
      depth = context.Attr<int>("depth");
    }
Y
Yang yaming 已提交
84
    framework::VisitDataType(
85 86
        static_cast<framework::proto::VarType::Type>(
            context.Attr<int>("dtype")),
Y
Yang yaming 已提交
87 88 89 90 91 92 93 94 95 96 97 98
        OneHotOpCUDAFunctor<DeviceContext, T>(
            in, out, depth, context.template device_context<DeviceContext>()));
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    one_hot, ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int>,
    ops::OneHotCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);