memcpy.cc 55.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16

17
#include "paddle/fluid/platform/device/device_wrapper.h"
18
#include "paddle/fluid/platform/device_context.h"
19
#include "paddle/fluid/platform/profiler.h"
20
#include "paddle/phi/common/place.h"
21

22 23 24 25 26 27 28 29
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#endif

#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif

30 31 32
namespace paddle {
namespace memory {

33 34 35 36 37 38 39 40 41 42 43 44 45 46
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <>
void Copy<platform::CPUPlace, platform::CustomPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CustomPlace src_place,
    const void* src, size_t num, void* stream) {
  if (UNLIKELY(num == 0)) return;

  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

47 48 49
  phi::DeviceManager::SetDevice(src_place);
  phi::stream::Stream stream_wrapper(src_place, stream);
  phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2H(
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
      dst, src, num, &stream_wrapper);
}

template <>
void Copy<platform::CustomPlace, platform::CPUPlace>(
    platform::CustomPlace dst_place, void* dst, platform::CPUPlace src_place,
    const void* src, size_t num, void* stream) {
  if (UNLIKELY(num == 0)) return;
  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

65 66 67
  phi::DeviceManager::SetDevice(dst_place);
  phi::stream::Stream stream_wrapper(dst_place, stream);
  phi::DeviceManager::GetDeviceWithPlace(dst_place)->MemoryCopyH2D(
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
      dst, src, num, &stream_wrapper);
}

template <>
void Copy<platform::CustomPlace, platform::CustomPlace>(
    platform::CustomPlace dst_place, void* dst, platform::CustomPlace src_place,
    const void* src, size_t num, void* stream) {
  if (UNLIKELY(num == 0)) return;

  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

  if (src_type == dst_type) {
85 86
    phi::DeviceManager::SetDevice(src_place);
    phi::stream::Stream stream_wrapper(src_place, stream);
87 88 89 90

    auto src_id = platform::PlaceHelper::GetDeviceId(src_place);
    auto dst_id = platform::PlaceHelper::GetDeviceId(dst_place);
    if (src_id == dst_id) {
91
      phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2D(
92 93
          dst, src, num, &stream_wrapper);
    } else {
94
      phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyP2P(
95 96 97 98 99 100 101 102 103
          dst_place, dst, src, num, &stream_wrapper);
    }
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Copy between %s and %s is not supported.", src_type, dst_type));
  }
}
#endif  // PADDLE_WITH_CUSTOM_DEVICE

104 105 106 107
template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
108
  if (UNLIKELY(num == 0)) return;
109
  VLOG(4) << "src: " << src << ", dst: " << dst << ", num: " << num;
110 111
  std::memcpy(dst, src, num);
}
112

J
jianghaicheng 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
#ifdef PADDLE_WITH_IPU
template <>
void Copy<platform::IPUPlace, platform::CPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::CPUPlace, platform::IPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::IPUPlace, platform::IPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
138 139 140

// NOTE: only for (CPUPlace and IPUPlace) -> (IPUPlace).
template <>
141 142 143 144
void Copy<phi::IPUPlace, phi::Place>(phi::IPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num) {
  if (src_place.GetType() == phi::AllocationType::CPU) {
145 146
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
147
  } else if (src_place.GetType() == phi::AllocationType::IPU) {
148 149 150 151 152 153 154
    platform::IPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (IPUPlace) -> (CPUPlace and IPUPlace).
template <>
155 156 157 158
void Copy<phi::Place, phi::IPUPlace>(phi::Place dst_place, void* dst,
                                     phi::IPUPlace src_place, const void* src,
                                     size_t num) {
  if (dst_place.GetType() == phi::AllocationType::CPU) {
159 160
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
161
  } else if (dst_place.GetType() == phi::AllocationType::IPU) {
162 163 164 165
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
J
jianghaicheng 已提交
166
#endif
167

168 169 170 171 172 173 174
#ifdef PADDLE_WITH_XPU
template <>
void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
175
    VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
176 177
    return;
  }
178
  platform::MemcpySyncH2D(dst, src, num, dst_place);
179 180 181 182 183 184 185 186
}

template <>
void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
187
    VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
188 189
    return;
  }
190
  platform::MemcpySyncD2H(dst, src, num, src_place);
191 192 193 194 195 196 197 198
}

template <>
void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
199
    VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
200 201
    return;
  }
202
  platform::MemcpySyncD2D(dst, dst_place, src, src_place, num);
203
}
204 205 206

// NOTE: only for (CPUPlace and XPUPlace) -> (XPUPlace).
template <>
207 208 209 210
void Copy<phi::XPUPlace, phi::Place>(phi::XPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num) {
  if (src_place.GetType() == phi::AllocationType::CPU) {
211 212
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
213
  } else if (src_place.GetType() == phi::AllocationType::XPU) {
214 215 216 217 218 219 220
    platform::XPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (XPUPlace) -> (CPUPlace and XPUPlace).
template <>
221 222 223 224
void Copy<phi::Place, phi::XPUPlace>(phi::Place dst_place, void* dst,
                                     phi::XPUPlace src_place, const void* src,
                                     size_t num) {
  if (dst_place.GetType() == phi::AllocationType::CPU) {
225 226
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
227
  } else if (dst_place.GetType() == phi::AllocationType::XPU) {
228 229 230 231
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
232 233
#endif

234 235 236 237 238 239
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
240
                                                  void* stream) {
241 242 243
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);
244

245 246
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
247

248
  if (stream) {
249 250
    platform::RecordEvent record_event(
        "NpuMemcpyAsync:CPU->NPU", platform::TracerEventType::UserDefined, 1);
251 252
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE,
                             reinterpret_cast<aclrtStream>(stream));
253
  } else {
254 255 256 257 258 259
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

260 261
    platform::RecordEvent record_event(
        "NpuMemcpySync:CPU->NPU", platform::TracerEventType::UserDefined, 1);
262 263 264 265 266 267 268 269 270
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

template <>
void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
271
                                                  void* stream) {
272 273 274
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);
275

276 277
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
278

279
  if (stream) {
280 281
    platform::RecordEvent record_event(
        "NpuMemcpyAsync:NPU->CPU", platform::TracerEventType::UserDefined, 1);
282 283
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST,
                             reinterpret_cast<aclrtStream>(stream));
284
  } else {
285 286 287
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

288 289
    platform::RecordEvent record_event(
        "NpuMemcpySync:NPU->CPU", platform::TracerEventType::UserDefined, 1);
290 291 292 293 294 295 296 297 298
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
299
                                                  void* stream) {
300 301 302 303 304 305 306
  if (UNLIKELY(num == 0)) return;

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by stream(" << stream << ")";
  if (dst_place == src_place) {
    platform::SetNPUDeviceId(src_place.device);
    if (stream) {
307 308 309
      platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
310
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
311
                               reinterpret_cast<aclrtStream>(stream));
312
    } else {
313 314 315 316
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

317 318 319
      platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
320 321 322 323 324 325 326 327 328
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  } else {
    if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Peer access between NPU places is not allowed."));
    }
    if (stream) {
      // TODO(zhiqiu): support peer access?
329 330 331
      platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
332
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
333
                               reinterpret_cast<aclrtStream>(stream));
334
    } else {
335 336 337 338
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

339 340 341
      platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
342 343 344 345
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  }
}
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379

template <>
void Copy<platform::CPUPlace, platform::NPUPinnedPlace>(
    platform::CPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::CPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::CPUPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPinnedPlace>(
    platform::NPUPinnedPlace dst_place, void* dst,
    platform::NPUPinnedPlace src_place, const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::NPUPlace src_place,
380
    const void* src, size_t num, void* stream) {
381 382 383 384 385 386 387 388
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
389 390 391
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->NPUPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
392 393
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST,
                             reinterpret_cast<aclrtStream>(stream));
394 395 396 397
  } else {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

398 399 400
    platform::RecordEvent record_event("NpuMemcpySync:NPU->NPUPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
401 402 403 404 405 406 407
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPinnedPlace>(
    platform::NPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
408
    const void* src, size_t num, void* stream) {
409 410 411 412 413 414 415 416
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
417 418 419
    platform::RecordEvent record_event("NpuMemcpyAsync:NPUPinned->NPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
420 421
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE,
                             reinterpret_cast<aclrtStream>(stream));
422 423 424 425 426 427 428
  } else {
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

429 430 431
    platform::RecordEvent record_event("NpuMemcpySync:NPUPinned->NPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
432 433 434 435
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

436 437
// NOTE: only for CPUPlace, NPUPlace and NPUPinnedPlace.
template <>
438 439 440 441 442
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, aclrtStream stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
443 444
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
445 446
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::NPU) {
447 448 449
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
450 451
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
452 453 454
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
455 456
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::NPU) {
457 458 459
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
460 461
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
462 463 464
    platform::CPUPlace place_src;
    platform::NPUPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
465 466
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
467 468 469
    platform::CPUPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
470 471
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
472 473 474
    platform::NPUPinnedPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
475 476
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPU) {
477 478 479
    platform::NPUPinnedPlace place_src;
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
480 481
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
482 483 484
    platform::NPUPinnedPlace place_dst;
    platform::NPUPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
485
#ifdef PADDLE_WITH_CUSTOM_DEVICE
486 487
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
488 489 490
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
491 492
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
493 494 495
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
496 497
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
498 499 500 501
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
502 503 504 505 506
  }
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (CPUPlace).
template <>
507 508 509 510
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
511 512 513 514
}

// NOTE: only for (CPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace).
template <>
515 516 517 518
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
519 520 521 522
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPlace)
template <>
523 524 525 526 527
void Copy<phi::NPUPlace, phi::Place>(phi::NPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst, src_place,
       src, num, stream);
528 529 530 531
}

// NOTE: only for (NPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
532 533 534 535 536
void Copy<phi::Place, phi::NPUPlace>(phi::Place dst_place, void* dst,
                                     phi::NPUPlace src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src, num, stream);
537 538 539 540
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPinnedPlace)
template <>
541 542 543 544 545
void Copy<phi::NPUPinnedPlace, phi::Place>(phi::NPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num,
                                           aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
546 547 548 549
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
550 551 552 553 554
void Copy<phi::Place, phi::NPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::NPUPinnedPlace src_place,
                                           const void* src, size_t num,
                                           aclrtStream stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
555 556 557 558
}

// NOTE: only for (CPUPlace) -> (NPUPinnedPlace)
template <>
559 560 561 562
void Copy<phi::NPUPinnedPlace, phi::Place>(phi::NPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
563 564 565 566
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace)
template <>
567 568 569 570
void Copy<phi::Place, phi::NPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::NPUPinnedPlace src_place,
                                           const void* src, size_t num) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, nullptr);
571
}
572 573
#endif

574
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
sneaxiy 已提交
575 576
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

577 578 579 580 581 582 583 584 585 586 587 588 589
#ifdef PADDLE_WITH_HIP
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  hipStreamSynchronize(0);
#else
  hipError_t e_sync = hipSuccess;
  while (e_sync = hipStreamQuery(0)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
}
#else
590 591 592 593 594 595 596 597 598 599 600
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  cudaStreamSynchronize(0);
#else
  cudaError_t e_sync = cudaSuccess;
  while (e_sync = cudaStreamQuery(0)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
}
601
#endif
602

603 604 605 606 607 608
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

609
template <>
D
dzhwinter 已提交
610 611
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
612
    const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
613
  if (UNLIKELY(num == 0)) return;
614

615 616
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
617
          << dst_place << " by stream(" << stream << ")";
618
  if (stream) {
619 620
    platform::RecordEvent record_event(
        "GpuMemcpyAsync:GPU->CPU", platform::TracerEventType::UserDefined, 1);
621
#ifdef PADDLE_WITH_HIP
622 623
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
624
#else
625 626
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
627
#endif
628
  } else {
629 630
    platform::RecordEvent record_event(
        "GpuMemcpySync:GPU->CPU", platform::TracerEventType::UserDefined, 1);
631 632 633
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
634
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
635
#endif
S
sneaxiy 已提交
636 637
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
638
      SyncCUDAStream();
S
sneaxiy 已提交
639
    }
640
  }
641 642 643
}

template <>
D
dzhwinter 已提交
644 645
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
646
    const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
647 648
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
649
  platform::SetDeviceId(dst_place.device);
650
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
651
          << dst_place << " by stream(" << stream << ")";
652
  if (stream) {
653 654
    platform::RecordEvent record_event(
        "GpuMemcpyAsync:CPU->GPU", platform::TracerEventType::UserDefined, 1);
655
#ifdef PADDLE_WITH_HIP
656 657
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
658
#else
659 660
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
661
#endif
662
  } else {
663 664
    platform::RecordEvent record_event(
        "GpuMemcpySync:CPU->GPU", platform::TracerEventType::UserDefined, 1);
665 666 667
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
668
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
669
#endif
S
sneaxiy 已提交
670 671
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
672
      SyncCUDAStream();
S
sneaxiy 已提交
673
    }
674
  }
675 676 677
}

template <>
D
dzhwinter 已提交
678 679
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
680
    const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
681 682
  if (UNLIKELY(num == 0)) return;

683
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
684
          << dst_place << " by stream(" << stream << ")";
685
  if (dst_place == src_place) {
L
liaogang 已提交
686
    platform::SetDeviceId(src_place.device);
687
    if (stream) {
688 689 690
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
691
#ifdef PADDLE_WITH_HIP
692 693
      platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToDevice,
                               reinterpret_cast<gpuStream_t>(stream));
694
#else
695 696
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice,
                               reinterpret_cast<gpuStream_t>(stream));
697
#endif
698
    } else {
699 700 701
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
702 703 704
#ifdef PADDLE_WITH_HIP
      platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToDevice);
#else
705
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
706
#endif
707
    }
708
  } else {
709
    if (stream) {
710 711 712
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
713
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
714
                                   num, reinterpret_cast<gpuStream_t>(stream));
715
    } else {
716 717 718
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU",
                                         platform::TracerEventType::UserDefined,
                                         1);
719
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
720
                                  num);
721
    }
722 723 724
  }
}

C
chengduoZH 已提交
725 726 727 728
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
729 730
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
731
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
732 733 734 735 736 737 738
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
739 740
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
741
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
742 743 744 745 746 747 748
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
749 750
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
751
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
752 753 754 755 756 757
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
758
    platform::CUDAPlace src_place, const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
759
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
760
  platform::SetDeviceId(src_place.device);
761 762
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
763
  if (stream) {
764 765 766
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
767
#ifdef PADDLE_WITH_HIP
768 769
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
770
#else
771 772
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
773
#endif
774
  } else {
775 776 777
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned",
                                       platform::TracerEventType::UserDefined,
                                       1);
778 779 780
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
781
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
782
#endif
783
  }
C
chengduoZH 已提交
784 785 786 787 788 789
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
790
    void* stream) {
Z
Zeng Jinle 已提交
791 792
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
793
  platform::SetDeviceId(dst_place.device);
794 795
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
796
  if (stream) {
797 798 799
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
800
#ifdef PADDLE_WITH_HIP
801 802
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
803
#else
804 805
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
806
#endif
807
  } else {
808 809 810
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
811 812 813
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
814
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
815
#endif
816
  }
C
chengduoZH 已提交
817 818
}

819 820
// NOTE: only for CPUPlace、CUDAPlace and CUDAPinnedPlace.
template <>
821 822 823 824 825
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, void* stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
826 827
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
828 829
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::GPU) {
830 831 832
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
833 834
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
835 836 837
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
838 839
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::GPU) {
840 841 842
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
843 844
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
845 846 847
    platform::CPUPlace place_src;
    platform::CUDAPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
848 849
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
850 851 852
    platform::CPUPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
853 854
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
855 856 857
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
858 859
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPU) {
860 861 862
    platform::CUDAPinnedPlace place_src;
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
863 864
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
865 866 867
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
868
#ifdef PADDLE_WITH_CUSTOM_DEVICE
869 870
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
871 872 873
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
874 875
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
876 877 878
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
879 880
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
881 882 883 884
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
885 886 887 888 889
  }
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CPUPlace).
template <>
890 891 892 893
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
894 895 896 897
}

// NOTE: only for (CPUPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace).
template <>
898 899 900 901
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
902 903 904 905
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPlace)
template <>
906 907 908 909 910
void Copy<phi::GPUPlace, phi::Place>(phi::GPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst, src_place,
       src, num, stream);
911 912 913 914
}

// NOTE: only for (CUDAPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
915 916 917 918 919
void Copy<phi::Place, phi::GPUPlace>(phi::Place dst_place, void* dst,
                                     phi::GPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src, num, stream);
920 921 922 923
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPinnedPlace)
template <>
924 925 926 927 928
void Copy<phi::GPUPinnedPlace, phi::Place>(phi::GPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num,
                                           void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
929 930 931 932
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
933 934 935 936 937
void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::GPUPinnedPlace src_place,
                                           const void* src, size_t num,
                                           void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
938 939 940 941
}

// NOTE: only for (CPUPlace) -> (CUDAPinnedPlace)
template <>
942 943 944 945
void Copy<phi::GPUPinnedPlace, phi::Place>(phi::GPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
946 947 948 949
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace)
template <>
950 951 952 953
void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::GPUPinnedPlace src_place,
                                           const void* src, size_t num) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, nullptr);
954
}
L
Luo Tao 已提交
955
#endif
Y
Yi Wang 已提交
956

F
fwenguang 已提交
957 958 959 960 961 962
#ifdef PADDLE_WITH_MLU
template <>
void Copy<platform::CPUPlace, platform::MLUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
                                                  const void* src, size_t num,
963
                                                  void* stream) {
F
fwenguang 已提交
964 965 966 967 968 969
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(src_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
970 971 972
    platform::RecordEvent record_event("MLUMemcpyD2HAsync:MLU->CPU",
                                       platform::TracerEventType::UserDefined,
                                       1);
973 974
    platform::MLUMemcpyD2HAsync(dst, src, num,
                                reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
975
  } else {
976 977 978
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
979 980
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
981 982
    platform::RecordEvent record_event(
        "MLUMemcpyD2HSync:MLU->CPU", platform::TracerEventType::UserDefined, 1);
F
fwenguang 已提交
983 984 985 986 987 988 989 990 991
    platform::MLUMemcpyD2HSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::CPUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
992
                                                  void* stream) {
F
fwenguang 已提交
993 994 995 996 997 998
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(dst_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
999 1000 1001
    platform::RecordEvent record_event("MLUMemcpyH2DAsync:CPU->MLU",
                                       platform::TracerEventType::UserDefined,
                                       1);
1002 1003
    platform::MLUMemcpyH2DAsync(dst, src, num,
                                reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1004
  } else {
1005 1006 1007
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
1008 1009
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
1010 1011
    platform::RecordEvent record_event(
        "MLUMemcpyH2DSync:CPU->MLU", platform::TracerEventType::UserDefined, 1);
F
fwenguang 已提交
1012 1013 1014 1015 1016 1017 1018 1019 1020
    platform::MLUMemcpyH2DSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::MLUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
                                                  const void* src, size_t num,
1021
                                                  void* stream) {
F
fwenguang 已提交
1022 1023 1024 1025 1026 1027 1028
  if (UNLIKELY(num == 0)) return;

  if (dst_place == src_place) {
    platform::SetMLUDeviceId(dst_place.device);
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
1029 1030 1031
      platform::RecordEvent record_event("MLUMemcpyD2DAsync(same_mlu):MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
1032 1033
      platform::MLUMemcpyD2DAsync(dst, src, num,
                                  reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1034
    } else {
1035 1036 1037 1038
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
1039 1040
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
1041 1042 1043
      platform::RecordEvent record_event("MLUMemcpyD2DSync(same_mlu):MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
F
fwenguang 已提交
1044 1045 1046 1047 1048 1049
      platform::MLUMemcpyD2DSync(dst, src, num);
    }
  } else {
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
1050 1051 1052
      platform::RecordEvent record_event("MLUMemcpyPeerAsync:MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
F
fwenguang 已提交
1053
      platform::MLUMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
1054
                                   num, reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1055 1056 1057
    } else {
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
1058 1059 1060
      platform::RecordEvent record_event("MLUMemcpyPeerSync:MLU->MLU",
                                         platform::TracerEventType::UserDefined,
                                         1);
F
fwenguang 已提交
1061 1062 1063 1064 1065 1066
      platform::MLUMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
                                  num);
    }
  }
}

1067 1068
// NOTE: only for CPUPlace and MLUPlace.
template <>
1069 1070 1071 1072 1073
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, void* stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
1074 1075
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
1076 1077
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::MLU) {
1078 1079 1080
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
1081 1082
  } else if (src_place.GetType() == phi::AllocationType::MLU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1083 1084 1085
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1086 1087
  } else if (src_place.GetType() == phi::AllocationType::MLU &&
             dst_place.GetType() == phi::AllocationType::MLU) {
1088 1089 1090
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
1091
#ifdef PADDLE_WITH_CUSTOM_DEVICE
1092 1093
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1094 1095 1096
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
1097 1098
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
1099 1100 1101
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1102 1103
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1104 1105 1106 1107
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
1108 1109 1110 1111 1112
  }
}

// NOTE: only for (CPUPlace and MLUPlace) -> (MLUPlace)
template <>
1113 1114 1115 1116 1117
void Copy<phi::MLUPlace, phi::Place>(phi::MLUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst, src_place,
       src, num, stream);
1118 1119 1120 1121
}

// NOTE: only for (MLUPlace) -> (CPUPlace and MLUPlace)
template <>
1122 1123 1124 1125 1126
void Copy<phi::Place, phi::MLUPlace>(phi::Place dst_place, void* dst,
                                     phi::MLUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src, num, stream);
1127 1128
}

F
fwenguang 已提交
1129 1130
// NOTE: only for (MLUPlace) -> (CPUPlace) with mluStream.
template <>
1131 1132 1133 1134
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
F
fwenguang 已提交
1135 1136 1137 1138
}

// NOTE: only for (CPUPlace) -> (MLUPlace) with mluStream.
template <>
1139 1140 1141 1142
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
F
fwenguang 已提交
1143 1144
}

F
fwenguang 已提交
1145 1146
#endif  // PADDLE_WITH_MLU

1147 1148
// NOTE: Only for CPUPlace, XPUPlace and PinnedPlace.
template <>
1149 1150 1151
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num) {
1152 1153 1154
  if (UNLIKELY(num == 0)) return;
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
1155 1156
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
1157 1158 1159
    std::memcpy(dst, src, num);
  }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1160 1161
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1162
    std::memcpy(dst, src, num);
1163 1164
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1165
    std::memcpy(dst, src, num);
1166 1167
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1168 1169 1170 1171
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_ASCEND_CL
1172 1173
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::NPUPINNED) {
1174
    std::memcpy(dst, src, num);
1175 1176
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1177
    std::memcpy(dst, src, num);
1178 1179
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
1180 1181 1182 1183
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_XPU
1184 1185
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::CPU) {
1186 1187
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
1188 1189
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::XPU) {
1190 1191 1192
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1193 1194
  } else if (src_place.GetType() == phi::AllocationType::XPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1195 1196 1197
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
1198 1199
  } else if (src_place.GetType() == phi::AllocationType::XPU &&
             dst_place.GetType() == phi::AllocationType::XPU) {
1200 1201 1202 1203 1204
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
#endif
A
Allen Guo 已提交
1205
#ifdef PADDLE_WITH_IPU
1206 1207
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::IPU) {
A
Allen Guo 已提交
1208 1209 1210
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1211 1212
  } else if (src_place.GetType() == phi::AllocationType::IPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
A
Allen Guo 已提交
1213 1214 1215
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
1216 1217
  } else if (src_place.GetType() == phi::AllocationType::IPU &&
             dst_place.GetType() == phi::AllocationType::IPU) {
A
Allen Guo 已提交
1218 1219 1220 1221 1222
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
#endif
1223 1224 1225 1226
}

// NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace).
template <>
1227 1228 1229 1230
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num);
1231 1232 1233 1234
}

// NOTE: Only for (CPUPlace and PinnedPlace) -> (CPUPlace).
template <>
1235 1236 1237 1238
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num);
1239 1240
}

1241 1242 1243 1244 1245
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUDA) && \
    !defined(PADDLE_WITH_ASCEND_CL) && !defined(PADDLE_WITH_HIP) &&     \
    !defined(PADDLE_WITH_MLU)

template <>
1246 1247 1248 1249 1250
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, void* stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
      dst_place.GetType() == phi::AllocationType::CUSTOM) {
1251 1252 1253
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
1254 1255
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
1256 1257 1258
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1259 1260
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1261 1262 1263 1264 1265 1266 1267
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
  }
}

template <>
1268 1269 1270 1271
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
1272 1273 1274 1275
}

// NOTE: only for (CPUPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace).
template <>
1276 1277 1278 1279
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
1280 1281 1282
}
#endif

Y
Yi Wang 已提交
1283 1284
}  // namespace memory
}  // namespace paddle