diff --git a/paddle/memory/CMakeLists.txt b/paddle/memory/CMakeLists.txt index fac442cca56b81f56a750bd3b1c2c0911e79e468..a5c4420ac0649f07670b794d6736eb214fe19445 100644 --- a/paddle/memory/CMakeLists.txt +++ b/paddle/memory/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(detail) cc_library(memory SRCS memory.cc) +cc_library(memcpy SRCS memcpy.cc) cc_library(paddle_memory DEPS diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc new file mode 100644 index 0000000000000000000000000000000000000000..804369de8240b8354291c85e5f3dcdec04f65e78 --- /dev/null +++ b/paddle/memory/memcpy.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/memory/memcpy.h" + +#include // for memcpy + +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace memory { + +template <> +void Copy(platform::CPUPlace, void* dst, + platform::CPUPlace, + const void* src, size_t num) { + std::memcpy(dst, src, num); +} + +#ifndef PADDLE_ONLY_CPU +template <> +void Copy(platform::CPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num, + cudaStream_t stream) { + platform::GPUPlaceGuard g(src_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); +} + +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::CPUPlace src_place, + const void* src, size_t num, + cudaStream_t stream) { + platform::GPUPlaceGuard g(dst_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); +} + +template <> +void Copy(platform::GPUPlace dst_place, + void* dst, + platform::GPUPlace src_place, + const void* src, size_t num, + cudaStream_t stream) { + if (dst_place == src_place) { + platform::GPUPlaceGuard g(src_place.device); + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); + } else { + platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, + stream); + } +} + +#endif // PADDLE_ONLY_CPU diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h new file mode 100644 index 0000000000000000000000000000000000000000..99b1c2e1c3e5ae4facaeb4fd0b773a7531448f03 --- /dev/null +++ b/paddle/memory/memcpy.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/platform/gpu_info.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace memory { + +template +void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); + +#ifndef PADDLE_ONLY_CPU +template +void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, + cudaStream_t stream); +#endif // PADDLE_ONLY_CPU + +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 78443cc35a400bceac77b99c3468daf16d8a4690..c2e046926fafd8f4cfc4cd81d8f32e3882ff02ec 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -46,13 +46,6 @@ size_t Used(platform::CPUPlace place) { return GetCPUBuddyAllocator()->Used(); } -template <> -void Copy(platform::CPUPlace, void* dst, - platform::CPUPlace, - const void* src, size_t num) { - std::memcpy(dst, src, num); -} - #ifndef PADDLE_ONLY_CPU detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { @@ -85,41 +78,6 @@ size_t Used(platform::GPUPlace place) { return GetGPUBuddyAllocator(place.device)->Used(); } -template <> -void Copy(platform::CPUPlace dst_place, - void* dst, - platform::GPUPlace src_place, - const void* src, size_t num, - cudaStream_t stream) { - platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); -} - -template <> -void Copy(platform::GPUPlace dst_place, - void* dst, - platform::CPUPlace src_place, - const void* src, size_t num, - cudaStream_t stream) { - platform::SetDeviceId(dst_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); -} - -template <> -void Copy(platform::GPUPlace dst_place, - void* dst, - platform::GPUPlace src_place, - const void* src, size_t num, - cudaStream_t stream) { - if (dst_place == src_place) { - platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); - } else { - platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, - stream); - } -} - #endif // PADDLE_ONLY_CPU } // namespace memory diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index 7ef7a73bc8b25e6a637a5e89c87e3eef06174b92..5e0d64707299acb22aacff0fad237c135f614d9c 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -29,15 +29,6 @@ void Free(Place, void*); template size_t Used(Place); -template -void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); - -#ifndef PADDLE_ONLY_CPU -template -void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, - cudaStream_t stream); -#endif // PADDLE_ONLY_CPU - template ::value>::type* = nullptr>