cast_op.cu 4.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/cast_op.h"
16
#include "paddle/fluid/platform/aligned_vector.h"
K
Kexin Zhao 已提交
17
#include "paddle/fluid/platform/float16.h"
Z
Zhang Ting 已提交
18 19 20 21 22
#include "paddle/fluid/platform/gpu_launch_config.h"

namespace paddle {
namespace operators {

Z
Zhang Ting 已提交
23 24
template <typename InT, typename OutT, int VecSize>
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
25 26 27
  using LoadT = platform::AlignedVector<InT, VecSize>;
  using StoreT = platform::AlignedVector<OutT, VecSize>;

Z
Zhang Ting 已提交
28
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
29 30
  for (int64_t i = idx * VecSize; i < N;
       i += blockDim.x * gridDim.x * VecSize) {
31 32
    LoadT in_val;
    platform::Load<InT, VecSize>(&in[i], &in_val);
Z
Zhang Ting 已提交
33

34
    StoreT out_val;
Z
Zhang Ting 已提交
35
#pragma unroll
36 37
    for (int j = 0; j < VecSize; j++) {
      out_val[j] = static_cast<OutT>(in_val[j]);
Z
Zhang Ting 已提交
38 39
    }

40
    platform::Store<OutT, VecSize>(out_val, &out[i]);
Z
Zhang Ting 已提交
41 42 43
  }
}

Z
Zhang Ting 已提交
44 45 46 47 48 49
template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
  CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}

template <typename InT>
Z
Zeng Jinle 已提交
50
struct CastCUDAOpFunctor {
Z
Zhang Ting 已提交
51 52 53
  const framework::Tensor* in_;
  framework::Tensor* out_;
  const platform::CUDADeviceContext& ctx_;
Z
Zeng Jinle 已提交
54 55
  CastCUDAOpFunctor(const framework::Tensor* in, framework::Tensor* out,
                    const platform::CUDADeviceContext& ctx)
Z
Zhang Ting 已提交
56 57 58 59 60 61 62 63 64
      : in_(in), out_(out), ctx_(ctx) {}

  template <typename OutT>
  void apply() const {
    auto* in = in_->data<InT>();
    auto size = in_->numel();
    auto* out = out_->mutable_data<OutT>(ctx_.GetPlace());
    platform::GpuLaunchConfig config =
        platform::GetGpuLaunchConfig1D(ctx_, size);
65
    int vec_size = platform::GetVectorizedSize<OutT>(out);
Z
Zhang Ting 已提交
66 67 68 69 70 71 72 73 74
    if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
      VecCastCUDAKernel<InT, OutT, 4><<<
          config.block_per_grid, config.thread_per_block, 0, ctx_.stream()>>>(
          in, size, out);
    } else {
      CastCUDAKernel<InT, OutT><<<config.block_per_grid,
                                  config.thread_per_block, 0, ctx_.stream()>>>(
          in, size, out);
    }
Z
Zhang Ting 已提交
75 76 77
  }
};

Z
Zeng Jinle 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
template <typename InT>
class CastCUDAOpKernel : public framework::OpKernel<InT> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* in = context.Input<framework::Tensor>("X");
    auto* out = context.Output<framework::Tensor>("Out");
    framework::VisitDataType(
        static_cast<framework::proto::VarType::Type>(
            context.Attr<int>("out_dtype")),
        CastCUDAOpFunctor<InT>(
            in, out,
            context.template device_context<platform::CUDADeviceContext>()));
  }
};

Z
Zhang Ting 已提交
93 94
}  // namespace operators
}  // namespace paddle
Y
Yu Yang 已提交
95

96
namespace ops = paddle::operators;
Y
Yiqun Liu 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109
namespace plat = paddle::platform;

#define REGISTER_CAST_CUDA_BASE(op_name, ...)                               \
  REGISTER_OP_CUDA_KERNEL(                                                  \
      op_name, ops::CastCUDAOpKernel<float>, ops::CastCUDAOpKernel<double>, \
      ops::CastCUDAOpKernel<int>, ops::CastCUDAOpKernel<int64_t>,           \
      ops::CastCUDAOpKernel<int16_t>, ops::CastCUDAOpKernel<bool>,          \
      ops::CastCUDAOpKernel<uint8_t>, ops::CastCUDAOpKernel<plat::float16>, \
      ops::CastCUDAOpKernel<plat::complex<float>>,                          \
      ops::CastCUDAOpKernel<plat::complex<double>>, ##__VA_ARGS__);

#if !defined(PADDLE_WITH_HIP)
REGISTER_CAST_CUDA_BASE(cast, ops::CastCUDAOpKernel<plat::bfloat16>)
110
#else
Y
Yiqun Liu 已提交
111
REGISTER_CAST_CUDA_BASE(cast)
112
#endif