From 9956763eade996b2cf37418eb84cefcbc3a72bf0 Mon Sep 17 00:00:00 2001
From: chentianyu03 <chentianyu03@baidu.com>
Date: Mon, 29 Nov 2021 15:04:21 +0800
Subject: [PATCH] [Pten] add cuda implement of cast kernel (#37610)

* add cuda implement of cast kernel

* remove bfloat16 when defined paddle_with_hip
---
 paddle/pten/kernels/cuda/manipulation.cu      | 48 ++++++-----
 .../kernels/functions/cuda/cast_kernel_impl.h | 79 +++++++++++++++++++
 2 files changed, 107 insertions(+), 20 deletions(-)
 create mode 100644 paddle/pten/kernels/functions/cuda/cast_kernel_impl.h

diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu
index f4bf9322047..22ada75304f 100644
--- a/paddle/pten/kernels/cuda/manipulation.cu
+++ b/paddle/pten/kernels/cuda/manipulation.cu
@@ -16,8 +16,8 @@
 #include "paddle/pten/infermeta/unary.h"
 #include "paddle/pten/kernels/cuda/manipulation.h"
 #include "paddle/pten/kernels/cuda/utils.h"
+#include "paddle/pten/kernels/functions/cuda/cast_kernel_impl.h"
 #include "paddle/pten/kernels/functions/general/manipulation.h"
-#include "paddle/pten/kernels/functions/math/cast_func.h"
 
 namespace pten {
 
@@ -123,8 +123,7 @@ void Cast(const CUDAContext& dev_ctx,
           DataType in_dtype,
           DenseTensor* out) {
   PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
-                       math::CastKernelImpl<CUDAContext, T, data_t>(
-                           dev_ctx, x, out);
+                       detail::CastCUDAKernelImpl<T, data_t>(dev_ctx, x, out);
                      }));
 }
 
@@ -158,23 +157,32 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
                    int8_t,
                    int,
                    int64_t) {}
-// todo: Hip need support bfloat16
-PT_REGISTER_KERNEL("cast",
-                   CUDA,
-                   ANY,
-                   pten::Cast,
-                   float,
-                   double,
-                   int,
-                   int64_t,
-                   int16_t,
-                   bool,
-                   uint8_t,
-                   paddle::platform::float16,
-                   paddle::platform::complex<float>,
-                   paddle::platform::complex<double>) {
-  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
-}
+
+#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
+  PT_REGISTER_KERNEL("cast",                            \
+                     CUDA,                              \
+                     ANY,                               \
+                     pten::Cast,                        \
+                     float,                             \
+                     double,                            \
+                     int,                               \
+                     int64_t,                           \
+                     int16_t,                           \
+                     bool,                              \
+                     uint8_t,                           \
+                     paddle::platform::float16,         \
+                     paddle::platform::complex<float>,  \
+                     paddle::platform::complex<double>, \
+                     ##__VA_ARGS__) {                   \
+    kernel->OutputAt(0).SetDataType(                    \
+        paddle::experimental::DataType::UNDEFINED);     \
+  }
+
+#if !defined(PADDLE_WITH_HIP)
+PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
+#else
+PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
+#endif
 
 PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
                                 CUDA,
diff --git a/paddle/pten/kernels/functions/cuda/cast_kernel_impl.h b/paddle/pten/kernels/functions/cuda/cast_kernel_impl.h
new file mode 100644
index 00000000000..435da644356
--- /dev/null
+++ b/paddle/pten/kernels/functions/cuda/cast_kernel_impl.h
@@ -0,0 +1,79 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include "paddle/fluid/platform/cuda_helper.h"
+#include "paddle/fluid/platform/float16.h"
+#include "paddle/pten/core/dense_tensor.h"
+
+#include "paddle/fluid/platform/aligned_vector.h"
+#include "paddle/fluid/platform/gpu_launch_config.h"
+namespace pten {
+namespace detail {
+using CUDAContext = paddle::platform::CUDADeviceContext;
+
+template <typename InT, typename OutT, int VecSize>
+__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
+  using LoadT = paddle::platform::AlignedVector<InT, VecSize>;
+  using StoreT = paddle::platform::AlignedVector<OutT, VecSize>;
+
+  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
+  for (int64_t i = idx * VecSize; i < N;
+       i += blockDim.x * gridDim.x * VecSize) {
+    LoadT in_val;
+    paddle::platform::Load<InT, VecSize>(&in[i], &in_val);
+
+    StoreT out_val;
+#pragma unroll
+    for (int j = 0; j < VecSize; j++) {
+      out_val[j] = static_cast<OutT>(in_val[j]);
+    }
+
+    paddle::platform::Store<OutT, VecSize>(out_val, &out[i]);
+  }
+}
+
+template <typename InT, typename OutT>
+__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
+  CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
+}
+
+template <typename InT, typename OutT>
+void CastCUDAKernelImpl(const CUDAContext& dev_ctx,
+                        const DenseTensor& x,
+                        DenseTensor* out) {
+  auto* in_data = x.data<InT>();
+  auto size = x.numel();
+  auto* out_data = out->mutable_data<OutT>();
+
+  paddle::platform::GpuLaunchConfig config =
+      paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
+  int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
+  if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
+    VecCastCUDAKernel<InT, OutT, 4><<<config.block_per_grid,
+                                      config.thread_per_block,
+                                      0,
+                                      dev_ctx.stream()>>>(
+        in_data, size, out_data);
+  } else {
+    CastCUDAKernel<InT, OutT><<<config.block_per_grid,
+                                config.thread_per_block,
+                                0,
+                                dev_ctx.stream()>>>(in_data, size, out_data);
+  }
+}
+
+}  // namespace detail
+
+}  // namespace pten
-- 
GitLab