From 6fbf097bccf77f74927e7a19aa879182088558ca Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 11 Sep 2017 20:11:56 -0700 Subject: [PATCH] Mark thrust::device_ptr in transform Fix TravisCI --- paddle/platform/transform.h | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h index fcd300f2d98..c80446b45c8 100644 --- a/paddle/platform/transform.h +++ b/paddle/platform/transform.h @@ -19,13 +19,43 @@ #include "paddle/platform/place.h" #include +#include #ifdef __NVCC__ +#include #include #endif namespace paddle { namespace platform { +#ifdef __NVCC__ +template +struct DevicePtrCast; + +template +struct DevicePtrCast { + using ELEM = typename std::remove_pointer::type; + using RTYPE = thrust::device_ptr; + + inline thrust::device_ptr operator()(ELEM* ele) const { + return thrust::device_pointer_cast(ele); + } +}; + +template +struct DevicePtrCast { + using RTYPE = T; + inline RTYPE operator()(RTYPE it) const { return it; } +}; + +template +auto DevCast(T t) -> + typename DevicePtrCast::value>::RTYPE { + DevicePtrCast::value> cast; + return cast(t); +} +#endif + // Transform on host or device. It provides the same API in std library. template @@ -35,7 +65,7 @@ void Transform(Place place, InputIter first, InputIter last, OutputIter result, std::transform(first, last, result, op); } else { #ifdef __NVCC__ - thrust::transform(first, last, result, op); + thrust::transform(DevCast(first), DevCast(last), DevCast(result), op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif @@ -50,7 +80,8 @@ void Transform(Place place, InputIter1 first1, InputIter1 last1, std::transform(first1, last1, first2, result, op); } else { #ifdef __NVCC__ - thrust::transform(first1, last1, first2, result, op); + thrust::transform(DevCast(first1), DevCast(last1), DevCast(first2), + DevCast(result), op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif -- GitLab