/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { // aligned vector generates vectorized load/store on CUDA template struct alignas(sizeof(T) * Size) AlignedVector { T val[Size]; }; template inline int VectorizedSize(const T* pointer) { uint64_t address = reinterpret_cast(pointer); constexpr int vec4 = std::alignment_of>::value; // NOLINT if (address % vec4 == 0) { return 4; } return 1; } template __global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; using LoadT = AlignedVector; using StoreT = AlignedVector; for (int i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) { InT in_vec[VecSize]; LoadT* in_value = reinterpret_cast(&in_vec); *in_value = *reinterpret_cast(&in[i]); OutT out_vec[VecSize]; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { out_vec[ii] = static_cast(in_vec[ii]); } *(reinterpret_cast(&out[i])) = *reinterpret_cast(&out_vec[0]); } } template __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } } template struct CastOpFunctor { const framework::Tensor* in_; framework::Tensor* out_; const platform::CUDADeviceContext& ctx_; CastOpFunctor(const framework::Tensor* in, framework::Tensor* out, const platform::CUDADeviceContext& ctx) : in_(in), out_(out), ctx_(ctx) {} template void apply() const { auto* in = in_->data(); auto size = in_->numel(); auto* out = out_->mutable_data(ctx_.GetPlace()); platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx_, size); int vec_size = VectorizedSize(out); if (!std::is_same::value && vec_size == 4 && size % 4 == 0) { VecCastCUDAKernel<<< config.block_per_grid, config.thread_per_block, 0, ctx_.stream()>>>( in, size, out); } else { CastCUDAKernel<<>>( in, size, out); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; #ifdef PADDLE_WITH_HIP REGISTER_OP_CUDA_KERNEL( cast, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel); #else REGISTER_OP_CUDA_KERNEL( cast, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel); #endif