未验证 提交 15cbf81b 编写于 作者: S sneaxiy 提交者: GitHub

try to expose cast with ptr function (#38598)

上级 de26b88b
...@@ -54,13 +54,10 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { ...@@ -54,13 +54,10 @@ __global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) {
} }
template <typename InT, typename OutT> template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx, void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx,
const DenseTensor& x, const InT* in_data,
DenseTensor* out) { OutT* out_data,
auto* in_data = x.data<InT>(); int64_t size) {
auto size = x.numel();
auto* out_data = out->mutable_data<OutT>();
paddle::platform::GpuLaunchConfig config = paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size); paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size);
int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data); int vec_size = paddle::platform::GetVectorizedSize<OutT>(out_data);
...@@ -78,6 +75,16 @@ void CastCUDAKernelImpl(const GPUContext& dev_ctx, ...@@ -78,6 +75,16 @@ void CastCUDAKernelImpl(const GPUContext& dev_ctx,
} }
} }
template <typename InT, typename OutT>
void CastCUDAKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* in_data = x.data<InT>();
auto size = x.numel();
auto* out_data = out->mutable_data<OutT>();
CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size);
}
template <typename T, typename ContextT> template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx, void Cast(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册