diff --git a/paddle/platform/details/device_ptr_cast.h b/paddle/platform/details/device_ptr_cast.h new file mode 100644 index 0000000000000000000000000000000000000000..4015491fcdc3554029aa771ab7da1b2f3424321f --- /dev/null +++ b/paddle/platform/details/device_ptr_cast.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifndef __NVCC__ +#error device_ptr_cast must be include by .cu file +#endif + +#include + +namespace paddle { +namespace platform { +namespace details { +template +struct DevicePtrCast; + +template +struct DevicePtrCast { + using ELEM = typename std::remove_pointer::type; + using RTYPE = thrust::device_ptr; + + inline thrust::device_ptr operator()(ELEM* ele) const { + return thrust::device_pointer_cast(ele); + } +}; + +template +struct DevicePtrCast { + using RTYPE = T; + inline RTYPE operator()(RTYPE it) const { return it; } +}; + +// Cast T to thrust::device_ptr if T is a pointer. +// Otherwise, e.g., T is a iterator, return T itself. +template +auto DevPtrCast(T t) -> + typename DevicePtrCast::value>::RTYPE { + DevicePtrCast::value> cast; + return cast(t); +} + +} // namespace details +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h index c80446b45c8201f094c039951c8b6f6bde70e43c..3ee4acd29660f201d318ce6d39baa6f3999ae274 100644 --- a/paddle/platform/transform.h +++ b/paddle/platform/transform.h @@ -21,41 +21,12 @@ #include #include #ifdef __NVCC__ -#include #include +#include "paddle/platform/details/device_ptr_cast.h" #endif namespace paddle { namespace platform { - -#ifdef __NVCC__ -template -struct DevicePtrCast; - -template -struct DevicePtrCast { - using ELEM = typename std::remove_pointer::type; - using RTYPE = thrust::device_ptr; - - inline thrust::device_ptr operator()(ELEM* ele) const { - return thrust::device_pointer_cast(ele); - } -}; - -template -struct DevicePtrCast { - using RTYPE = T; - inline RTYPE operator()(RTYPE it) const { return it; } -}; - -template -auto DevCast(T t) -> - typename DevicePtrCast::value>::RTYPE { - DevicePtrCast::value> cast; - return cast(t); -} -#endif - // Transform on host or device. It provides the same API in std library. template @@ -65,7 +36,9 @@ void Transform(Place place, InputIter first, InputIter last, OutputIter result, std::transform(first, last, result, op); } else { #ifdef __NVCC__ - thrust::transform(DevCast(first), DevCast(last), DevCast(result), op); + using namespace details; + thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result), + op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif @@ -80,8 +53,9 @@ void Transform(Place place, InputIter1 first1, InputIter1 last1, std::transform(first1, last1, first2, result, op); } else { #ifdef __NVCC__ - thrust::transform(DevCast(first1), DevCast(last1), DevCast(first2), - DevCast(result), op); + using namespace details; + thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2), + DevPtrCast(result), op); #else PADDLE_THROW("Do not invoke `Transform` in .cc file"); #endif