diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index f4bf9322047b442b4a78a1d67faa51cfabadfb56..22ada75304f24559c559b9fe10ce32477608a44b 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -16,8 +16,8 @@ #include "paddle/pten/infermeta/unary.h" #include "paddle/pten/kernels/cuda/manipulation.h" #include "paddle/pten/kernels/cuda/utils.h" +#include "paddle/pten/kernels/functions/cuda/cast_kernel_impl.h" #include "paddle/pten/kernels/functions/general/manipulation.h" -#include "paddle/pten/kernels/functions/math/cast_func.h" namespace pten { @@ -123,8 +123,7 @@ void Cast(const CUDAContext& dev_ctx, DataType in_dtype, DenseTensor* out) { PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] { - math::CastKernelImpl( - dev_ctx, x, out); + detail::CastCUDAKernelImpl(dev_ctx, x, out); })); } @@ -158,23 +157,32 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", int8_t, int, int64_t) {} -// todo: Hip need support bfloat16 -PT_REGISTER_KERNEL("cast", - CUDA, - ANY, - pten::Cast, - float, - double, - int, - int64_t, - int16_t, - bool, - uint8_t, - paddle::platform::float16, - paddle::platform::complex, - paddle::platform::complex) { - kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); -} + +#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \ + PT_REGISTER_KERNEL("cast", \ + CUDA, \ + ANY, \ + pten::Cast, \ + float, \ + double, \ + int, \ + int64_t, \ + int16_t, \ + bool, \ + uint8_t, \ + paddle::platform::float16, \ + paddle::platform::complex, \ + paddle::platform::complex, \ + ##__VA_ARGS__) { \ + kernel->OutputAt(0).SetDataType( \ + paddle::experimental::DataType::UNDEFINED); \ + } + +#if !defined(PADDLE_WITH_HIP) +PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) +#else +PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) +#endif PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2", CUDA, diff --git a/paddle/pten/kernels/functions/cuda/cast_kernel_impl.h b/paddle/pten/kernels/functions/cuda/cast_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..435da644356f9467930f4954837a53ed6454529a --- /dev/null +++ b/paddle/pten/kernels/functions/cuda/cast_kernel_impl.h @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/platform/cuda_helper.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/pten/core/dense_tensor.h" + +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/gpu_launch_config.h" +namespace pten { +namespace detail { +using CUDAContext = paddle::platform::CUDADeviceContext; + +template +__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + using LoadT = paddle::platform::AlignedVector; + using StoreT = paddle::platform::AlignedVector; + + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int64_t i = idx * VecSize; i < N; + i += blockDim.x * gridDim.x * VecSize) { + LoadT in_val; + paddle::platform::Load(&in[i], &in_val); + + StoreT out_val; +#pragma unroll + for (int j = 0; j < VecSize; j++) { + out_val[j] = static_cast(in_val[j]); + } + + paddle::platform::Store(out_val, &out[i]); + } +} + +template +__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { + CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } +} + +template +void CastCUDAKernelImpl(const CUDAContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto* in_data = x.data(); + auto size = x.numel(); + auto* out_data = out->mutable_data(); + + paddle::platform::GpuLaunchConfig config = + paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size); + int vec_size = paddle::platform::GetVectorizedSize(out_data); + if (!std::is_same::value && vec_size == 4 && size % 4 == 0) { + VecCastCUDAKernel<<>>( + in_data, size, out_data); + } else { + CastCUDAKernel<<>>(in_data, size, out_data); + } +} + +} // namespace detail + +} // namespace pten