memcpy.cc 6.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16 17

#include <cstring>  // for memcpy
Z
Zeng Jinle 已提交
18
#include "paddle/fluid/platform/enforce.h"
19
#include "paddle/fluid/platform/profiler.h"
20 21 22 23 24 25 26 27

namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
28
  if (UNLIKELY(num == 0)) return;
29 30 31
  std::memcpy(dst, src, num);
}

32
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
33 34
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

35 36 37 38 39 40 41 42 43 44 45 46
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  cudaStreamSynchronize(0);
#else
  cudaError_t e_sync = cudaSuccess;
  while (e_sync = cudaStreamQuery(0)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
}

47 48 49 50 51 52
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

53
template <>
D
dzhwinter 已提交
54 55
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
56
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
57
  if (UNLIKELY(num == 0)) return;
58

59 60 61
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
62
  if (stream) {
63
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
64 65
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
66
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
67
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
S
sneaxiy 已提交
68 69
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
70
      SyncCUDAStream();
S
sneaxiy 已提交
71
    }
72
  }
73 74 75
}

template <>
D
dzhwinter 已提交
76 77
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
F
fengjiayi 已提交
78
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
79 80
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
81
  platform::SetDeviceId(dst_place.device);
82 83
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
84
  if (stream) {
85
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
86 87
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
88
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
89
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
S
sneaxiy 已提交
90 91
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
92
      SyncCUDAStream();
S
sneaxiy 已提交
93
    }
94
  }
95 96 97
}

template <>
D
dzhwinter 已提交
98 99
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
100
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
101 102
  if (UNLIKELY(num == 0)) return;

103
  if (dst_place == src_place) {
L
liaogang 已提交
104
    platform::SetDeviceId(src_place.device);
105
    if (stream) {
106
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
107 108
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
    } else {
109
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
110 111
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
    }
112
  } else {
113
    if (stream) {
114
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
115 116 117
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
118
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
119
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
120
                                  num);
121
    }
122 123 124
  }
}

C
chengduoZH 已提交
125 126 127 128
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
129
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
130 131 132 133 134 135 136
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
137
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
138 139 140 141 142 143 144
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
145
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
146 147 148 149 150 151 152
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
153
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
154
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
155
  platform::SetDeviceId(src_place.device);
156
  if (stream) {
157
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
158 159
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
160
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
161 162
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
  }
C
chengduoZH 已提交
163 164 165 166 167 168
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
169
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
170 171
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
172
  platform::SetDeviceId(dst_place.device);
173
  if (stream) {
174
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
175 176
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
177
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
178 179
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
  }
C
chengduoZH 已提交
180 181
}

L
Luo Tao 已提交
182
#endif
Y
Yi Wang 已提交
183 184 185

}  // namespace memory
}  // namespace paddle