diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index a19a3e3675e3e2e7cc0c3594f21191f932d6379f..19ec9ba9b26f5919796181a19a048b7edb508bdd 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -62,6 +62,24 @@ void Copy(platform::GPUPlace dst_place, } } +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num) { + platform::SetDeviceId(src_place.device); + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); +} + +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num) { + platform::SetDeviceId(dst_place.device); + platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); +} + #endif // PADDLE_ONLY_CPU } // namespace memory diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 17bdac8749e31565b119b2cb84aed199fac0f441..8b605e51c3f4ea38fc358ce054bb36fcc82063c4 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -24,3 +24,4 @@ cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) +nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place) diff --git a/paddle/platform/details/device_ptr_cast.h b/paddle/platform/details/device_ptr_cast.h new file mode 100644 index 0000000000000000000000000000000000000000..4015491fcdc3554029aa771ab7da1b2f3424321f --- /dev/null +++ b/paddle/platform/details/device_ptr_cast.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifndef __NVCC__ +#error device_ptr_cast must be include by .cu file +#endif + +#include + +namespace paddle { +namespace platform { +namespace details { +template +struct DevicePtrCast; + +template +struct DevicePtrCast { + using ELEM = typename std::remove_pointer::type; + using RTYPE = thrust::device_ptr; + + inline thrust::device_ptr operator()(ELEM* ele) const { + return thrust::device_pointer_cast(ele); + } +}; + +template +struct DevicePtrCast { + using RTYPE = T; + inline RTYPE operator()(RTYPE it) const { return it; } +}; + +// Cast T to thrust::device_ptr if T is a pointer. +// Otherwise, e.g., T is a iterator, return T itself. +template +auto DevPtrCast(T t) -> + typename DevicePtrCast::value>::RTYPE { + DevicePtrCast::value> cast; + return cast(t); +} + +} // namespace details +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h new file mode 100644 index 0000000000000000000000000000000000000000..3ee4acd29660f201d318ce6d39baa6f3999ae274 --- /dev/null +++ b/paddle/platform/transform.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/platform/enforce.h" +#include "paddle/platform/hostdevice.h" +#include "paddle/platform/place.h" + +#include +#include +#ifdef __NVCC__ +#include +#include "paddle/platform/details/device_ptr_cast.h" +#endif + +namespace paddle { +namespace platform { +// Transform on host or device. It provides the same API in std library. +template +void Transform(Place place, InputIter first, InputIter last, OutputIter result, + UnaryOperation op) { + if (is_cpu_place(place)) { + std::transform(first, last, result, op); + } else { +#ifdef __NVCC__ + using namespace details; + thrust::transform(DevPtrCast(first), DevPtrCast(last), DevPtrCast(result), + op); +#else + PADDLE_THROW("Do not invoke `Transform` in .cc file"); +#endif + } +} + +template +void Transform(Place place, InputIter1 first1, InputIter1 last1, + InputIter2 first2, OutputIter result, BinaryOperation op) { + if (is_cpu_place(place)) { + std::transform(first1, last1, first2, result, op); + } else { +#ifdef __NVCC__ + using namespace details; + thrust::transform(DevPtrCast(first1), DevPtrCast(last1), DevPtrCast(first2), + DevPtrCast(result), op); +#else + PADDLE_THROW("Do not invoke `Transform` in .cc file"); +#endif + } +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/transform_test.cu b/paddle/platform/transform_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..600fed8f45077a6fee91f295aa854153c9cf9c01 --- /dev/null +++ b/paddle/platform/transform_test.cu @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/memory/memcpy.h" +#include "paddle/memory/memory.h" +#include "paddle/platform/transform.h" + +template +class Scale { + public: + explicit Scale(const T& scale) : scale_(scale) {} + + HOSTDEVICE T operator()(const T& a) const { return a * scale_; } + + private: + T scale_; +}; + +template +class Multiply { + public: + HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } +}; + +TEST(Transform, CPUUnary) { + using namespace paddle::platform; + float buf[4] = {0.1, 0.2, 0.3, 0.4}; + Transform(CPUPlace(), buf, buf + 4, buf, Scale(10)); + for (int i = 0; i < 4; ++i) { + ASSERT_NEAR(buf[i], static_cast(i + 1), 1e-5); + } +} + +TEST(Transform, GPUUnary) { + using namespace paddle::platform; + using namespace paddle::memory; + GPUPlace gpu0(0); + float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4}; + float* gpu_buf = static_cast(Alloc(gpu0, sizeof(float) * 4)); + Copy(gpu0, gpu_buf, CPUPlace(), cpu_buf, sizeof(cpu_buf)); + Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, Scale(10)); + Copy(CPUPlace(), cpu_buf, gpu0, gpu_buf, sizeof(cpu_buf)); + Free(gpu0, gpu_buf); + for (int i = 0; i < 4; ++i) { + ASSERT_NEAR(cpu_buf[i], static_cast(i + 1), 1e-5); + } +} + +TEST(Transform, CPUBinary) { + using namespace paddle::platform; + using namespace paddle::memory; + int buf[4] = {1, 2, 3, 4}; + Transform(CPUPlace(), buf, buf + 4, buf, buf, Multiply()); + for (int i = 0; i < 4; ++i) { + ASSERT_EQ((i + 1) * (i + 1), buf[i]); + } +} + +TEST(Transform, GPUBinary) { + using namespace paddle::platform; + using namespace paddle::memory; + int buf[4] = {1, 2, 3, 4}; + GPUPlace gpu0(0); + int* gpu_buf = static_cast(Alloc(gpu0, sizeof(buf))); + Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf)); + Transform(gpu0, gpu_buf, gpu_buf + 4, gpu_buf, gpu_buf, Multiply()); + Copy(CPUPlace(), buf, gpu0, gpu_buf, sizeof(buf)); + Free(gpu0, gpu_buf); + for (int i = 0; i < 4; ++i) { + ASSERT_EQ((i + 1) * (i + 1), buf[i]); + } +} \ No newline at end of file