未验证 提交 9956763e 编写于 作者: C chentianyu03 提交者: GitHub

[Pten] add cuda implement of cast kernel (#37610)

* add cuda implement of cast kernel

* remove bfloat16 when defined paddle_with_hip
上级 2bb3f0b5
......@@ -16,8 +16,8 @@
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/cuda/manipulation.h"
#include "paddle/pten/kernels/cuda/utils.h"
#include "paddle/pten/kernels/functions/cuda/cast_kernel_impl.h"
#include "paddle/pten/kernels/functions/general/manipulation.h"
#include "paddle/pten/kernels/functions/math/cast_func.h"
namespace pten {
......@@ -123,8 +123,7 @@ void Cast(const CUDAContext& dev_ctx,
DataType in_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
math::CastKernelImpl<CUDAContext, T, data_t>(
dev_ctx, x, out);
detail::CastCUDAKernelImpl<T, data_t>(dev_ctx, x, out);
}));
}
......@@ -158,23 +157,32 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int8_t,
int,
int64_t) {}
// todo: Hip need support bfloat16
PT_REGISTER_KERNEL("cast",
CUDA,
ANY,
pten::Cast,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL("cast", \
CUDA, \
ANY, \
pten::Cast, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
}
#if !defined(PADDLE_WITH_HIP)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
#else
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif
PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
CUDA,
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace pten {
namespace detail {
using CUDAContext = paddle::platform::CUDADeviceContext;
template <typename InT, typename OutT, int VecSize>
__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
using LoadT = paddle::platform::AlignedVector<InT, VecSize>;
using StoreT = paddle::platform::AlignedVector<OutT, VecSize>;
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = idx * VecSize; i < N;
i += blockDim.x * gridDim.x * VecSize) {
LoadT in_val;
paddle::platform::Load<InT, VecSize>(&in[i], &in_val);
StoreT out_val;
#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_val[j] = static_cast<OutT>(in_val[j]);
}
paddle::platform::Store<OutT, VecSize>(out_val, &out[i]);
}
}
template <typename InT, typename OutT>
__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast<OutT>(in[index]); }
}
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const CUDAContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_data = x.data<InT>();
auto size = x.numel();
auto* out_data = out->mutable_data<OutT>();
paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
if (!std::is_same<InT, OutT>::value && vec_size == 4 && size % 4 == 0) {
VecCastCUDAKernel<InT, OutT, 4><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(
in_data, size, out_data);
} else {
CastCUDAKernel<InT, OutT><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(in_data, size, out_data);
}
}
} // namespace detail
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册