未验证 提交 15cbf81b 编写于 作者: S sneaxiy 提交者: GitHub

try to expose cast with ptr function (#38598)

上级 de26b88b
......@@ -54,13 +54,10 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
}
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_data = x.data<InT>();
auto size = x.numel();
auto* out_data = out->mutable_data<OutT>();
void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx,
const InT* in_data,
OutT* out_data,
int64_t size) {
paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
......@@ -78,6 +75,16 @@ void CastCUDAKernelImpl(const GPUContext& dev_ctx,
}
}
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_data = x.data<InT>();
auto size = x.numel();
auto* out_data = out->mutable_data<OutT>();
CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size);
}
template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx,
const DenseTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册