memcpy.cc 61.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16

17
#include "paddle/fluid/platform/device/device_wrapper.h"
18
#include "paddle/fluid/platform/device_context.h"
19
#include "paddle/fluid/platform/profiler/event_tracing.h"
20
#include "paddle/phi/common/place.h"
21

22 23 24 25 26 27 28 29
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#endif

#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif

30 31 32
namespace paddle {
namespace memory {

33 34 35
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <>
void Copy<platform::CPUPlace, platform::CustomPlace>(
36 37 38 39 40 41
    platform::CPUPlace dst_place,
    void* dst,
    platform::CustomPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
42 43 44 45 46 47 48 49 50
  if (UNLIKELY(num == 0)) return;

  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

51 52 53
  phi::DeviceManager::SetDevice(src_place);
  phi::stream::Stream stream_wrapper(src_place, stream);
  phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2H(
54 55 56 57 58
      dst, src, num, &stream_wrapper);
}

template <>
void Copy<platform::CustomPlace, platform::CPUPlace>(
59 60 61 62 63 64
    platform::CustomPlace dst_place,
    void* dst,
    platform::CPUPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
65 66 67 68 69 70 71 72
  if (UNLIKELY(num == 0)) return;
  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

73 74 75
  phi::DeviceManager::SetDevice(dst_place);
  phi::stream::Stream stream_wrapper(dst_place, stream);
  phi::DeviceManager::GetDeviceWithPlace(dst_place)->MemoryCopyH2D(
76 77 78 79 80
      dst, src, num, &stream_wrapper);
}

template <>
void Copy<platform::CustomPlace, platform::CustomPlace>(
81 82 83 84 85 86
    platform::CustomPlace dst_place,
    void* dst,
    platform::CustomPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
87 88 89 90 91 92 93 94 95 96
  if (UNLIKELY(num == 0)) return;

  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

  if (src_type == dst_type) {
97 98
    phi::DeviceManager::SetDevice(src_place);
    phi::stream::Stream stream_wrapper(src_place, stream);
99 100 101 102

    auto src_id = platform::PlaceHelper::GetDeviceId(src_place);
    auto dst_id = platform::PlaceHelper::GetDeviceId(dst_place);
    if (src_id == dst_id) {
103
      phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2D(
104 105
          dst, src, num, &stream_wrapper);
    } else {
106
      phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyP2P(
107 108 109 110 111 112 113 114 115
          dst_place, dst, src, num, &stream_wrapper);
    }
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Copy between %s and %s is not supported.", src_type, dst_type));
  }
}
#endif  // PADDLE_WITH_CUSTOM_DEVICE

116
template <>
117 118
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace,
                                                  void* dst,
119
                                                  platform::CPUPlace,
120 121
                                                  const void* src,
                                                  size_t num) {
Z
Zeng Jinle 已提交
122
  if (UNLIKELY(num == 0)) return;
123
  VLOG(4) << "src: " << src << ", dst: " << dst << ", num: " << num;
124 125
  std::memcpy(dst, src, num);
}
126

J
jianghaicheng 已提交
127 128 129 130 131
#ifdef PADDLE_WITH_IPU
template <>
void Copy<platform::IPUPlace, platform::CPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
132 133
                                                  const void* src,
                                                  size_t num) {
J
jianghaicheng 已提交
134 135 136 137 138 139 140
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::CPUPlace, platform::IPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
141 142
                                                  const void* src,
                                                  size_t num) {
J
jianghaicheng 已提交
143 144 145 146 147 148 149
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::IPUPlace, platform::IPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
150 151
                                                  const void* src,
                                                  size_t num) {
J
jianghaicheng 已提交
152 153 154
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
155 156 157

// NOTE: only for (CPUPlace and IPUPlace) -> (IPUPlace).
template <>
158 159 160 161
void Copy<phi::IPUPlace, phi::Place>(phi::IPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
162 163
                                     size_t num) {
  if (src_place.GetType() == phi::AllocationType::CPU) {
164 165
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
166
  } else if (src_place.GetType() == phi::AllocationType::IPU) {
167 168 169 170 171 172 173
    platform::IPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (IPUPlace) -> (CPUPlace and IPUPlace).
template <>
174 175 176 177
void Copy<phi::Place, phi::IPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::IPUPlace src_place,
                                     const void* src,
178 179
                                     size_t num) {
  if (dst_place.GetType() == phi::AllocationType::CPU) {
180 181
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
182
  } else if (dst_place.GetType() == phi::AllocationType::IPU) {
183 184 185 186
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
J
jianghaicheng 已提交
187
#endif
188

189 190 191 192 193
#ifdef PADDLE_WITH_XPU
template <>
void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
194 195
                                                  const void* src,
                                                  size_t num) {
196
  if (num <= 0) {
197
    VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
198 199
    return;
  }
200
  platform::MemcpySyncH2D(dst, src, num, dst_place);
201 202 203 204 205 206
}

template <>
void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
207 208
                                                  const void* src,
                                                  size_t num) {
209
  if (num <= 0) {
210
    VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
211 212
    return;
  }
213
  platform::MemcpySyncD2H(dst, src, num, src_place);
214 215 216 217 218 219
}

template <>
void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
220 221
                                                  const void* src,
                                                  size_t num) {
222
  if (num <= 0) {
223
    VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
224 225
    return;
  }
226
  platform::MemcpySyncD2D(dst, dst_place, src, src_place, num);
227
}
228 229 230

// NOTE: only for (CPUPlace and XPUPlace) -> (XPUPlace).
template <>
231 232 233 234
void Copy<phi::XPUPlace, phi::Place>(phi::XPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
235 236
                                     size_t num) {
  if (src_place.GetType() == phi::AllocationType::CPU) {
237 238
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
239
  } else if (src_place.GetType() == phi::AllocationType::XPU) {
240 241 242 243 244 245 246
    platform::XPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (XPUPlace) -> (CPUPlace and XPUPlace).
template <>
247 248 249 250
void Copy<phi::Place, phi::XPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::XPUPlace src_place,
                                     const void* src,
251 252
                                     size_t num) {
  if (dst_place.GetType() == phi::AllocationType::CPU) {
253 254
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
255
  } else if (dst_place.GetType() == phi::AllocationType::XPU) {
256 257 258 259
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
260 261
#endif

262 263 264 265 266
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
267 268
                                                  const void* src,
                                                  size_t num,
269
                                                  void* stream) {
270 271 272
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);
273

274 275
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
276

277
  if (stream) {
278 279
    platform::RecordEvent record_event(
        "NpuMemcpyAsync:CPU->NPU", platform::TracerEventType::UserDefined, 1);
280 281 282 283
    platform::NPUMemcpyAsync(dst,
                             src,
                             num,
                             ACL_MEMCPY_HOST_TO_DEVICE,
284
                             reinterpret_cast<aclrtStream>(stream));
285
  } else {
286 287 288 289 290 291
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

292 293
    platform::RecordEvent record_event(
        "NpuMemcpySync:CPU->NPU", platform::TracerEventType::UserDefined, 1);
294 295 296 297 298 299 300 301
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

template <>
void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
302 303
                                                  const void* src,
                                                  size_t num,
304
                                                  void* stream) {
305 306 307
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);
308

309 310
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
311

312
  if (stream) {
313 314
    platform::RecordEvent record_event(
        "NpuMemcpyAsync:NPU->CPU", platform::TracerEventType::UserDefined, 1);
315 316 317 318
    platform::NPUMemcpyAsync(dst,
                             src,
                             num,
                             ACL_MEMCPY_DEVICE_TO_HOST,
319
                             reinterpret_cast<aclrtStream>(stream));
320
  } else {
321 322 323
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

324 325
    platform::RecordEvent record_event(
        "NpuMemcpySync:NPU->CPU", platform::TracerEventType::UserDefined, 1);
326 327 328 329 330 331 332 333
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
334 335
                                                  const void* src,
                                                  size_t num,
336
                                                  void* stream) {
337 338 339 340 341 342 343
  if (UNLIKELY(num == 0)) return;

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by stream(" << stream << ")";
  if (dst_place == src_place) {
    platform::SetNPUDeviceId(src_place.device);
    if (stream) {
344 345 346
      platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
347 348 349 350
      platform::NPUMemcpyAsync(dst,
                               src,
                               num,
                               ACL_MEMCPY_DEVICE_TO_DEVICE,
351
                               reinterpret_cast<aclrtStream>(stream));
352
    } else {
353 354 355 356
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

357 358 359
      platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
360 361 362 363 364 365 366 367 368
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  } else {
    if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Peer access between NPU places is not allowed."));
    }
    if (stream) {
      // TODO(zhiqiu): support peer access?
369 370 371
      platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
372 373 374 375
      platform::NPUMemcpyAsync(dst,
                               src,
                               num,
                               ACL_MEMCPY_DEVICE_TO_DEVICE,
376
                               reinterpret_cast<aclrtStream>(stream));
377
    } else {
378 379 380 381
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

382 383 384
      platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
385 386 387 388
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  }
}
389 390 391

template <>
void Copy<platform::CPUPlace, platform::NPUPinnedPlace>(
392 393 394 395 396
    platform::CPUPlace dst_place,
    void* dst,
    platform::NPUPinnedPlace src_place,
    const void* src,
    size_t num) {
397 398 399 400 401 402 403 404
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::CPUPlace>(
405 406 407 408 409
    platform::NPUPinnedPlace dst_place,
    void* dst,
    platform::CPUPlace src_place,
    const void* src,
    size_t num) {
410 411 412 413 414 415 416 417
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPinnedPlace>(
418 419 420 421 422
    platform::NPUPinnedPlace dst_place,
    void* dst,
    platform::NPUPinnedPlace src_place,
    const void* src,
    size_t num) {
423 424 425 426 427 428 429 430
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPlace>(
431 432 433 434 435 436
    platform::NPUPinnedPlace dst_place,
    void* dst,
    platform::NPUPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
437 438 439 440 441 442 443 444
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
445 446 447
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->NPUPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
448 449 450 451
    platform::NPUMemcpyAsync(dst,
                             src,
                             num,
                             ACL_MEMCPY_DEVICE_TO_HOST,
452
                             reinterpret_cast<aclrtStream>(stream));
453 454 455 456
  } else {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

457 458 459
    platform::RecordEvent record_event("NpuMemcpySync:NPU->NPUPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
460 461 462 463 464 465
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPinnedPlace>(
466 467 468 469 470 471
    platform::NPUPlace dst_place,
    void* dst,
    platform::NPUPinnedPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
472 473 474 475 476 477 478 479
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
480 481 482
    platform::RecordEvent record_event("NpuMemcpyAsync:NPUPinned->NPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
483 484 485 486
    platform::NPUMemcpyAsync(dst,
                             src,
                             num,
                             ACL_MEMCPY_HOST_TO_DEVICE,
487
                             reinterpret_cast<aclrtStream>(stream));
488 489 490 491 492 493 494
  } else {
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

495 496 497
    platform::RecordEvent record_event("NpuMemcpySync:NPUPinned->NPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
498 499 500 501
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

502 503
// NOTE: only for CPUPlace, NPUPlace and NPUPinnedPlace.
template <>
504 505 506 507 508 509
void Copy<phi::Place, phi::Place>(phi::Place dst_place,
                                  void* dst,
                                  phi::Place src_place,
                                  const void* src,
                                  size_t num,
                                  aclrtStream stream) {
510 511
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
512 513
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
514 515
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::NPU) {
516 517 518
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
519 520
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
521 522 523
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
524 525
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::NPU) {
526 527 528
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
529 530
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
531 532 533
    platform::CPUPlace place_src;
    platform::NPUPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
534 535
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
536 537 538
    platform::CPUPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
539 540
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
541 542 543
    platform::NPUPinnedPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
544 545
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPU) {
546 547 548
    platform::NPUPinnedPlace place_src;
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
549 550
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
551 552 553
    platform::NPUPinnedPlace place_dst;
    platform::NPUPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
554
#ifdef PADDLE_WITH_CUSTOM_DEVICE
555 556
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
557 558 559
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
560 561
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
562 563 564
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
565 566
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
567 568 569 570
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
571 572 573 574 575
  }
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (CPUPlace).
template <>
576 577 578 579 580 581
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
                                     size_t num,
                                     aclrtStream stream) {
582
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
583 584 585 586
}

// NOTE: only for (CPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace).
template <>
587 588 589 590 591 592
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::CPUPlace src_place,
                                     const void* src,
                                     size_t num,
                                     aclrtStream stream) {
593
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
594 595 596 597
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPlace)
template <>
598 599 600 601 602 603 604 605 606 607 608 609
void Copy<phi::NPUPlace, phi::Place>(phi::NPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
                                     size_t num,
                                     aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()),
       dst,
       src_place,
       src,
       num,
       stream);
610 611 612 613
}

// NOTE: only for (NPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
614 615 616 617 618 619 620 621 622 623 624 625
void Copy<phi::Place, phi::NPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::NPUPlace src_place,
                                     const void* src,
                                     size_t num,
                                     aclrtStream stream) {
  Copy(dst_place,
       dst,
       phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src,
       num,
       stream);
626 627 628 629
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPinnedPlace)
template <>
630
void Copy<phi::NPUPinnedPlace, phi::Place>(phi::NPUPinnedPlace dst_place,
631 632 633 634
                                           void* dst,
                                           phi::Place src_place,
                                           const void* src,
                                           size_t num,
635 636
                                           aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
637 638 639 640
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
641 642
void Copy<phi::Place, phi::NPUPinnedPlace>(phi::Place dst_place,
                                           void* dst,
643
                                           phi::NPUPinnedPlace src_place,
644 645
                                           const void* src,
                                           size_t num,
646 647
                                           aclrtStream stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
648 649 650 651
}

// NOTE: only for (CPUPlace) -> (NPUPinnedPlace)
template <>
652
void Copy<phi::NPUPinnedPlace, phi::Place>(phi::NPUPinnedPlace dst_place,
653 654 655 656
                                           void* dst,
                                           phi::Place src_place,
                                           const void* src,
                                           size_t num) {
657
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
658 659 660 661
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace)
template <>
662 663
void Copy<phi::Place, phi::NPUPinnedPlace>(phi::Place dst_place,
                                           void* dst,
664
                                           phi::NPUPinnedPlace src_place,
665 666
                                           const void* src,
                                           size_t num) {
667
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, nullptr);
668
}
669 670
#endif

671
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
sneaxiy 已提交
672 673
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

674 675 676 677 678 679 680 681 682 683 684 685 686
#ifdef PADDLE_WITH_HIP
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  hipStreamSynchronize(0);
#else
  hipError_t e_sync = hipSuccess;
  while (e_sync = hipStreamQuery(0)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
}
#else
687 688 689 690 691 692 693 694 695 696 697
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  cudaStreamSynchronize(0);
#else
  cudaError_t e_sync = cudaSuccess;
  while (e_sync = cudaStreamQuery(0)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
}
698
#endif
699

700 701 702 703 704 705
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

706
template <>
D
dzhwinter 已提交
707
void Copy<platform::CPUPlace, platform::CUDAPlace>(
708 709 710 711 712 713
    platform::CPUPlace dst_place,
    void* dst,
    platform::CUDAPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
Z
Zeng Jinle 已提交
714
  if (UNLIKELY(num == 0)) return;
715

716 717
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
718
          << dst_place << " by stream(" << stream << ")";
719
  if (stream) {
720 721
    platform::RecordEvent record_event(
        "GpuMemcpyAsync:GPU->CPU", platform::TracerEventType::UserDefined, 1);
722
#ifdef PADDLE_WITH_HIP
723 724 725 726
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             hipMemcpyDeviceToHost,
727
                             reinterpret_cast<gpuStream_t>(stream));
728
#else
729 730 731 732
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             cudaMemcpyDeviceToHost,
733
                             reinterpret_cast<gpuStream_t>(stream));
734
#endif
735
  } else {
736 737
    platform::RecordEvent record_event(
        "GpuMemcpySync:GPU->CPU", platform::TracerEventType::UserDefined, 1);
738 739 740
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
741
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
742
#endif
S
sneaxiy 已提交
743 744
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
745
      SyncCUDAStream();
S
sneaxiy 已提交
746
    }
747
  }
748 749 750
}

template <>
D
dzhwinter 已提交
751
void Copy<platform::CUDAPlace, platform::CPUPlace>(
752 753 754 755 756 757
    platform::CUDAPlace dst_place,
    void* dst,
    platform::CPUPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
Z
Zeng Jinle 已提交
758 759
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
760
  platform::SetDeviceId(dst_place.device);
761 762
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
763
  if (stream) {
764 765
    platform::RecordEvent record_event(
        "GpuMemcpyAsync:CPU->GPU", platform::TracerEventType::UserDefined, 1);
766
#ifdef PADDLE_WITH_HIP
767 768 769 770
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             hipMemcpyHostToDevice,
771
                             reinterpret_cast<gpuStream_t>(stream));
772
#else
773 774 775 776
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             cudaMemcpyHostToDevice,
777
                             reinterpret_cast<gpuStream_t>(stream));
778
#endif
779
  } else {
780 781
    platform::RecordEvent record_event(
        "GpuMemcpySync:CPU->GPU", platform::TracerEventType::UserDefined, 1);
782 783 784
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
785
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
786
#endif
S
sneaxiy 已提交
787 788
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
789
      SyncCUDAStream();
S
sneaxiy 已提交
790
    }
791
  }
792 793 794
}

template <>
D
dzhwinter 已提交
795
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
796 797 798 799 800 801
    platform::CUDAPlace dst_place,
    void* dst,
    platform::CUDAPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
Z
Zeng Jinle 已提交
802 803
  if (UNLIKELY(num == 0)) return;

804
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
805
          << dst_place << " by stream(" << stream << ")";
806
  if (dst_place == src_place) {
L
liaogang 已提交
807
    platform::SetDeviceId(src_place.device);
808
    if (stream) {
809 810 811
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
812
#ifdef PADDLE_WITH_HIP
813 814 815 816
      platform::GpuMemcpyAsync(dst,
                               src,
                               num,
                               hipMemcpyDeviceToDevice,
817
                               reinterpret_cast<gpuStream_t>(stream));
818
#else
819 820 821 822
      platform::GpuMemcpyAsync(dst,
                               src,
                               num,
                               cudaMemcpyDeviceToDevice,
823
                               reinterpret_cast<gpuStream_t>(stream));
824
#endif
825
    } else {
826 827 828
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
829 830 831
#ifdef PADDLE_WITH_HIP
      platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToDevice);
#else
832
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
833
#endif
834
    }
835
  } else {
836
    if (stream) {
837 838 839
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
840 841 842 843 844 845
      platform::GpuMemcpyPeerAsync(dst,
                                   dst_place.device,
                                   src,
                                   src_place.device,
                                   num,
                                   reinterpret_cast<gpuStream_t>(stream));
846
    } else {
847 848 849
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
850 851
      platform::GpuMemcpyPeerSync(
          dst, dst_place.device, src, src_place.device, num);
852
    }
853 854 855
  }
}

C
chengduoZH 已提交
856 857
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
858 859 860 861 862
    platform::CPUPlace dst_place,
    void* dst,
    platform::CUDAPinnedPlace src_place,
    const void* src,
    size_t num) {
863 864
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
865
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
866 867 868 869 870
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
871 872 873 874 875
    platform::CUDAPinnedPlace dst_place,
    void* dst,
    platform::CPUPlace src_place,
    const void* src,
    size_t num) {
876 877
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
878
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
879 880 881 882 883
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
884 885 886 887 888
    platform::CUDAPinnedPlace dst_place,
    void* dst,
    platform::CUDAPinnedPlace src_place,
    const void* src,
    size_t num) {
889 890
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
891
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
892 893 894 895 896
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
897 898 899 900 901 902
    platform::CUDAPinnedPlace dst_place,
    void* dst,
    platform::CUDAPlace src_place,
    const void* src,
    size_t num,
    void* stream) {
Z
Zeng Jinle 已提交
903
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
904
  platform::SetDeviceId(src_place.device);
905 906
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
907
  if (stream) {
908 909 910
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
911
#ifdef PADDLE_WITH_HIP
912 913 914 915
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             hipMemcpyDeviceToHost,
916
                             reinterpret_cast<gpuStream_t>(stream));
917
#else
918 919 920 921
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             cudaMemcpyDeviceToHost,
922
                             reinterpret_cast<gpuStream_t>(stream));
923
#endif
924
  } else {
925 926 927
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
928 929 930
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
931
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
932
#endif
933
  }
C
chengduoZH 已提交
934 935 936 937
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
938 939 940 941 942
    platform::CUDAPlace dst_place,
    void* dst,
    platform::CUDAPinnedPlace src_place,
    const void* src,
    size_t num,
943
    void* stream) {
Z
Zeng Jinle 已提交
944 945
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
946
  platform::SetDeviceId(dst_place.device);
947 948
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
949
  if (stream) {
950 951 952
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
953
#ifdef PADDLE_WITH_HIP
954 955 956 957
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             hipMemcpyHostToDevice,
958
                             reinterpret_cast<gpuStream_t>(stream));
959
#else
960 961 962 963
    platform::GpuMemcpyAsync(dst,
                             src,
                             num,
                             cudaMemcpyHostToDevice,
964
                             reinterpret_cast<gpuStream_t>(stream));
965
#endif
966
  } else {
967 968 969
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
970 971 972
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
973
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
974
#endif
975
  }
C
chengduoZH 已提交
976 977
}

978 979
// NOTE: only for CPUPlace、CUDAPlace and CUDAPinnedPlace.
template <>
980 981 982 983 984 985
void Copy<phi::Place, phi::Place>(phi::Place dst_place,
                                  void* dst,
                                  phi::Place src_place,
                                  const void* src,
                                  size_t num,
                                  void* stream) {
986 987
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
988 989
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
990 991
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::GPU) {
992 993 994
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
995 996
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
997 998 999
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1000 1001
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::GPU) {
1002 1003 1004
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
1005 1006
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1007 1008 1009
    platform::CPUPlace place_src;
    platform::CUDAPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
1010 1011
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1012 1013 1014
    platform::CPUPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1015 1016
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1017 1018 1019
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1020 1021
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPU) {
1022 1023 1024
    platform::CUDAPinnedPlace place_src;
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
1025 1026
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1027 1028 1029
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
1030
#ifdef PADDLE_WITH_CUSTOM_DEVICE
1031 1032
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1033 1034 1035
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
1036 1037
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
1038 1039 1040
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1041 1042
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1043 1044 1045 1046
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
1047 1048 1049 1050 1051
  }
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CPUPlace).
template <>
1052 1053 1054 1055 1056 1057
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
1058
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
1059 1060 1061 1062
}

// NOTE: only for (CPUPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace).
template <>
1063 1064 1065 1066 1067 1068
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::CPUPlace src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
1069
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
1070 1071 1072 1073
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPlace)
template <>
1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
void Copy<phi::GPUPlace, phi::Place>(phi::GPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()),
       dst,
       src_place,
       src,
       num,
       stream);
1086 1087 1088 1089
}

// NOTE: only for (CUDAPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
void Copy<phi::Place, phi::GPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::GPUPlace src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
  Copy(dst_place,
       dst,
       phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src,
       num,
       stream);
1102 1103 1104 1105
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPinnedPlace)
template <>
1106
void Copy<phi::GPUPinnedPlace, phi::Place>(phi::GPUPinnedPlace dst_place,
1107 1108 1109 1110
                                           void* dst,
                                           phi::Place src_place,
                                           const void* src,
                                           size_t num,
1111 1112
                                           void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
1113 1114 1115 1116
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
1117 1118
void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place,
                                           void* dst,
1119
                                           phi::GPUPinnedPlace src_place,
1120 1121
                                           const void* src,
                                           size_t num,
1122 1123
                                           void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
1124 1125 1126 1127
}

// NOTE: only for (CPUPlace) -> (CUDAPinnedPlace)
template <>
1128
void Copy<phi::GPUPinnedPlace, phi::Place>(phi::GPUPinnedPlace dst_place,
1129 1130 1131 1132
                                           void* dst,
                                           phi::Place src_place,
                                           const void* src,
                                           size_t num) {
1133
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
1134 1135 1136 1137
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace)
template <>
1138 1139
void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place,
                                           void* dst,
1140
                                           phi::GPUPinnedPlace src_place,
1141 1142
                                           const void* src,
                                           size_t num) {
1143
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, nullptr);
1144
}
L
Luo Tao 已提交
1145
#endif
Y
Yi Wang 已提交
1146

F
fwenguang 已提交
1147 1148 1149 1150 1151
#ifdef PADDLE_WITH_MLU
template <>
void Copy<platform::CPUPlace, platform::MLUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
1152 1153
                                                  const void* src,
                                                  size_t num,
1154
                                                  void* stream) {
F
fwenguang 已提交
1155 1156 1157 1158 1159 1160
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(src_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
1161 1162 1163
    platform::RecordEvent record_event("MLUMemcpyD2HAsync:MLU->CPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
1164 1165
    platform::MLUMemcpyD2HAsync(
        dst, src, num, reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1166
  } else {
1167 1168 1169
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
1170 1171
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
1172 1173
    platform::RecordEvent record_event(
        "MLUMemcpyD2HSync:MLU->CPU", platform::TracerEventType::UserDefined, 1);
F
fwenguang 已提交
1174 1175 1176 1177 1178 1179 1180 1181
    platform::MLUMemcpyD2HSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::CPUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
1182 1183
                                                  const void* src,
                                                  size_t num,
1184
                                                  void* stream) {
F
fwenguang 已提交
1185 1186 1187 1188 1189 1190
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(dst_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
1191 1192 1193
    platform::RecordEvent record_event("MLUMemcpyH2DAsync:CPU->MLU",
                                       platform::TracerEventType::UserDefined,
                                       1);
1194 1195
    platform::MLUMemcpyH2DAsync(
        dst, src, num, reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1196
  } else {
1197 1198 1199
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
1200 1201
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
1202 1203
    platform::RecordEvent record_event(
        "MLUMemcpyH2DSync:CPU->MLU", platform::TracerEventType::UserDefined, 1);
F
fwenguang 已提交
1204 1205 1206 1207 1208 1209 1210 1211
    platform::MLUMemcpyH2DSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::MLUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
1212 1213
                                                  const void* src,
                                                  size_t num,
1214
                                                  void* stream) {
F
fwenguang 已提交
1215 1216 1217 1218 1219 1220 1221
  if (UNLIKELY(num == 0)) return;

  if (dst_place == src_place) {
    platform::SetMLUDeviceId(dst_place.device);
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
1222 1223 1224
      platform::RecordEvent record_event("MLUMemcpyD2DAsync(same_mlu):MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
1225 1226
      platform::MLUMemcpyD2DAsync(
          dst, src, num, reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1227
    } else {
1228 1229 1230 1231
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
1232 1233
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
1234 1235 1236
      platform::RecordEvent record_event("MLUMemcpyD2DSync(same_mlu):MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
F
fwenguang 已提交
1237 1238 1239 1240 1241 1242
      platform::MLUMemcpyD2DSync(dst, src, num);
    }
  } else {
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
1243 1244 1245
      platform::RecordEvent record_event("MLUMemcpyPeerAsync:MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
1246 1247 1248 1249 1250 1251
      platform::MLUMemcpyPeerAsync(dst,
                                   dst_place.device,
                                   src,
                                   src_place.device,
                                   num,
                                   reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1252 1253 1254
    } else {
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
1255 1256 1257
      platform::RecordEvent record_event("MLUMemcpyPeerSync:MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
1258 1259
      platform::MLUMemcpyPeerSync(
          dst, dst_place.device, src, src_place.device, num);
F
fwenguang 已提交
1260 1261 1262 1263
    }
  }
}

1264 1265
// NOTE: only for CPUPlace and MLUPlace.
template <>
1266 1267 1268 1269 1270 1271
void Copy<phi::Place, phi::Place>(phi::Place dst_place,
                                  void* dst,
                                  phi::Place src_place,
                                  const void* src,
                                  size_t num,
                                  void* stream) {
1272 1273
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
1274 1275
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
1276 1277
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::MLU) {
1278 1279 1280
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
1281 1282
  } else if (src_place.GetType() == phi::AllocationType::MLU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1283 1284 1285
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1286 1287
  } else if (src_place.GetType() == phi::AllocationType::MLU &&
             dst_place.GetType() == phi::AllocationType::MLU) {
1288 1289 1290
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
1291
#ifdef PADDLE_WITH_CUSTOM_DEVICE
1292 1293
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1294 1295 1296
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
1297 1298
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
1299 1300 1301
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1302 1303
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1304 1305 1306 1307
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
1308 1309 1310 1311 1312
  }
}

// NOTE: only for (CPUPlace and MLUPlace) -> (MLUPlace)
template <>
1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324
void Copy<phi::MLUPlace, phi::Place>(phi::MLUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()),
       dst,
       src_place,
       src,
       num,
       stream);
1325 1326 1327 1328
}

// NOTE: only for (MLUPlace) -> (CPUPlace and MLUPlace)
template <>
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340
void Copy<phi::Place, phi::MLUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::MLUPlace src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
  Copy(dst_place,
       dst,
       phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src,
       num,
       stream);
1341 1342
}

F
fwenguang 已提交
1343 1344
// NOTE: only for (MLUPlace) -> (CPUPlace) with mluStream.
template <>
1345 1346 1347 1348 1349 1350
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
1351
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
F
fwenguang 已提交
1352 1353 1354 1355
}

// NOTE: only for (CPUPlace) -> (MLUPlace) with mluStream.
template <>
1356 1357 1358 1359 1360 1361
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::CPUPlace src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
1362
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
F
fwenguang 已提交
1363 1364
}

F
fwenguang 已提交
1365 1366
#endif  // PADDLE_WITH_MLU

1367 1368
// NOTE: Only for CPUPlace, XPUPlace and PinnedPlace.
template <>
1369 1370 1371 1372
void Copy<phi::Place, phi::Place>(phi::Place dst_place,
                                  void* dst,
                                  phi::Place src_place,
                                  const void* src,
1373
                                  size_t num) {
1374 1375 1376
  if (UNLIKELY(num == 0)) return;
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
1377 1378
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
1379 1380 1381
    std::memcpy(dst, src, num);
  }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1382 1383
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1384
    std::memcpy(dst, src, num);
1385 1386
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1387
    std::memcpy(dst, src, num);
1388 1389
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1390 1391 1392 1393
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_ASCEND_CL
1394 1395
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::NPUPINNED) {
1396
    std::memcpy(dst, src, num);
1397 1398
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1399
    std::memcpy(dst, src, num);
1400 1401
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
1402 1403 1404 1405
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_XPU
1406 1407
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::CPU) {
1408 1409
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
1410 1411
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::XPU) {
1412 1413 1414
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1415 1416
  } else if (src_place.GetType() == phi::AllocationType::XPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1417 1418 1419
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
1420 1421
  } else if (src_place.GetType() == phi::AllocationType::XPU &&
             dst_place.GetType() == phi::AllocationType::XPU) {
1422 1423 1424 1425 1426
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
#endif
A
Allen Guo 已提交
1427
#ifdef PADDLE_WITH_IPU
1428 1429
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::IPU) {
A
Allen Guo 已提交
1430 1431 1432
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1433 1434
  } else if (src_place.GetType() == phi::AllocationType::IPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
A
Allen Guo 已提交
1435 1436 1437
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
1438 1439
  } else if (src_place.GetType() == phi::AllocationType::IPU &&
             dst_place.GetType() == phi::AllocationType::IPU) {
A
Allen Guo 已提交
1440 1441 1442 1443 1444
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
#endif
1445 1446 1447 1448
}

// NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace).
template <>
1449 1450 1451 1452
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::CPUPlace src_place,
                                     const void* src,
1453 1454
                                     size_t num) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num);
1455 1456 1457 1458
}

// NOTE: Only for (CPUPlace and PinnedPlace) -> (CPUPlace).
template <>
1459 1460 1461 1462
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
1463 1464
                                     size_t num) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num);
1465 1466
}

1467 1468 1469 1470 1471
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUDA) && \
    !defined(PADDLE_WITH_ASCEND_CL) && !defined(PADDLE_WITH_HIP) &&     \
    !defined(PADDLE_WITH_MLU)

template <>
1472 1473 1474 1475 1476 1477
void Copy<phi::Place, phi::Place>(phi::Place dst_place,
                                  void* dst,
                                  phi::Place src_place,
                                  const void* src,
                                  size_t num,
                                  void* stream) {
1478 1479
  if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
      dst_place.GetType() == phi::AllocationType::CUSTOM) {
1480 1481 1482
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
1483 1484
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
1485 1486 1487
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1488 1489
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1490 1491 1492 1493 1494 1495 1496
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
  }
}

template <>
1497 1498 1499 1500 1501 1502
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place,
                                     void* dst,
                                     phi::Place src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
1503
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
1504 1505 1506 1507
}

// NOTE: only for (CPUPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace).
template <>
1508 1509 1510 1511 1512 1513
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place,
                                     void* dst,
                                     phi::CPUPlace src_place,
                                     const void* src,
                                     size_t num,
                                     void* stream) {
1514
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
1515 1516 1517
}
#endif

Y
Yi Wang 已提交
1518 1519
}  // namespace memory
}  // namespace paddle