memcpy.cc 17.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16 17

#include <cstring>  // for memcpy
18

19
#include "paddle/fluid/platform/device_context.h"
Z
Zeng Jinle 已提交
20
#include "paddle/fluid/platform/enforce.h"
21
#include "paddle/fluid/platform/profiler.h"
22

23 24 25 26
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu_header.h"
#endif

27 28 29 30 31 32 33
namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
34
  if (UNLIKELY(num == 0)) return;
35 36 37
  std::memcpy(dst, src, num);
}

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
#ifdef PADDLE_WITH_XPU
template <>
void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
    VLOG(0) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
    return;
  }
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id >= 64) {
    // if dev_id >= 64, the device is a simulator device, -64 to get real dev_id
    dev_id -= 64;
  }
  if (dev_id != dst_place.device) {
    ret = xpu_set_device(dst_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
  ret = xpu_memcpy(dst, src, num, XPUMemcpyKind::XPU_HOST_TO_DEVICE);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id != dst_place.device) {
    ret = xpu_set_device(dev_id);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
}

template <>
void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
    VLOG(0) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
    return;
  }
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id >= 64) {
    // if dev_id >= 64, the device is a simulator device, -64 to get real dev_id
    dev_id -= 64;
  }
  if (dev_id != src_place.device) {
    ret = xpu_set_device(src_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
  ret = xpu_memcpy(dst, src, num, XPUMemcpyKind::XPU_DEVICE_TO_HOST);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id != src_place.device) {
    ret = xpu_set_device(dev_id);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
}

template <>
void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
    VLOG(0) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
    return;
  }
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id >= 64) {
    // if dev_id >= 64, the device is a simulator device, -64 to get real dev_id
    dev_id -= 64;
  }
  if (dev_id != src_place.device || dev_id != dst_place.device) {
    ret = xpu_set_device(src_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    void* tmp = malloc(num);
    ret = xpu_memcpy(tmp, src, num, XPUMemcpyKind::XPU_DEVICE_TO_HOST);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    ret = xpu_set_device(dst_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    ret = xpu_memcpy(dst, tmp, num, XPUMemcpyKind::XPU_HOST_TO_DEVICE);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    ret = xpu_set_device(dev_id);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    free(tmp);
  } else {
190 191 192 193 194 195 196
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    auto* dev_ctx = pool.GetByPlace(src_place);
    dev_ctx->Wait();
    int ret = xpu::memcpy_device(dev_ctx->x_context(), dst, src, num);
    PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, platform::errors::External(
                                            "XPU API return wrong value[%d %s]",
                                            ret, XPUAPIErrorMsg[ret]));
197 198 199 200
  }
}
#endif

201 202 203 204 205 206 207 208 209 210
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);
211

212 213
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
214

215 216 217 218
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
  } else {
219 220 221 222 223 224
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

225 226 227 228 229 230 231 232 233 234 235 236 237 238
    platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

template <>
void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);
239

240 241
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
242

243 244 245 246
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
  } else {
247
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
L
Leo Chen 已提交
248
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();
249

250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
    platform::RecordEvent record_event("GpuMemcpySync:NPU->CPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by stream(" << stream << ")";
  if (dst_place == src_place) {
    platform::SetNPUDeviceId(src_place.device);
    if (stream) {
      platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
                               stream);
    } else {
272 273 274 275
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

276 277 278 279 280 281 282 283 284 285 286 287 288 289
      platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  } else {
    if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Peer access between NPU places is not allowed."));
    }
    if (stream) {
      // TODO(zhiqiu): support peer access?
      platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
                               stream);
    } else {
290 291 292 293
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

294 295 296 297 298 299 300
      platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  }
}
#endif

301
#ifdef PADDLE_WITH_CUDA
S
sneaxiy 已提交
302 303
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

304 305 306 307 308 309 310 311 312 313 314 315
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  cudaStreamSynchronize(0);
#else
  cudaError_t e_sync = cudaSuccess;
  while (e_sync = cudaStreamQuery(0)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
}

316 317 318 319 320 321
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

322
template <>
D
dzhwinter 已提交
323 324
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
325
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
326
  if (UNLIKELY(num == 0)) return;
327

328 329 330
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
331
  if (stream) {
332
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
333 334
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
335
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
336
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
S
sneaxiy 已提交
337 338
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
339
      SyncCUDAStream();
S
sneaxiy 已提交
340
    }
341
  }
342 343 344
}

template <>
D
dzhwinter 已提交
345 346
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
F
fengjiayi 已提交
347
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
348 349
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
350
  platform::SetDeviceId(dst_place.device);
351 352
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
353
  if (stream) {
354
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
355 356
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
357
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
358
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
S
sneaxiy 已提交
359 360
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
361
      SyncCUDAStream();
S
sneaxiy 已提交
362
    }
363
  }
364 365 366
}

template <>
D
dzhwinter 已提交
367 368
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
369
    const void* src, size_t num, cudaStream_t stream) {
Z
Zeng Jinle 已提交
370 371
  if (UNLIKELY(num == 0)) return;

372
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
373
          << dst_place << " by stream(" << stream << ")";
374
  if (dst_place == src_place) {
L
liaogang 已提交
375
    platform::SetDeviceId(src_place.device);
376
    if (stream) {
377
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
378 379
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
    } else {
380
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
381 382
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
    }
383
  } else {
384
    if (stream) {
385
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
386 387 388
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
389
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
390
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
391
                                  num);
392
    }
393 394 395
  }
}

C
chengduoZH 已提交
396 397 398 399
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
400 401
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
402
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
403 404 405 406 407 408 409
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
410 411
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
412
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
413 414 415 416 417 418 419
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
420 421
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
422
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
423 424 425 426 427 428 429
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
430
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
431
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
432
  platform::SetDeviceId(src_place.device);
433 434
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
435
  if (stream) {
436
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
437 438
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
439
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
440 441
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
  }
C
chengduoZH 已提交
442 443 444 445 446 447
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
448
    cudaStream_t stream) {
Z
Zeng Jinle 已提交
449 450
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
451
  platform::SetDeviceId(dst_place.device);
452 453
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
454
  if (stream) {
455
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
456 457
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
458
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
459 460
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
  }
C
chengduoZH 已提交
461 462
}

L
Luo Tao 已提交
463
#endif
Y
Yi Wang 已提交
464 465 466

}  // namespace memory
}  // namespace paddle