未验证 提交 47609ab2 编写于 作者: Y Yi Wang 提交者: GitHub

Document transform.h and fix cpplint errors (#9913)

上级 b43d87c9
......@@ -18,16 +18,22 @@ limitations under the License. */
#error device_ptr_cast must be include by .cu file
#endif
#include <thrust/device_ptr.h>
#include <type_traits> // For std::remove_pointer and std::is_pointer.
#include "thrust/device_ptr.h"
namespace paddle {
namespace platform {
namespace details {
// PointerToThrustDevicePtr has two speicalizations, one casts a (CUDA
// device) pointer into thrust::device_ptr, the other keeps rest types
// un-casted.
template <typename T, bool is_ptr>
struct DevicePtrCast;
struct PointerToThrustDevicePtr;
template <typename T>
struct DevicePtrCast<T, true> {
struct PointerToThrustDevicePtr<T, true> {
using ELEM = typename std::remove_pointer<T>::type;
using RTYPE = thrust::device_ptr<ELEM>;
......@@ -37,17 +43,26 @@ struct DevicePtrCast<T, true> {
};
template <typename T>
struct DevicePtrCast<T, false> {
struct PointerToThrustDevicePtr<T, false> {
using RTYPE = T;
inline RTYPE operator()(RTYPE it) const { return it; }
};
// Cast T to thrust::device_ptr if T is a pointer.
// Otherwise, e.g., T is a iterator, return T itself.
// CastToCUDATransformIterator casts a pointer to thrust::device_ptr
// so it could be used as the iterator of thrust::transform. It
// doesn't cast other types.
//
// We need CastToCUDATransformIterator because it is often that we
// want to use device memory pointers as transform iterators, e.g., to
// transform a block of float32 to float16. In this case, we want
// CastToCUDATransformIterator to cast float16/32 pointers to
// thrust::device_ptr, otherwise they cannot work as the iterator
// required by thrust::transform. At the same time, we don't want to
// cast thrust::device_ptr to thrust::device_ptr repeatedly.
template <typename T>
auto DevPtrCast(T t) ->
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
DevicePtrCast<T, std::is_pointer<T>::value> cast;
auto CastToCUDATransformIterator(T t) ->
typename PointerToThrustDevicePtr<T, std::is_pointer<T>::value>::RTYPE {
PointerToThrustDevicePtr<T, std::is_pointer<T>::value> cast;
return cast(t);
}
......
......@@ -14,29 +14,44 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <type_traits>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/place.h"
#include <algorithm>
#include <type_traits>
#ifdef __NVCC__
#include <thrust/execution_policy.h>
#include <thrust/transform.h>
#include "paddle/fluid/platform/details/device_ptr_cast.h"
#include "paddle/fluid/platform/details/cuda_transform_iterator_cast.h"
#endif
namespace paddle {
namespace platform {
// Transform on host or device. It provides the same API in std library.
// Transform applys a unary or a binary functor on each element in a
// range defined by a pair of iterators.
//
// - The specialization for CPU calls std::transform.
// - The specialization for CUDA calls thrust::tranform.
//
// NOTE: We need to define InputIter and OutputIter defined as
// different types, because the InputIter points op's inputs and
// OutputIter pints to op's outputs.
//
// NOTE: We don't assume that InputIter to be const InputType* and
// OutputIter to be OutputType*, because we might use a iterator
// class, paddle::fluid::operators::RowwiseTRansformIterator.
template <typename DeviceContext>
struct Transform {
// The unary version.
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op);
// The binary version.
template <typename InputIter1, typename InputIter2, typename OutputIter,
typename BinaryOperation>
void operator()(const DeviceContext& context, InputIter1 first1,
......@@ -70,8 +85,9 @@ struct Transform<platform::CUDADeviceContext> {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
thrust::transform(thrust::cuda::par.on(context.stream()),
details::DevPtrCast(first), details::DevPtrCast(last),
details::DevPtrCast(result), op);
details::CastToCUDATransformIterator(first),
details::CastToCUDATransformIterator(last),
details::CastToCUDATransformIterator(result), op);
}
template <typename InputIter1, typename InputIter2, typename OutputIter,
......@@ -82,9 +98,10 @@ struct Transform<platform::CUDADeviceContext> {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
thrust::transform(thrust::cuda::par.on(context.stream()),
details::DevPtrCast(first1), details::DevPtrCast(last1),
details::DevPtrCast(first2), details::DevPtrCast(result),
op);
details::CastToCUDATransformIterator(first1),
details::CastToCUDATransformIterator(last1),
details::CastToCUDATransformIterator(first2),
details::CastToCUDATransformIterator(result), op);
}
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册