cast_kernel.cu 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#include "paddle/pten/kernels/cast_kernel.h"

#include "paddle/pten/api/ext/dispatch.h"
18
#include "paddle/pten/backends/gpu/gpu_context.h"
19
#include "paddle/pten/core/kernel_registry.h"
20

21
// See Note [ Why still include the fluid headers? ]
22
#include "paddle/fluid/platform/aligned_vector.h"
23 24
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/gpu/gpu_helper.h"
25
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
26 27
#include "paddle/fluid/platform/float16.h"

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
namespace pten {

template <typename InT, typename OutT, int VecSize>
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
  using LoadT = paddle::platform::AlignedVector<InT, VecSize>;
  using StoreT = paddle::platform::AlignedVector<OutT, VecSize>;

  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (int64_t i = idx * VecSize; i < N;
       i += blockDim.x * gridDim.x * VecSize) {
    LoadT in_val;
    paddle::platform::Load<InT, VecSize>(&in[i], &in_val);

    StoreT out_val;
#pragma unroll
    for (int j = 0; j < VecSize; j++) {
      out_val[j] = static_cast<OutT>(in_val[j]);
    }

    paddle::platform::Store<OutT, VecSize>(out_val, &out[i]);
  }
}

template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
  CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}

template <typename InT, typename OutT>
57 58 59 60
void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx,
                               const InT* in_data,
                               OutT* out_data,
                               int64_t size) {
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  paddle::platform::GpuLaunchConfig config =
      paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
  int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
  if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
    VecCastCUDAKernel<InT, OutT, 4><<<config.block_per_grid,
                                      config.thread_per_block,
                                      0,
                                      dev_ctx.stream()>>>(
        in_data, size, out_data);
  } else {
    CastCUDAKernel<InT, OutT><<<config.block_per_grid,
                                config.thread_per_block,
                                0,
                                dev_ctx.stream()>>>(in_data, size, out_data);
  }
}

78 79 80 81 82 83 84 85 86 87
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
                        const DenseTensor& x,
                        DenseTensor* out) {
  auto* in_data = x.data<InT>();
  auto size = x.numel();
  auto* out_data = out->mutable_data<OutT>();
  CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size);
}

88 89 90 91 92
template <typename T, typename Context>
void CastKernel(const Context& dev_ctx,
                const DenseTensor& x,
                DataType out_dtype,
                DenseTensor* out) {
93 94 95 96
  PD_VISIT_ALL_TYPES(out_dtype, "CastCUDAKernelImpl", ([&] {
                       CastCUDAKernelImpl<T, data_t>(dev_ctx, x, out);
                     }));
}
97 98

}  // namespace pten
99 100 101 102 103

#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...)     \
  PT_REGISTER_CTX_KERNEL(cast,                              \
                         GPU,                               \
                         ALL_LAYOUT,                        \
104
                         pten::CastKernel,                  \
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
                         float,                             \
                         double,                            \
                         int,                               \
                         int64_t,                           \
                         int16_t,                           \
                         bool,                              \
                         uint8_t,                           \
                         paddle::platform::float16,         \
                         paddle::platform::complex<float>,  \
                         paddle::platform::complex<double>, \
                         ##__VA_ARGS__) {                   \
    kernel->OutputAt(0).SetDataType(                        \
        paddle::experimental::DataType::UNDEFINED);         \
  }

#if !defined(PADDLE_WITH_HIP)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
#else
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif