memcpy.cc 5.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16 17

#include <cstring>  // for memcpy
18
#include "paddle/fluid/platform/profiler.h"
19 20 21 22 23 24 25 26 27 28 29

namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

30
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
31 32
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

33 34 35 36 37 38
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

39
template <>
D
dzhwinter 已提交
40 41
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
42
    const void* src, size_t num, cudaStream_t stream) {
L
liaogang 已提交
43
  platform::SetDeviceId(src_place.device);
44

45
  if (stream) {
46
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
47 48
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
49
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
50
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
S
sneaxiy 已提交
51 52 53 54
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
      cudaStreamSynchronize(0);
    }
55
  }
56 57 58
}

template <>
D
dzhwinter 已提交
59 60
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
F
fengjiayi 已提交
61
    const void* src, size_t num, cudaStream_t stream) {
L
liaogang 已提交
62
  platform::SetDeviceId(dst_place.device);
63
  if (stream) {
64
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
65 66
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
67
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
68
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
S
sneaxiy 已提交
69 70 71 72
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
      cudaStreamSynchronize(0);
    }
73
  }
74 75 76
}

template <>
D
dzhwinter 已提交
77 78
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
79
    const void* src, size_t num, cudaStream_t stream) {
80
  if (dst_place == src_place) {
L
liaogang 已提交
81
    platform::SetDeviceId(src_place.device);
82
    if (stream) {
83
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
84 85
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
    } else {
86
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
87 88
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
    }
89
  } else {
90
    if (stream) {
91
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
92 93 94
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
95
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
96
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
97
                                  num);
98
    }
99 100 101
  }
}

C
chengduoZH 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
127
    cudaStream_t stream) {
C
chengduoZH 已提交
128
  platform::SetDeviceId(src_place.device);
129
  if (stream) {
130
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
131 132
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
133
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
134 135
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
  }
C
chengduoZH 已提交
136 137 138 139 140 141
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
142
    cudaStream_t stream) {
C
chengduoZH 已提交
143
  platform::SetDeviceId(dst_place.device);
144
  if (stream) {
145
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
146 147
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
148
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
149 150
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
  }
C
chengduoZH 已提交
151 152
}

L
Luo Tao 已提交
153
#endif
Y
Yi Wang 已提交
154 155 156

}  // namespace memory
}  // namespace paddle