memcpy.cc 5.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16 17

#include <cstring>  // for memcpy
Z
Zeng Jinle 已提交
18
#include "paddle/fluid/platform/enforce.h"
19
#include "paddle/fluid/platform/profiler.h"
20 21 22 23 24 25 26 27

namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
28
  if (UNLIKELY(num == 0)) return;
29 30 31
  std::memcpy(dst, src, num);
}

32
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
33 34
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

35 36 37 38 39 40
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

41
template <>
D
dzhwinter 已提交
42 43
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
44
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
45
  if (UNLIKELY(num == 0)) return;
L
liaogang 已提交
46
  platform::SetDeviceId(src_place.device);
47

48
  if (stream) {
49
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
50 51
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
52
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
53
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
S
sneaxiy 已提交
54 55 56 57
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
      cudaStreamSynchronize(0);
    }
58
  }
59 60 61
}

template <>
D
dzhwinter 已提交
62 63
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
F
fengjiayi 已提交
64
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
65 66
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
67
  platform::SetDeviceId(dst_place.device);
68
  if (stream) {
69
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
70 71
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
72
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
73
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
S
sneaxiy 已提交
74 75 76 77
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
      cudaStreamSynchronize(0);
    }
78
  }
79 80 81
}

template <>
D
dzhwinter 已提交
82 83
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
84
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
85 86
  if (UNLIKELY(num == 0)) return;

87
  if (dst_place == src_place) {
L
liaogang 已提交
88
    platform::SetDeviceId(src_place.device);
89
    if (stream) {
90
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
91 92
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
    } else {
93
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
94 95
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
    }
96
  } else {
97
    if (stream) {
98
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
99 100 101
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
102
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
103
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
104
                                  num);
105
    }
106 107 108
  }
}

C
chengduoZH 已提交
109 110 111 112
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
113
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
114 115 116 117 118 119 120
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
121
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
122 123 124 125 126 127 128
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
129
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
130 131 132 133 134 135 136
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
137
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
138
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
139
  platform::SetDeviceId(src_place.device);
140
  if (stream) {
141
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
142 143
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
144
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
145 146
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
  }
C
chengduoZH 已提交
147 148 149 150 151 152
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
153
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
154 155
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
156
  platform::SetDeviceId(dst_place.device);
157
  if (stream) {
158
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
159 160
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
161
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
162 163
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
  }
C
chengduoZH 已提交
164 165
}

L
Luo Tao 已提交
166
#endif
Y
Yi Wang 已提交
167 168 169

}  // namespace memory
}  // namespace paddle