memcpy.cc 22.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16

17
#include "paddle/fluid/platform/device_context.h"
Z
Zeng Jinle 已提交
18
#include "paddle/fluid/platform/enforce.h"
19
#include "paddle/fluid/platform/profiler.h"
20

21 22 23 24
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu_header.h"
#endif

25 26 27 28 29 30 31
namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
32
  if (UNLIKELY(num == 0)) return;
33 34 35
  std::memcpy(dst, src, num);
}

36 37 38 39 40 41 42
#ifdef PADDLE_WITH_XPU
template <>
void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
43
    VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    return;
  }
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id >= 64) {
    // if dev_id >= 64, the device is a simulator device, -64 to get real dev_id
    dev_id -= 64;
  }
  if (dev_id != dst_place.device) {
    ret = xpu_set_device(dst_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
  ret = xpu_memcpy(dst, src, num, XPUMemcpyKind::XPU_HOST_TO_DEVICE);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id != dst_place.device) {
    ret = xpu_set_device(dev_id);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
}

template <>
void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
89
    VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    return;
  }
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id >= 64) {
    // if dev_id >= 64, the device is a simulator device, -64 to get real dev_id
    dev_id -= 64;
  }
  if (dev_id != src_place.device) {
    ret = xpu_set_device(src_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
  ret = xpu_memcpy(dst, src, num, XPUMemcpyKind::XPU_DEVICE_TO_HOST);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id != src_place.device) {
    ret = xpu_set_device(dev_id);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
  }
}

template <>
void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
135
    VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    return;
  }
  int dev_id = -1;
  int ret = xpu_current_device(&dev_id);
  PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
                    platform::errors::External(
                        "XPU API return wrong value[%d], please check whether "
                        "Baidu Kunlun Card is properly installed.",
                        ret));
  if (dev_id >= 64) {
    // if dev_id >= 64, the device is a simulator device, -64 to get real dev_id
    dev_id -= 64;
  }
  if (dev_id != src_place.device || dev_id != dst_place.device) {
    ret = xpu_set_device(src_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    void* tmp = malloc(num);
    ret = xpu_memcpy(tmp, src, num, XPUMemcpyKind::XPU_DEVICE_TO_HOST);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    ret = xpu_set_device(dst_place.device);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    ret = xpu_memcpy(dst, tmp, num, XPUMemcpyKind::XPU_HOST_TO_DEVICE);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    ret = xpu_set_device(dev_id);
    PADDLE_ENFORCE_EQ(
        ret, XPU_SUCCESS,
        platform::errors::External(
            "XPU API return wrong value[%d], please check whether "
            "Baidu Kunlun Card is properly installed.",
            ret));
    free(tmp);
  } else {
188 189 190 191 192 193 194
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    auto* dev_ctx = pool.GetByPlace(src_place);
    dev_ctx->Wait();
    int ret = xpu::memcpy_device(dev_ctx->x_context(), dst, src, num);
    PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, platform::errors::External(
                                            "XPU API return wrong value[%d %s]",
                                            ret, XPUAPIErrorMsg[ret]));
195 196 197 198
  }
}
#endif

199 200 201 202 203 204 205 206 207 208
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);
209

210 211
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
212

213 214 215 216
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
  } else {
217 218 219 220 221 222
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

223 224 225 226 227 228 229 230 231 232 233 234 235 236
    platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

template <>
void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);
237

238 239
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
240

241 242 243 244
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
  } else {
245 246 247
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

248
    platform::RecordEvent record_event("NpuMemcpySync:NPU->CPU");
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by stream(" << stream << ")";
  if (dst_place == src_place) {
    platform::SetNPUDeviceId(src_place.device);
    if (stream) {
      platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
                               stream);
    } else {
270 271 272 273
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

274 275 276 277 278 279 280 281 282 283 284 285 286 287
      platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  } else {
    if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Peer access between NPU places is not allowed."));
    }
    if (stream) {
      // TODO(zhiqiu): support peer access?
      platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
                               stream);
    } else {
288 289 290 291
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

292 293 294 295 296
      platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  }
}
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376

template <>
void Copy<platform::CPUPlace, platform::NPUPinnedPlace>(
    platform::CPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::CPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::CPUPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPinnedPlace>(
    platform::NPUPinnedPlace dst_place, void* dst,
    platform::NPUPinnedPlace src_place, const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::NPUPlace src_place,
    const void* src, size_t num, aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->NPUPinned");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
  } else {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

    platform::RecordEvent record_event("NpuMemcpySync:NPU->NPUPinned");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPinnedPlace>(
    platform::NPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
    const void* src, size_t num, aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPUPinned->NPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
  } else {
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

    platform::RecordEvent record_event("NpuMemcpySync:NPUPinned->NPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

377 378
#endif

379
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
sneaxiy 已提交
380 381
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

382 383 384 385 386 387 388 389 390 391 392 393 394
#ifdef PADDLE_WITH_HIP
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  hipStreamSynchronize(0);
#else
  hipError_t e_sync = hipSuccess;
  while (e_sync = hipStreamQuery(0)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
}
#else
395 396 397 398 399 400 401 402 403 404 405
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  cudaStreamSynchronize(0);
#else
  cudaError_t e_sync = cudaSuccess;
  while (e_sync = cudaStreamQuery(0)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
}
406
#endif
407

408 409 410 411 412 413
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

414
template <>
D
dzhwinter 已提交
415 416
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
417
    const void* src, size_t num, gpuStream_t stream) {
Z
Zeng Jinle 已提交
418
  if (UNLIKELY(num == 0)) return;
419

420 421
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
422
          << dst_place << " by stream(" << stream << ")";
423
  if (stream) {
424
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
425 426 427
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
428
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
429
#endif
430
  } else {
431
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
432 433 434
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
435
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
436
#endif
S
sneaxiy 已提交
437 438
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
439
      SyncCUDAStream();
S
sneaxiy 已提交
440
    }
441
  }
442 443 444
}

template <>
D
dzhwinter 已提交
445 446
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
447
    const void* src, size_t num, gpuStream_t stream) {
Z
Zeng Jinle 已提交
448 449
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
450
  platform::SetDeviceId(dst_place.device);
451 452
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
453
  if (stream) {
454
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
455 456 457
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
458
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
459
#endif
460
  } else {
461
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
462 463 464
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
465
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
466
#endif
S
sneaxiy 已提交
467 468
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
469
      SyncCUDAStream();
S
sneaxiy 已提交
470
    }
471
  }
472 473 474
}

template <>
D
dzhwinter 已提交
475 476
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
477
    const void* src, size_t num, gpuStream_t stream) {
Z
Zeng Jinle 已提交
478 479
  if (UNLIKELY(num == 0)) return;

480
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
481
          << dst_place << " by stream(" << stream << ")";
482
  if (dst_place == src_place) {
L
liaogang 已提交
483
    platform::SetDeviceId(src_place.device);
484
    if (stream) {
485
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
486 487 488
#ifdef PADDLE_WITH_HIP
      platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToDevice, stream);
#else
489
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
490
#endif
491
    } else {
492
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
493 494 495
#ifdef PADDLE_WITH_HIP
      platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToDevice);
#else
496
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
497
#endif
498
    }
499
  } else {
500
    if (stream) {
501
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
502 503 504
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
505
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
506
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
507
                                  num);
508
    }
509 510 511
  }
}

C
chengduoZH 已提交
512 513 514 515
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
516 517
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
518
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
519 520 521 522 523 524 525
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
526 527
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
528
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
529 530 531 532 533 534 535
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
536 537
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
538
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
539 540 541 542 543 544 545
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
546
    gpuStream_t stream) {
Z
Zeng Jinle 已提交
547
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
548
  platform::SetDeviceId(src_place.device);
549 550
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
551
  if (stream) {
552
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
553 554 555
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
556
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
557
#endif
558
  } else {
559
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
560 561 562
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
563
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
564
#endif
565
  }
C
chengduoZH 已提交
566 567 568 569 570 571
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
572
    gpuStream_t stream) {
Z
Zeng Jinle 已提交
573 574
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
575
  platform::SetDeviceId(dst_place.device);
576 577
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
578
  if (stream) {
579
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
580 581 582
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
583
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
584
#endif
585
  } else {
586
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
587 588 589
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
590
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
591
#endif
592
  }
C
chengduoZH 已提交
593 594
}

L
Luo Tao 已提交
595
#endif
Y
Yi Wang 已提交
596 597 598

}  // namespace memory
}  // namespace paddle