提交 6fbf097b 编写于 作者: Y Yu Yang

Mark thrust::device_ptr in transform

Fix TravisCI
上级 c5fa417c
...@@ -19,13 +19,43 @@ ...@@ -19,13 +19,43 @@
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include <algorithm> #include <algorithm>
#include <type_traits>
#ifdef __NVCC__ #ifdef __NVCC__
#include <thrust/device_ptr.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#endif #endif
namespace paddle { namespace paddle {
namespace platform { namespace platform {
#ifdef __NVCC__
template <typename T, bool is_ptr>
struct DevicePtrCast;
template <typename T>
struct DevicePtrCast<T, true> {
using ELEM = typename std::remove_pointer<T>::type;
using RTYPE = thrust::device_ptr<ELEM>;
inline thrust::device_ptr<ELEM> operator()(ELEM* ele) const {
return thrust::device_pointer_cast(ele);
}
};
template <typename T>
struct DevicePtrCast<T, false> {
using RTYPE = T;
inline RTYPE operator()(RTYPE it) const { return it; }
};
template <typename T>
auto DevCast(T t) ->
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
DevicePtrCast<T, std::is_pointer<T>::value> cast;
return cast(t);
}
#endif
// Transform on host or device. It provides the same API in std library. // Transform on host or device. It provides the same API in std library.
template <typename Place, typename InputIter, typename OutputIter, template <typename Place, typename InputIter, typename OutputIter,
typename UnaryOperation> typename UnaryOperation>
...@@ -35,7 +65,7 @@ void Transform(Place place, InputIter first, InputIter last, OutputIter result, ...@@ -35,7 +65,7 @@ void Transform(Place place, InputIter first, InputIter last, OutputIter result,
std::transform(first, last, result, op); std::transform(first, last, result, op);
} else { } else {
#ifdef __NVCC__ #ifdef __NVCC__
thrust::transform(first, last, result, op); thrust::transform(DevCast(first), DevCast(last), DevCast(result), op);
#else #else
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file"); PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
#endif #endif
...@@ -50,7 +80,8 @@ void Transform(Place place, InputIter1 first1, InputIter1 last1, ...@@ -50,7 +80,8 @@ void Transform(Place place, InputIter1 first1, InputIter1 last1,
std::transform(first1, last1, first2, result, op); std::transform(first1, last1, first2, result, op);
} else { } else {
#ifdef __NVCC__ #ifdef __NVCC__
thrust::transform(first1, last1, first2, result, op); thrust::transform(DevCast(first1), DevCast(last1), DevCast(first2),
DevCast(result), op);
#else #else
PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file"); PADDLE_THROW("Do not invoke `Transform<GPUPlace>` in .cc file");
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册