memcpy.cc 52.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16

17
#include "paddle/fluid/platform/device/device_wrapper.h"
18
#include "paddle/fluid/platform/device_context.h"
19
#include "paddle/fluid/platform/profiler.h"
20
#include "paddle/phi/common/place.h"
21

22 23 24 25 26 27 28 29
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#endif

#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif

30 31 32
namespace paddle {
namespace memory {

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <>
void Copy<platform::CPUPlace, platform::CustomPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CustomPlace src_place,
    const void* src, size_t num, void* stream) {
  if (UNLIKELY(num == 0)) return;

  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

  platform::DeviceManager::SetDevice(src_place);
  platform::stream::Stream stream_wrapper(src_place, stream);
  platform::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2H(
      dst, src, num, &stream_wrapper);
}

template <>
void Copy<platform::CustomPlace, platform::CPUPlace>(
    platform::CustomPlace dst_place, void* dst, platform::CPUPlace src_place,
    const void* src, size_t num, void* stream) {
  if (UNLIKELY(num == 0)) return;
  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

  platform::DeviceManager::SetDevice(dst_place);
  platform::stream::Stream stream_wrapper(dst_place, stream);
  platform::DeviceManager::GetDeviceWithPlace(dst_place)->MemoryCopyH2D(
      dst, src, num, &stream_wrapper);
}

template <>
void Copy<platform::CustomPlace, platform::CustomPlace>(
    platform::CustomPlace dst_place, void* dst, platform::CustomPlace src_place,
    const void* src, size_t num, void* stream) {
  if (UNLIKELY(num == 0)) return;

  auto src_type = platform::PlaceHelper::GetDeviceType(src_place);
  auto dst_type = platform::PlaceHelper::GetDeviceType(dst_place);
  std::string msg = "Memcpy:" + src_type + "->" + dst_type;
  platform::RecordEvent record_event(msg);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << ", stream=" << stream;

  if (src_type == dst_type) {
    platform::DeviceManager::SetDevice(src_place);
    platform::stream::Stream stream_wrapper(src_place, stream);

    auto src_id = platform::PlaceHelper::GetDeviceId(src_place);
    auto dst_id = platform::PlaceHelper::GetDeviceId(dst_place);
    if (src_id == dst_id) {
      platform::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2D(
          dst, src, num, &stream_wrapper);
    } else {
      platform::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyP2P(
          dst_place, dst, src, num, &stream_wrapper);
    }
  } else {
    PADDLE_THROW(platform::errors::Unavailable(
        "Copy between %s and %s is not supported.", src_type, dst_type));
  }
}
#endif  // PADDLE_WITH_CUSTOM_DEVICE

104 105 106 107
template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
108
  if (UNLIKELY(num == 0)) return;
109
  VLOG(4) << "src: " << src << ", dst: " << dst << ", num: " << num;
110 111
  std::memcpy(dst, src, num);
}
112

J
jianghaicheng 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
#ifdef PADDLE_WITH_IPU
template <>
void Copy<platform::IPUPlace, platform::CPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::CPUPlace, platform::IPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::IPUPlace, platform::IPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
138 139 140

// NOTE: only for (CPUPlace and IPUPlace) -> (IPUPlace).
template <>
141 142 143 144
void Copy<phi::IPUPlace, phi::Place>(phi::IPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num) {
  if (src_place.GetType() == phi::AllocationType::CPU) {
145 146
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
147
  } else if (src_place.GetType() == phi::AllocationType::IPU) {
148 149 150 151 152 153 154
    platform::IPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (IPUPlace) -> (CPUPlace and IPUPlace).
template <>
155 156 157 158
void Copy<phi::Place, phi::IPUPlace>(phi::Place dst_place, void* dst,
                                     phi::IPUPlace src_place, const void* src,
                                     size_t num) {
  if (dst_place.GetType() == phi::AllocationType::CPU) {
159 160
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
161
  } else if (dst_place.GetType() == phi::AllocationType::IPU) {
162 163 164 165
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
J
jianghaicheng 已提交
166
#endif
167

168 169 170 171 172 173 174
#ifdef PADDLE_WITH_XPU
template <>
void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
175
    VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
176 177
    return;
  }
178
  platform::MemcpySyncH2D(dst, src, num, dst_place);
179 180 181 182 183 184 185 186
}

template <>
void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
187
    VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
188 189
    return;
  }
190
  platform::MemcpySyncD2H(dst, src, num, src_place);
191 192 193 194 195 196 197 198
}

template <>
void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
199
    VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
200 201
    return;
  }
202
  platform::MemcpySyncD2D(dst, dst_place, src, src_place, num);
203
}
204 205 206

// NOTE: only for (CPUPlace and XPUPlace) -> (XPUPlace).
template <>
207 208 209 210
void Copy<phi::XPUPlace, phi::Place>(phi::XPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num) {
  if (src_place.GetType() == phi::AllocationType::CPU) {
211 212
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
213
  } else if (src_place.GetType() == phi::AllocationType::XPU) {
214 215 216 217 218 219 220
    platform::XPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (XPUPlace) -> (CPUPlace and XPUPlace).
template <>
221 222 223 224
void Copy<phi::Place, phi::XPUPlace>(phi::Place dst_place, void* dst,
                                     phi::XPUPlace src_place, const void* src,
                                     size_t num) {
  if (dst_place.GetType() == phi::AllocationType::CPU) {
225 226
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
227
  } else if (dst_place.GetType() == phi::AllocationType::XPU) {
228 229 230 231
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
232 233
#endif

234 235 236 237 238 239
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
240
                                                  void* stream) {
241 242 243
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);
244

245 246
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
247

248 249
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU");
250 251
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE,
                             reinterpret_cast<aclrtStream>(stream));
252
  } else {
253 254 255 256 257 258
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

259 260 261 262 263 264 265 266 267 268
    platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

template <>
void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
269
                                                  void* stream) {
270 271 272
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);
273

274 275
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
276

277 278
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU");
279 280
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST,
                             reinterpret_cast<aclrtStream>(stream));
281
  } else {
282 283 284
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

285
    platform::RecordEvent record_event("NpuMemcpySync:NPU->CPU");
286 287 288 289 290 291 292 293 294
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
295
                                                  void* stream) {
296 297 298 299 300 301 302 303 304
  if (UNLIKELY(num == 0)) return;

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by stream(" << stream << ")";
  if (dst_place == src_place) {
    platform::SetNPUDeviceId(src_place.device);
    if (stream) {
      platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
305
                               reinterpret_cast<aclrtStream>(stream));
306
    } else {
307 308 309 310
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

311 312 313 314 315 316 317 318 319 320 321 322
      platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  } else {
    if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Peer access between NPU places is not allowed."));
    }
    if (stream) {
      // TODO(zhiqiu): support peer access?
      platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
323
                               reinterpret_cast<aclrtStream>(stream));
324
    } else {
325 326 327 328
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

329 330 331 332 333
      platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  }
}
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367

template <>
void Copy<platform::CPUPlace, platform::NPUPinnedPlace>(
    platform::CPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::CPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::CPUPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPinnedPlace>(
    platform::NPUPinnedPlace dst_place, void* dst,
    platform::NPUPinnedPlace src_place, const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::NPUPlace src_place,
368
    const void* src, size_t num, void* stream) {
369 370 371 372 373 374 375 376 377
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->NPUPinned");
378 379
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST,
                             reinterpret_cast<aclrtStream>(stream));
380 381 382 383 384 385 386 387 388 389 390 391
  } else {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

    platform::RecordEvent record_event("NpuMemcpySync:NPU->NPUPinned");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPinnedPlace>(
    platform::NPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
392
    const void* src, size_t num, void* stream) {
393 394 395 396 397 398 399 400 401
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPUPinned->NPU");
402 403
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE,
                             reinterpret_cast<aclrtStream>(stream));
404 405 406 407 408 409 410 411 412 413 414 415
  } else {
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

    platform::RecordEvent record_event("NpuMemcpySync:NPUPinned->NPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

416 417
// NOTE: only for CPUPlace, NPUPlace and NPUPinnedPlace.
template <>
418 419 420 421 422
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, aclrtStream stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
423 424
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
425 426
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::NPU) {
427 428 429
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
430 431
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
432 433 434
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
435 436
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::NPU) {
437 438 439
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
440 441
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
442 443 444
    platform::CPUPlace place_src;
    platform::NPUPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
445 446
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
447 448 449
    platform::CPUPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
450 451
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
452 453 454
    platform::NPUPinnedPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
455 456
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPU) {
457 458 459
    platform::NPUPinnedPlace place_src;
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
460 461
  } else if (src_place.GetType() == phi::AllocationType::NPU &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
462 463 464
    platform::NPUPinnedPlace place_dst;
    platform::NPUPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
465
#ifdef PADDLE_WITH_CUSTOM_DEVICE
466 467
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
468 469 470
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
471 472
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
473 474 475
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
476 477
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
478 479 480 481
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
482 483 484 485 486
  }
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (CPUPlace).
template <>
487 488 489 490
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
491 492 493 494
}

// NOTE: only for (CPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace).
template <>
495 496 497 498
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
499 500 501 502
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPlace)
template <>
503 504 505 506 507
void Copy<phi::NPUPlace, phi::Place>(phi::NPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst, src_place,
       src, num, stream);
508 509 510 511
}

// NOTE: only for (NPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
512 513 514 515 516
void Copy<phi::Place, phi::NPUPlace>(phi::Place dst_place, void* dst,
                                     phi::NPUPlace src_place, const void* src,
                                     size_t num, aclrtStream stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src, num, stream);
517 518 519 520
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPinnedPlace)
template <>
521 522 523 524 525
void Copy<phi::NPUPinnedPlace, phi::Place>(phi::NPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num,
                                           aclrtStream stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
526 527 528 529
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
530 531 532 533 534
void Copy<phi::Place, phi::NPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::NPUPinnedPlace src_place,
                                           const void* src, size_t num,
                                           aclrtStream stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
535 536 537 538
}

// NOTE: only for (CPUPlace) -> (NPUPinnedPlace)
template <>
539 540 541 542
void Copy<phi::NPUPinnedPlace, phi::Place>(phi::NPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
543 544 545 546
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace)
template <>
547 548 549 550
void Copy<phi::Place, phi::NPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::NPUPinnedPlace src_place,
                                           const void* src, size_t num) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, nullptr);
551
}
552 553
#endif

554
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
sneaxiy 已提交
555 556
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

557 558 559 560 561 562 563 564 565 566 567 568 569
#ifdef PADDLE_WITH_HIP
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  hipStreamSynchronize(0);
#else
  hipError_t e_sync = hipSuccess;
  while (e_sync = hipStreamQuery(0)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
}
#else
570 571 572 573 574 575 576 577 578 579 580
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  cudaStreamSynchronize(0);
#else
  cudaError_t e_sync = cudaSuccess;
  while (e_sync = cudaStreamQuery(0)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
}
581
#endif
582

583 584 585 586 587 588
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

589
template <>
D
dzhwinter 已提交
590 591
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
592
    const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
593
  if (UNLIKELY(num == 0)) return;
594

595 596
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
597
          << dst_place << " by stream(" << stream << ")";
598
  if (stream) {
599
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
600
#ifdef PADDLE_WITH_HIP
601 602
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
603
#else
604 605
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
606
#endif
607
  } else {
608
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
609 610 611
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
612
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
613
#endif
S
sneaxiy 已提交
614 615
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
616
      SyncCUDAStream();
S
sneaxiy 已提交
617
    }
618
  }
619 620 621
}

template <>
D
dzhwinter 已提交
622 623
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
624
    const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
625 626
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
627
  platform::SetDeviceId(dst_place.device);
628 629
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
630
  if (stream) {
631
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
632
#ifdef PADDLE_WITH_HIP
633 634
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
635
#else
636 637
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
638
#endif
639
  } else {
640
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
641 642 643
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
644
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
645
#endif
S
sneaxiy 已提交
646 647
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
648
      SyncCUDAStream();
S
sneaxiy 已提交
649
    }
650
  }
651 652 653
}

template <>
D
dzhwinter 已提交
654 655
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
656
    const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
657 658
  if (UNLIKELY(num == 0)) return;

659
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
660
          << dst_place << " by stream(" << stream << ")";
661
  if (dst_place == src_place) {
L
liaogang 已提交
662
    platform::SetDeviceId(src_place.device);
663
    if (stream) {
664
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
665
#ifdef PADDLE_WITH_HIP
666 667
      platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToDevice,
                               reinterpret_cast<gpuStream_t>(stream));
668
#else
669 670
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice,
                               reinterpret_cast<gpuStream_t>(stream));
671
#endif
672
    } else {
673
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
674 675 676
#ifdef PADDLE_WITH_HIP
      platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToDevice);
#else
677
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
678
#endif
679
    }
680
  } else {
681
    if (stream) {
682
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
683
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
684
                                   num, reinterpret_cast<gpuStream_t>(stream));
685
    } else {
686
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
687
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
688
                                  num);
689
    }
690 691 692
  }
}

C
chengduoZH 已提交
693 694 695 696
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
697 698
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
699
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
700 701 702 703 704 705 706
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
707 708
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
709
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
710 711 712 713 714 715 716
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
717 718
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
719
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
720 721 722 723 724 725
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
726
    platform::CUDAPlace src_place, const void* src, size_t num, void* stream) {
Z
Zeng Jinle 已提交
727
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
728
  platform::SetDeviceId(src_place.device);
729 730
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
731
  if (stream) {
732
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
733
#ifdef PADDLE_WITH_HIP
734 735
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
736
#else
737 738
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost,
                             reinterpret_cast<gpuStream_t>(stream));
739
#endif
740
  } else {
741
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
742 743 744
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
745
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
746
#endif
747
  }
C
chengduoZH 已提交
748 749 750 751 752 753
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
754
    void* stream) {
Z
Zeng Jinle 已提交
755 756
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
757
  platform::SetDeviceId(dst_place.device);
758 759
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
760
  if (stream) {
761
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
762
#ifdef PADDLE_WITH_HIP
763 764
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
765
#else
766 767
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice,
                             reinterpret_cast<gpuStream_t>(stream));
768
#endif
769
  } else {
770
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
771 772 773
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
774
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
775
#endif
776
  }
C
chengduoZH 已提交
777 778
}

779 780
// NOTE: only for CPUPlace、CUDAPlace and CUDAPinnedPlace.
template <>
781 782 783 784 785
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, void* stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
786 787
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
788 789
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::GPU) {
790 791 792
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
793 794
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
795 796 797
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
798 799
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::GPU) {
800 801 802
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
803 804
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
805 806 807
    platform::CPUPlace place_src;
    platform::CUDAPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
808 809
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
810 811 812
    platform::CPUPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
813 814
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
815 816 817
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
818 819
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPU) {
820 821 822
    platform::CUDAPinnedPlace place_src;
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
823 824
  } else if (src_place.GetType() == phi::AllocationType::GPU &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
825 826 827
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
828
#ifdef PADDLE_WITH_CUSTOM_DEVICE
829 830
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
831 832 833
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
834 835
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
836 837 838
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
839 840
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
841 842 843 844
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
845 846 847 848 849
  }
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CPUPlace).
template <>
850 851 852 853
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
854 855 856 857
}

// NOTE: only for (CPUPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace).
template <>
858 859 860 861
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
862 863 864 865
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPlace)
template <>
866 867 868 869 870
void Copy<phi::GPUPlace, phi::Place>(phi::GPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst, src_place,
       src, num, stream);
871 872 873 874
}

// NOTE: only for (CUDAPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
875 876 877 878 879
void Copy<phi::Place, phi::GPUPlace>(phi::Place dst_place, void* dst,
                                     phi::GPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src, num, stream);
880 881 882 883
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPinnedPlace)
template <>
884 885 886 887 888
void Copy<phi::GPUPinnedPlace, phi::Place>(phi::GPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num,
                                           void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
889 890 891 892
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
893 894 895 896 897
void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::GPUPinnedPlace src_place,
                                           const void* src, size_t num,
                                           void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
898 899 900 901
}

// NOTE: only for (CPUPlace) -> (CUDAPinnedPlace)
template <>
902 903 904 905
void Copy<phi::GPUPinnedPlace, phi::Place>(phi::GPUPinnedPlace dst_place,
                                           void* dst, phi::Place src_place,
                                           const void* src, size_t num) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
906 907 908 909
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace)
template <>
910 911 912 913
void Copy<phi::Place, phi::GPUPinnedPlace>(phi::Place dst_place, void* dst,
                                           phi::GPUPinnedPlace src_place,
                                           const void* src, size_t num) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, nullptr);
914
}
L
Luo Tao 已提交
915
#endif
Y
Yi Wang 已提交
916

F
fwenguang 已提交
917 918 919 920 921 922
#ifdef PADDLE_WITH_MLU
template <>
void Copy<platform::CPUPlace, platform::MLUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
                                                  const void* src, size_t num,
923
                                                  void* stream) {
F
fwenguang 已提交
924 925 926 927 928 929 930
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(src_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
    platform::RecordEvent record_event("MLUMemcpyD2HAsync:MLU->CPU");
931 932
    platform::MLUMemcpyD2HAsync(dst, src, num,
                                reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
933
  } else {
934 935 936
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
937 938 939 940 941 942 943 944 945 946 947 948
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
    platform::RecordEvent record_event("MLUMemcpyD2HSync:MLU->CPU");
    platform::MLUMemcpyD2HSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::CPUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
949
                                                  void* stream) {
F
fwenguang 已提交
950 951 952 953 954 955 956
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(dst_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
    platform::RecordEvent record_event("MLUMemcpyH2DAsync:CPU->MLU");
957 958
    platform::MLUMemcpyH2DAsync(dst, src, num,
                                reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
959
  } else {
960 961 962
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
963 964 965 966 967 968 969 970 971 972 973 974
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
    platform::RecordEvent record_event("MLUMemcpyH2DSync:CPU->MLU");
    platform::MLUMemcpyH2DSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::MLUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
                                                  const void* src, size_t num,
975
                                                  void* stream) {
F
fwenguang 已提交
976 977 978 979 980 981 982 983 984
  if (UNLIKELY(num == 0)) return;

  if (dst_place == src_place) {
    platform::SetMLUDeviceId(dst_place.device);
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
      platform::RecordEvent record_event(
          "MLUMemcpyD2DAsync(same_mlu):MLU->MLU");
985 986
      platform::MLUMemcpyD2DAsync(dst, src, num,
                                  reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
987
    } else {
988 989 990 991
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
992 993 994 995 996 997 998 999 1000 1001 1002
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
      platform::RecordEvent record_event("MLUMemcpyD2DSync(same_mlu):MLU->MLU");
      platform::MLUMemcpyD2DSync(dst, src, num);
    }
  } else {
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
      platform::RecordEvent record_event("MLUMemcpyPeerAsync:MLU->MLU");
      platform::MLUMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
1003
                                   num, reinterpret_cast<mluStream>(stream));
F
fwenguang 已提交
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
    } else {
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
      platform::RecordEvent record_event("MLUMemcpyPeerSync:MLU->MLU");
      platform::MLUMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
                                  num);
    }
  }
}

1014 1015
// NOTE: only for CPUPlace and MLUPlace.
template <>
1016 1017 1018 1019 1020
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, void* stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
1021 1022
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
1023 1024
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::MLU) {
1025 1026 1027
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
1028 1029
  } else if (src_place.GetType() == phi::AllocationType::MLU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1030 1031 1032
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1033 1034
  } else if (src_place.GetType() == phi::AllocationType::MLU &&
             dst_place.GetType() == phi::AllocationType::MLU) {
1035 1036 1037
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
1038
#ifdef PADDLE_WITH_CUSTOM_DEVICE
1039 1040
  } else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1041 1042 1043
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
1044 1045
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
1046 1047 1048
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1049 1050
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1051 1052 1053 1054
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
#endif
1055 1056 1057 1058 1059
  }
}

// NOTE: only for (CPUPlace and MLUPlace) -> (MLUPlace)
template <>
1060 1061 1062 1063 1064
void Copy<phi::MLUPlace, phi::Place>(phi::MLUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst, src_place,
       src, num, stream);
1065 1066 1067 1068
}

// NOTE: only for (MLUPlace) -> (CPUPlace and MLUPlace)
template <>
1069 1070 1071 1072 1073
void Copy<phi::Place, phi::MLUPlace>(phi::Place dst_place, void* dst,
                                     phi::MLUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType(), src_place.GetDeviceId()),
       src, num, stream);
1074 1075
}

F
fwenguang 已提交
1076 1077
// NOTE: only for (MLUPlace) -> (CPUPlace) with mluStream.
template <>
1078 1079 1080 1081
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
F
fwenguang 已提交
1082 1083 1084 1085
}

// NOTE: only for (CPUPlace) -> (MLUPlace) with mluStream.
template <>
1086 1087 1088 1089
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
F
fwenguang 已提交
1090 1091
}

F
fwenguang 已提交
1092 1093
#endif  // PADDLE_WITH_MLU

1094 1095
// NOTE: Only for CPUPlace, XPUPlace and PinnedPlace.
template <>
1096 1097 1098
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num) {
1099 1100 1101
  if (UNLIKELY(num == 0)) return;
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
1102 1103
  if (src_place.GetType() == phi::AllocationType::CPU &&
      dst_place.GetType() == phi::AllocationType::CPU) {
1104 1105 1106
    std::memcpy(dst, src, num);
  }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1107 1108
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1109
    std::memcpy(dst, src, num);
1110 1111
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1112
    std::memcpy(dst, src, num);
1113 1114
  } else if (src_place.GetType() == phi::AllocationType::GPUPINNED &&
             dst_place.GetType() == phi::AllocationType::GPUPINNED) {
1115 1116 1117 1118
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_ASCEND_CL
1119 1120
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::NPUPINNED) {
1121
    std::memcpy(dst, src, num);
1122 1123
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1124
    std::memcpy(dst, src, num);
1125 1126
  } else if (src_place.GetType() == phi::AllocationType::NPUPINNED &&
             dst_place.GetType() == phi::AllocationType::NPUPINNED) {
1127 1128 1129 1130
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_XPU
1131 1132
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::CPU) {
1133 1134
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
1135 1136
  } else if (src_place.GetType() == phi::AllocationType::CPU &&
             dst_place.GetType() == phi::AllocationType::XPU) {
1137 1138 1139
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1140 1141
  } else if (src_place.GetType() == phi::AllocationType::XPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
1142 1143 1144
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
1145 1146
  } else if (src_place.GetType() == phi::AllocationType::XPU &&
             dst_place.GetType() == phi::AllocationType::XPU) {
1147 1148 1149 1150 1151
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
#endif
A
Allen Guo 已提交
1152
#ifdef PADDLE_WITH_IPU
1153 1154
  else if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == phi::AllocationType::IPU) {
A
Allen Guo 已提交
1155 1156 1157
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
1158 1159
  } else if (src_place.GetType() == phi::AllocationType::IPU &&
             dst_place.GetType() == phi::AllocationType::CPU) {
A
Allen Guo 已提交
1160 1161 1162
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
1163 1164
  } else if (src_place.GetType() == phi::AllocationType::IPU &&
             dst_place.GetType() == phi::AllocationType::IPU) {
A
Allen Guo 已提交
1165 1166 1167 1168 1169
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
#endif
1170 1171 1172 1173
}

// NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace).
template <>
1174 1175 1176 1177
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num);
1178 1179 1180 1181
}

// NOTE: Only for (CPUPlace and PinnedPlace) -> (CPUPlace).
template <>
1182 1183 1184 1185
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num);
1186 1187
}

1188 1189 1190 1191 1192
#if defined(PADDLE_WITH_CUSTOM_DEVICE) && !defined(PADDLE_WITH_CUDA) && \
    !defined(PADDLE_WITH_ASCEND_CL) && !defined(PADDLE_WITH_HIP) &&     \
    !defined(PADDLE_WITH_MLU)

template <>
1193 1194 1195 1196 1197
void Copy<phi::Place, phi::Place>(phi::Place dst_place, void* dst,
                                  phi::Place src_place, const void* src,
                                  size_t num, void* stream) {
  if (src_place.GetType() == phi::AllocationType::CPU &&  // NOLINT
      dst_place.GetType() == phi::AllocationType::CUSTOM) {
1198 1199 1200
    platform::CPUPlace place_src;
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
1201 1202
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CPU) {
1203 1204 1205
    platform::CustomPlace place_src(src_place);
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
1206 1207
  } else if (src_place.GetType() == phi::AllocationType::CUSTOM &&  // NOLINT
             dst_place.GetType() == phi::AllocationType::CUSTOM) {
1208 1209 1210 1211 1212 1213 1214
    platform::CustomPlace place_src(src_place);
    platform::CustomPlace place_dst(dst_place);
    return Copy(place_dst, dst, place_src, src, num, stream);
  }
}

template <>
1215 1216 1217 1218
void Copy<phi::CPUPlace, phi::Place>(phi::CPUPlace dst_place, void* dst,
                                     phi::Place src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream);
1219 1220 1221 1222
}

// NOTE: only for (CPUPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace).
template <>
1223 1224 1225 1226
void Copy<phi::Place, phi::CPUPlace>(phi::Place dst_place, void* dst,
                                     phi::CPUPlace src_place, const void* src,
                                     size_t num, void* stream) {
  Copy(dst_place, dst, phi::Place(src_place.GetType()), src, num, stream);
1227 1228 1229
}
#endif

Y
Yi Wang 已提交
1230 1231
}  // namespace memory
}  // namespace paddle