memcpy.cc 6.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16 17

#include <cstring>  // for memcpy
Z
Zeng Jinle 已提交
18
#include "paddle/fluid/platform/enforce.h"
19
#include "paddle/fluid/platform/profiler.h"
20 21 22 23 24 25 26 27

namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
28
  if (UNLIKELY(num == 0)) return;
29 30 31
  std::memcpy(dst, src, num);
}

32
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
33 34
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

35 36 37 38 39 40
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

41
template <>
D
dzhwinter 已提交
42 43
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
44
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
45
  if (UNLIKELY(num == 0)) return;
46

47 48 49
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
50
  if (stream) {
51
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
52 53
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
54
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
55
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
S
sneaxiy 已提交
56 57 58 59
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
      cudaStreamSynchronize(0);
    }
60
  }
61 62 63
}

template <>
D
dzhwinter 已提交
64 65
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
F
fengjiayi 已提交
66
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
67 68
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
69
  platform::SetDeviceId(dst_place.device);
70 71
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
72
  if (stream) {
73
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
74 75
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
76
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
77
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
S
sneaxiy 已提交
78 79 80 81
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
      cudaStreamSynchronize(0);
    }
82
  }
83 84 85
}

template <>
D
dzhwinter 已提交
86 87
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
88
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
89 90
  if (UNLIKELY(num == 0)) return;

91
  if (dst_place == src_place) {
L
liaogang 已提交
92
    platform::SetDeviceId(src_place.device);
93
    if (stream) {
94
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
95 96
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
    } else {
97
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
98 99
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
    }
100
  } else {
101
    if (stream) {
102
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
103 104 105
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
106
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
107
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
108
                                  num);
109
    }
110 111 112
  }
}

C
chengduoZH 已提交
113 114 115 116
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
117
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
118 119 120 121 122 123 124
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
125
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
126 127 128 129 130 131 132
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
Z
Zeng Jinle 已提交
133
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
134 135 136 137 138 139 140
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
141
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
142
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
143
  platform::SetDeviceId(src_place.device);
144
  if (stream) {
145
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
146 147
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
148
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
149 150
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
  }
C
chengduoZH 已提交
151 152 153 154 155 156
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
157
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
158 159
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
160
  platform::SetDeviceId(dst_place.device);
161
  if (stream) {
162
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
163 164
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
165
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
166 167
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
  }
C
chengduoZH 已提交
168 169
}

L
Luo Tao 已提交
170
#endif
Y
Yi Wang 已提交
171 172 173

}  // namespace memory
}  // namespace paddle