未验证 提交 47609ab2 编写于 作者: Y Yi Wang 提交者: GitHub

Document transform.h and fix cpplint errors (#9913)

上级 b43d87c9
...@@ -18,16 +18,22 @@ limitations under the License. */ ...@@ -18,16 +18,22 @@ limitations under the License. */
#error device_ptr_cast must be include by .cu file #error device_ptr_cast must be include by .cu file
#endif #endif
#include <thrust/device_ptr.h> #include <type_traits> // For std::remove_pointer and std::is_pointer.
#include "thrust/device_ptr.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace details { namespace details {
// PointerToThrustDevicePtr has two speicalizations, one casts a (CUDA
// device) pointer into thrust::device_ptr, the other keeps rest types
// un-casted.
template <typename T, bool is_ptr> template <typename T, bool is_ptr>
struct DevicePtrCast; struct PointerToThrustDevicePtr;
template <typename T> template <typename T>
struct DevicePtrCast<T, true> { struct PointerToThrustDevicePtr<T, true> {
using ELEM = typename std::remove_pointer<T>::type; using ELEM = typename std::remove_pointer<T>::type;
using RTYPE = thrust::device_ptr<ELEM>; using RTYPE = thrust::device_ptr<ELEM>;
...@@ -37,17 +43,26 @@ struct DevicePtrCast<T, true> { ...@@ -37,17 +43,26 @@ struct DevicePtrCast<T, true> {
}; };
template <typename T> template <typename T>
struct DevicePtrCast<T, false> { struct PointerToThrustDevicePtr<T, false> {
using RTYPE = T; using RTYPE = T;
inline RTYPE operator()(RTYPE it) const { return it; } inline RTYPE operator()(RTYPE it) const { return it; }
}; };
// Cast T to thrust::device_ptr if T is a pointer. // CastToCUDATransformIterator casts a pointer to thrust::device_ptr
// Otherwise, e.g., T is a iterator, return T itself. // so it could be used as the iterator of thrust::transform. It
// doesn't cast other types.
//
// We need CastToCUDATransformIterator because it is often that we
// want to use device memory pointers as transform iterators, e.g., to
// transform a block of float32 to float16. In this case, we want
// CastToCUDATransformIterator to cast float16/32 pointers to
// thrust::device_ptr, otherwise they cannot work as the iterator
// required by thrust::transform. At the same time, we don't want to
// cast thrust::device_ptr to thrust::device_ptr repeatedly.
template <typename T> template <typename T>
auto DevPtrCast(T t) -> auto CastToCUDATransformIterator(T t) ->
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE { typename PointerToThrustDevicePtr<T, std::is_pointer<T>::value>::RTYPE {
DevicePtrCast<T, std::is_pointer<T>::value> cast; PointerToThrustDevicePtr<T, std::is_pointer<T>::value> cast;
return cast(t); return cast(t);
} }
......
...@@ -14,29 +14,44 @@ limitations under the License. */ ...@@ -14,29 +14,44 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <type_traits>
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include <algorithm>
#include <type_traits>
#ifdef __NVCC__ #ifdef __NVCC__
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include "paddle/fluid/platform/details/device_ptr_cast.h" #include "paddle/fluid/platform/details/cuda_transform_iterator_cast.h"
#endif #endif
namespace paddle { namespace paddle {
namespace platform { namespace platform {
// Transform on host or device. It provides the same API in std library. // Transform applys a unary or a binary functor on each element in a
// range defined by a pair of iterators.
//
// - The specialization for CPU calls std::transform.
// - The specialization for CUDA calls thrust::tranform.
//
// NOTE: We need to define InputIter and OutputIter defined as
// different types, because the InputIter points op's inputs and
// OutputIter pints to op's outputs.
//
// NOTE: We don't assume that InputIter to be const InputType* and
// OutputIter to be OutputType*, because we might use a iterator
// class, paddle::fluid::operators::RowwiseTRansformIterator.
template <typename DeviceContext> template <typename DeviceContext>
struct Transform { struct Transform {
// The unary version.
template <typename InputIter, typename OutputIter, typename UnaryOperation> template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last, void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op); OutputIter result, UnaryOperation op);
// The binary version.
template <typename InputIter1, typename InputIter2, typename OutputIter, template <typename InputIter1, typename InputIter2, typename OutputIter,
typename BinaryOperation> typename BinaryOperation>
void operator()(const DeviceContext& context, InputIter1 first1, void operator()(const DeviceContext& context, InputIter1 first1,
...@@ -70,8 +85,9 @@ struct Transform<platform::CUDADeviceContext> { ...@@ -70,8 +85,9 @@ struct Transform<platform::CUDADeviceContext> {
auto place = context.GetPlace(); auto place = context.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place."); PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
thrust::transform(thrust::cuda::par.on(context.stream()), thrust::transform(thrust::cuda::par.on(context.stream()),
details::DevPtrCast(first), details::DevPtrCast(last), details::CastToCUDATransformIterator(first),
details::DevPtrCast(result), op); details::CastToCUDATransformIterator(last),
details::CastToCUDATransformIterator(result), op);
} }
template <typename InputIter1, typename InputIter2, typename OutputIter, template <typename InputIter1, typename InputIter2, typename OutputIter,
...@@ -82,9 +98,10 @@ struct Transform<platform::CUDADeviceContext> { ...@@ -82,9 +98,10 @@ struct Transform<platform::CUDADeviceContext> {
auto place = context.GetPlace(); auto place = context.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place."); PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
thrust::transform(thrust::cuda::par.on(context.stream()), thrust::transform(thrust::cuda::par.on(context.stream()),
details::DevPtrCast(first1), details::DevPtrCast(last1), details::CastToCUDATransformIterator(first1),
details::DevPtrCast(first2), details::DevPtrCast(result), details::CastToCUDATransformIterator(last1),
op); details::CastToCUDATransformIterator(first2),
details::CastToCUDATransformIterator(result), op);
} }
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册