memcpy.cc 44.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16

17
#include "paddle/fluid/platform/device/device_wrapper.h"
18
#include "paddle/fluid/platform/device_context.h"
19
#include "paddle/fluid/platform/profiler.h"
20
#include "paddle/pten/common/place.h"
21 22 23 24 25 26 27 28

namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
Z
Zeng Jinle 已提交
29
  if (UNLIKELY(num == 0)) return;
30
  VLOG(4) << "src: " << src << ", dst: " << dst << ", num: " << num;
31 32
  std::memcpy(dst, src, num);
}
33

J
jianghaicheng 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
#ifdef PADDLE_WITH_IPU
template <>
void Copy<platform::IPUPlace, platform::CPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::CPUPlace, platform::IPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
template <>
void Copy<platform::IPUPlace, platform::IPUPlace>(platform::IPUPlace dst_place,
                                                  void* dst,
                                                  platform::IPUPlace src_place,
                                                  const void* src, size_t num) {
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

// NOTE: only for CPUPlace and IPUPlace.
template <>
void Copy<pten::Place, pten::Place>(pten::Place dst_place, void* dst,
                                    pten::Place src_place, const void* src,
                                    size_t num) {
  if (src_place.GetType() == pten::AllocationType::CPU &&
      dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::CPU &&
             dst_place.GetType() == pten::AllocationType::IPU) {
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::IPU &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::IPU &&
             dst_place.GetType() == pten::AllocationType::IPU) {
    platform::IPUPlace place_src(src_place.GetDeviceId());
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
}

// NOTE: only for (CPUPlace and IPUPlace) -> (IPUPlace).
template <>
void Copy<pten::IPUPlace, pten::Place>(pten::IPUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num) {
  if (src_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::IPU) {
    platform::IPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (IPUPlace) -> (CPUPlace and IPUPlace).
template <>
void Copy<pten::Place, pten::IPUPlace>(pten::Place dst_place, void* dst,
                                       pten::IPUPlace src_place,
                                       const void* src, size_t num) {
  if (dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
  } else if (dst_place.GetType() == pten::AllocationType::IPU) {
    platform::IPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
J
jianghaicheng 已提交
114
#endif
115

116 117 118 119 120 121 122
#ifdef PADDLE_WITH_XPU
template <>
void Copy<platform::XPUPlace, platform::CPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
123
    VLOG(1) << "memcpy XPU_HOST_TO_DEVICE size <= 0 (" << num << ")";
124 125
    return;
  }
126
  platform::MemcpySyncH2D(dst, src, num, dst_place);
127 128 129 130 131 132 133 134
}

template <>
void Copy<platform::CPUPlace, platform::XPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
135
    VLOG(1) << "memcpy XPU_DEVICE_TO_HOST size <= 0 (" << num << ")";
136 137
    return;
  }
138
  platform::MemcpySyncD2H(dst, src, num, src_place);
139 140 141 142 143 144 145 146
}

template <>
void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
                                                  void* dst,
                                                  platform::XPUPlace src_place,
                                                  const void* src, size_t num) {
  if (num <= 0) {
147
    VLOG(1) << "memcpy XPU_DEVICE_TO_DEVICE size <= 0 (" << num << ")";
148 149
    return;
  }
150
  platform::MemcpySyncD2D(dst, dst_place, src, src_place, num);
151
}
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179

// NOTE: only for (CPUPlace and XPUPlace) -> (XPUPlace).
template <>
void Copy<pten::XPUPlace, pten::Place>(pten::XPUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num) {
  if (src_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_src;
    return Copy(dst_place, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::XPU) {
    platform::XPUPlace place_src(src_place.GetDeviceId());
    return Copy(dst_place, dst, place_src, src, num);
  }
}

// NOTE: only for (XPUPlace) -> (CPUPlace and XPUPlace).
template <>
void Copy<pten::Place, pten::XPUPlace>(pten::Place dst_place, void* dst,
                                       pten::XPUPlace src_place,
                                       const void* src, size_t num) {
  if (dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, src_place, src, num);
  } else if (dst_place.GetType() == pten::AllocationType::XPU) {
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, src_place, src, num);
  }
}
180 181
#endif

182 183 184 185 186 187 188 189 190 191
#ifdef PADDLE_WITH_ASCEND_CL
template <>
void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);
192

193 194
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
195

196 197 198 199
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
  } else {
200 201 202 203 204 205
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

206 207 208 209 210 211 212 213 214 215 216 217 218 219
    platform::RecordEvent record_event("NpuMemcpySync:CPU->NPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

template <>
void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);
220

221 222
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
223

224 225 226 227
  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
  } else {
228 229 230
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

231
    platform::RecordEvent record_event("NpuMemcpySync:NPU->CPU");
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPlace>(platform::NPUPlace dst_place,
                                                  void* dst,
                                                  platform::NPUPlace src_place,
                                                  const void* src, size_t num,
                                                  aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by stream(" << stream << ")";
  if (dst_place == src_place) {
    platform::SetNPUDeviceId(src_place.device);
    if (stream) {
      platform::RecordEvent record_event("NpuMemcpyAsync(same_npu):NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
                               stream);
    } else {
253 254 255 256
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

257 258 259 260 261 262 263 264 265 266 267 268 269 270
      platform::RecordEvent record_event("NpuMemcpySync(same_npu):NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  } else {
    if (!platform::NPUCanAccessPeer(dst_place.device, dst_place.device)) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Peer access between NPU places is not allowed."));
    }
    if (stream) {
      // TODO(zhiqiu): support peer access?
      platform::RecordEvent record_event("NpuMemcpyPeerAsync:NPU->NPU");
      platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE,
                               stream);
    } else {
271 272 273 274
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

275 276 277 278 279
      platform::RecordEvent record_event("NpuMemcpyPeerSync:NPU->NPU");
      platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_DEVICE);
    }
  }
}
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359

template <>
void Copy<platform::CPUPlace, platform::NPUPinnedPlace>(
    platform::CPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::CPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::CPUPlace src_place,
    const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPinnedPlace>(
    platform::NPUPinnedPlace dst_place, void* dst,
    platform::NPUPinnedPlace src_place, const void* src, size_t num) {
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (UNLIKELY(num == 0)) return;
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::NPUPinnedPlace, platform::NPUPlace>(
    platform::NPUPinnedPlace dst_place, void* dst, platform::NPUPlace src_place,
    const void* src, size_t num, aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(src_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPU->NPUPinned");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);
  } else {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(src_place))->Wait();

    platform::RecordEvent record_event("NpuMemcpySync:NPU->NPUPinned");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST);
  }
}

template <>
void Copy<platform::NPUPlace, platform::NPUPinnedPlace>(
    platform::NPUPlace dst_place, void* dst, platform::NPUPinnedPlace src_place,
    const void* src, size_t num, aclrtStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetNPUDeviceId(dst_place.device);

  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";

  if (stream) {
    platform::RecordEvent record_event("NpuMemcpyAsync:NPUPinned->NPU");
    platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
  } else {
    // On NPU, async operation after sync operation is ok, while sync operation
    // after async is not ok, since the async operation may not done.
    // So, its needed to do wait before sync operation.
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::NPUDeviceContext*>(pool.Get(dst_place))->Wait();

    platform::RecordEvent record_event("NpuMemcpySync:NPUPinned->NPU");
    platform::NPUMemcpySync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE);
  }
}

360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
// NOTE: only for CPUPlace, NPUPlace and NPUPinnedPlace.
template <>
void Copy<pten::Place, pten::Place>(pten::Place dst_place, void* dst,
                                    pten::Place src_place, const void* src,
                                    size_t num, aclrtStream stream) {
  if (src_place.GetType() == pten::AllocationType::CPU &&
      dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::CPU &&
             dst_place.GetType() == pten::AllocationType::NPU) {
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::NPU &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::NPU &&
             dst_place.GetType() == pten::AllocationType::NPU) {
    platform::NPUPlace place_src(src_place.GetDeviceId());
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::CPU &&
             dst_place.GetType() == pten::AllocationType::NPUPINNED) {
    platform::CPUPlace place_src;
    platform::NPUPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::NPUPINNED &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::NPUPINNED &&
             dst_place.GetType() == pten::AllocationType::NPUPINNED) {
    platform::NPUPinnedPlace place_dst;
    platform::NPUPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::NPUPINNED &&
             dst_place.GetType() == pten::AllocationType::NPU) {
    platform::NPUPinnedPlace place_src;
    platform::NPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::NPU &&
             dst_place.GetType() == pten::AllocationType::NPUPINNED) {
    platform::NPUPinnedPlace place_dst;
    platform::NPUPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
  }
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (CPUPlace).
template <>
void Copy<pten::CPUPlace, pten::Place>(pten::CPUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num, aclrtStream stream) {
  Copy(pten::Place(dst_place.GetType()), dst, src_place, src, num, stream);
}

// NOTE: only for (CPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace).
template <>
void Copy<pten::Place, pten::CPUPlace>(pten::Place dst_place, void* dst,
                                       pten::CPUPlace src_place,
                                       const void* src, size_t num,
                                       aclrtStream stream) {
  Copy(dst_place, dst, pten::Place(src_place.GetType()), src, num, stream);
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPlace)
template <>
void Copy<pten::NPUPlace, pten::Place>(pten::NPUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num, aclrtStream stream) {
  Copy(pten::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst,
       src_place, src, num, stream);
}

// NOTE: only for (NPUPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
void Copy<pten::Place, pten::NPUPlace>(pten::Place dst_place, void* dst,
                                       pten::NPUPlace src_place,
                                       const void* src, size_t num,
                                       aclrtStream stream) {
  Copy(dst_place, dst,
       pten::Place(src_place.GetType(), src_place.GetDeviceId()), src, num,
       stream);
}

// NOTE: only for (CPUPlace, NPUPlace and NPUPinnedPlace) -> (NPUPinnedPlace)
template <>
void Copy<pten::NPUPinnedPlace, pten::Place>(pten::NPUPinnedPlace dst_place,
                                             void* dst, pten::Place src_place,
                                             const void* src, size_t num,
                                             aclrtStream stream) {
  Copy(pten::Place(dst_place.GetType()), dst, src_place, src, num, stream);
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace, NPUPlace and NPUPinnedPlace)
template <>
void Copy<pten::Place, pten::NPUPinnedPlace>(pten::Place dst_place, void* dst,
                                             pten::NPUPinnedPlace src_place,
                                             const void* src, size_t num,
                                             aclrtStream stream) {
  Copy(dst_place, dst, pten::Place(src_place.GetType()), src, num, stream);
}

// NOTE: only for (CPUPlace) -> (NPUPinnedPlace)
template <>
void Copy<pten::NPUPinnedPlace, pten::Place>(pten::NPUPinnedPlace dst_place,
                                             void* dst, pten::Place src_place,
                                             const void* src, size_t num) {
  Copy(pten::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
}

// NOTE: only for (NPUPinnedPlace) -> (CPUPlace)
template <>
void Copy<pten::Place, pten::NPUPinnedPlace>(pten::Place dst_place, void* dst,
                                             pten::NPUPinnedPlace src_place,
                                             const void* src, size_t num) {
  Copy(dst_place, dst, pten::Place(src_place.GetType()), src, num, nullptr);
}
482 483
#endif

484
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
S
sneaxiy 已提交
485 486
static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024;  // 64K

487 488 489 490 491 492 493 494 495 496 497 498 499
#ifdef PADDLE_WITH_HIP
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  hipStreamSynchronize(0);
#else
  hipError_t e_sync = hipSuccess;
  while (e_sync = hipStreamQuery(0)) {
    if (e_sync == hipErrorNotReady) continue;
    break;
  }
#endif
}
#else
500 501 502 503 504 505 506 507 508 509 510
inline void SyncCUDAStream() {
#if !defined(_WIN32)
  cudaStreamSynchronize(0);
#else
  cudaError_t e_sync = cudaSuccess;
  while (e_sync = cudaStreamQuery(0)) {
    if (e_sync == cudaErrorNotReady) continue;
    break;
  }
#endif
}
511
#endif
512

513 514 515 516 517 518
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// which will make two commands from different streams cannot run concurrently.
// Reference:
// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/

519
template <>
D
dzhwinter 已提交
520 521
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
522
    const void* src, size_t num, gpuStream_t stream) {
Z
Zeng Jinle 已提交
523
  if (UNLIKELY(num == 0)) return;
524

525 526
  platform::SetDeviceId(src_place.device);
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
527
          << dst_place << " by stream(" << stream << ")";
528
  if (stream) {
529
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU");
530 531 532
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
533
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
534
#endif
535
  } else {
536
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU");
537 538 539
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
540
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
541
#endif
S
sneaxiy 已提交
542 543
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
544
      SyncCUDAStream();
S
sneaxiy 已提交
545
    }
546
  }
547 548 549
}

template <>
D
dzhwinter 已提交
550 551
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
552
    const void* src, size_t num, gpuStream_t stream) {
Z
Zeng Jinle 已提交
553 554
  if (UNLIKELY(num == 0)) return;

L
liaogang 已提交
555
  platform::SetDeviceId(dst_place.device);
556 557
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
558
  if (stream) {
559
    platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU");
560 561 562
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
563
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
564
#endif
565
  } else {
566
    platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU");
567 568 569
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
570
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
571
#endif
S
sneaxiy 已提交
572 573
    // FIXME(zjl): do we really need it?
    if (num <= kMaxGpuAsyncCopyBytes) {
574
      SyncCUDAStream();
S
sneaxiy 已提交
575
    }
576
  }
577 578 579
}

template <>
D
dzhwinter 已提交
580 581
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
582
    const void* src, size_t num, gpuStream_t stream) {
Z
Zeng Jinle 已提交
583 584
  if (UNLIKELY(num == 0)) return;

585
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
586
          << dst_place << " by stream(" << stream << ")";
587
  if (dst_place == src_place) {
L
liaogang 已提交
588
    platform::SetDeviceId(src_place.device);
589
    if (stream) {
590
      platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU");
591 592 593
#ifdef PADDLE_WITH_HIP
      platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToDevice, stream);
#else
594
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
595
#endif
596
    } else {
597
      platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU");
598 599 600
#ifdef PADDLE_WITH_HIP
      platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToDevice);
#else
601
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
602
#endif
603
    }
604
  } else {
605
    if (stream) {
606
      platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU");
607 608 609
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
610
      platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU");
611
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
612
                                  num);
613
    }
614 615 616
  }
}

C
chengduoZH 已提交
617 618 619 620
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
621 622
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
623
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
624 625 626 627 628 629 630
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
631 632
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
633
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
634 635 636 637 638 639 640
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
641 642
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
Z
Zeng Jinle 已提交
643
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
644 645 646 647 648 649 650
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
651
    gpuStream_t stream) {
Z
Zeng Jinle 已提交
652
  if (UNLIKELY(num == 0)) return;
C
chengduoZH 已提交
653
  platform::SetDeviceId(src_place.device);
654 655
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
656
  if (stream) {
657
    platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned");
658 659 660
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyDeviceToHost, stream);
#else
661
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
662
#endif
663
  } else {
664
    platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned");
665 666 667
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyDeviceToHost);
#else
668
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
669
#endif
670
  }
C
chengduoZH 已提交
671 672 673 674 675 676
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
677
    gpuStream_t stream) {
Z
Zeng Jinle 已提交
678 679
  if (UNLIKELY(num == 0)) return;

C
chengduoZH 已提交
680
  platform::SetDeviceId(dst_place.device);
681 682
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place << " by thream(" << stream << ")";
683
  if (stream) {
684
    platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU");
685 686 687
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpyAsync(dst, src, num, hipMemcpyHostToDevice, stream);
#else
688
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
689
#endif
690
  } else {
691
    platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU");
692 693 694
#ifdef PADDLE_WITH_HIP
    platform::GpuMemcpySync(dst, src, num, hipMemcpyHostToDevice);
#else
695
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
696
#endif
697
  }
C
chengduoZH 已提交
698 699
}

700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
// NOTE: only for CPUPlace、CUDAPlace and CUDAPinnedPlace.
template <>
void Copy<pten::Place, pten::Place>(pten::Place dst_place, void* dst,
                                    pten::Place src_place, const void* src,
                                    size_t num, gpuStream_t stream) {
  if (src_place.GetType() == pten::AllocationType::CPU &&
      dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::CPU &&
             dst_place.GetType() == pten::AllocationType::GPU) {
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::GPU &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::GPU &&
             dst_place.GetType() == pten::AllocationType::GPU) {
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::CPU &&
             dst_place.GetType() == pten::AllocationType::GPUPINNED) {
    platform::CPUPlace place_src;
    platform::CUDAPinnedPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::GPUPINNED &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::GPUPINNED &&
             dst_place.GetType() == pten::AllocationType::GPUPINNED) {
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPinnedPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::GPUPINNED &&
             dst_place.GetType() == pten::AllocationType::GPU) {
    platform::CUDAPinnedPlace place_src;
    platform::CUDAPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::GPU &&
             dst_place.GetType() == pten::AllocationType::GPUPINNED) {
    platform::CUDAPinnedPlace place_dst;
    platform::CUDAPlace place_src(src_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
  }
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CPUPlace).
template <>
void Copy<pten::CPUPlace, pten::Place>(pten::CPUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num, gpuStream_t stream) {
  Copy(pten::Place(dst_place.GetType()), dst, src_place, src, num, stream);
}

// NOTE: only for (CPUPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace).
template <>
void Copy<pten::Place, pten::CPUPlace>(pten::Place dst_place, void* dst,
                                       pten::CPUPlace src_place,
                                       const void* src, size_t num,
                                       gpuStream_t stream) {
  Copy(dst_place, dst, pten::Place(src_place.GetType()), src, num, stream);
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPlace)
template <>
void Copy<pten::GPUPlace, pten::Place>(pten::GPUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num, gpuStream_t stream) {
  Copy(pten::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst,
       src_place, src, num, stream);
}

// NOTE: only for (CUDAPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
void Copy<pten::Place, pten::GPUPlace>(pten::Place dst_place, void* dst,
                                       pten::GPUPlace src_place,
                                       const void* src, size_t num,
                                       gpuStream_t stream) {
  Copy(dst_place, dst,
       pten::Place(src_place.GetType(), src_place.GetDeviceId()), src, num,
       stream);
}

// NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CUDAPinnedPlace)
template <>
void Copy<pten::GPUPinnedPlace, pten::Place>(pten::GPUPinnedPlace dst_place,
                                             void* dst, pten::Place src_place,
                                             const void* src, size_t num,
                                             gpuStream_t stream) {
  Copy(pten::Place(dst_place.GetType()), dst, src_place, src, num, stream);
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace, CUDAPlace and CUDAPinnedPlace)
template <>
void Copy<pten::Place, pten::GPUPinnedPlace>(pten::Place dst_place, void* dst,
                                             pten::GPUPinnedPlace src_place,
                                             const void* src, size_t num,
                                             gpuStream_t stream) {
  Copy(dst_place, dst, pten::Place(src_place.GetType()), src, num, stream);
}

// NOTE: only for (CPUPlace) -> (CUDAPinnedPlace)
template <>
void Copy<pten::GPUPinnedPlace, pten::Place>(pten::GPUPinnedPlace dst_place,
                                             void* dst, pten::Place src_place,
                                             const void* src, size_t num) {
  Copy(pten::Place(dst_place.GetType()), dst, src_place, src, num, nullptr);
}

// NOTE: only for (CUDAPinnedPlace) -> (CPUPlace)
template <>
void Copy<pten::Place, pten::GPUPinnedPlace>(pten::Place dst_place, void* dst,
                                             pten::GPUPinnedPlace src_place,
                                             const void* src, size_t num) {
  Copy(dst_place, dst, pten::Place(src_place.GetType()), src, num, nullptr);
}
L
Luo Tao 已提交
822
#endif
Y
Yi Wang 已提交
823

F
fwenguang 已提交
824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839
#ifdef PADDLE_WITH_MLU
template <>
void Copy<platform::CPUPlace, platform::MLUPlace>(platform::CPUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
                                                  const void* src, size_t num,
                                                  mluStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(src_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
    platform::RecordEvent record_event("MLUMemcpyD2HAsync:MLU->CPU");
    platform::MLUMemcpyD2HAsync(dst, src, num, stream);
  } else {
840 841 842
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
    platform::RecordEvent record_event("MLUMemcpyD2HSync:MLU->CPU");
    platform::MLUMemcpyD2HSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::CPUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::CPUPlace src_place,
                                                  const void* src, size_t num,
                                                  mluStream stream) {
  if (UNLIKELY(num == 0)) return;

  platform::SetMLUDeviceId(dst_place.device);
  if (stream) {
    VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place << " by mlu stream(" << stream << ")";
    platform::RecordEvent record_event("MLUMemcpyH2DAsync:CPU->MLU");
    platform::MLUMemcpyH2DAsync(dst, src, num, stream);
  } else {
865 866 867
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
    VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
            << " to " << dst_place;
    platform::RecordEvent record_event("MLUMemcpyH2DSync:CPU->MLU");
    platform::MLUMemcpyH2DSync(dst, src, num);
  }
}

template <>
void Copy<platform::MLUPlace, platform::MLUPlace>(platform::MLUPlace dst_place,
                                                  void* dst,
                                                  platform::MLUPlace src_place,
                                                  const void* src, size_t num,
                                                  mluStream stream) {
  if (UNLIKELY(num == 0)) return;

  if (dst_place == src_place) {
    platform::SetMLUDeviceId(dst_place.device);
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
      platform::RecordEvent record_event(
          "MLUMemcpyD2DAsync(same_mlu):MLU->MLU");
      platform::MLUMemcpyD2DAsync(dst, src, num, stream);
    } else {
892 893 894 895
      platform::DeviceContextPool& pool =
          platform::DeviceContextPool::Instance();
      static_cast<platform::MLUDeviceContext*>(pool.Get(src_place))->Wait();

F
fwenguang 已提交
896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
      platform::RecordEvent record_event("MLUMemcpyD2DSync(same_mlu):MLU->MLU");
      platform::MLUMemcpyD2DSync(dst, src, num);
    }
  } else {
    if (stream) {
      VLOG(4) << "Async memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place << " by mlu stream(" << stream << ")";
      platform::RecordEvent record_event("MLUMemcpyPeerAsync:MLU->MLU");
      platform::MLUMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
      VLOG(4) << "Sync memory::Copy " << num << " Bytes from " << src_place
              << " to " << dst_place;
      platform::RecordEvent record_event("MLUMemcpyPeerSync:MLU->MLU");
      platform::MLUMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
                                  num);
    }
  }
}

918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964
// NOTE: only for CPUPlace and MLUPlace.
template <>
void Copy<pten::Place, pten::Place>(pten::Place dst_place, void* dst,
                                    pten::Place src_place, const void* src,
                                    size_t num, mluStream stream) {
  if (src_place.GetType() == pten::AllocationType::CPU &&
      dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::CPU &&
             dst_place.GetType() == pten::AllocationType::MLU) {
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::MLU &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num, stream);
  } else if (src_place.GetType() == pten::AllocationType::MLU &&
             dst_place.GetType() == pten::AllocationType::MLU) {
    platform::MLUPlace place_src(src_place.GetDeviceId());
    platform::MLUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num, stream);
  }
}

// NOTE: only for (CPUPlace and MLUPlace) -> (MLUPlace)
template <>
void Copy<pten::MLUPlace, pten::Place>(pten::MLUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num, mluStream stream) {
  Copy(pten::Place(dst_place.GetType(), dst_place.GetDeviceId()), dst,
       src_place, src, num, stream);
}

// NOTE: only for (MLUPlace) -> (CPUPlace and MLUPlace)
template <>
void Copy<pten::Place, pten::MLUPlace>(pten::Place dst_place, void* dst,
                                       pten::MLUPlace src_place,
                                       const void* src, size_t num,
                                       mluStream stream) {
  Copy(dst_place, dst,
       pten::Place(src_place.GetType(), src_place.GetDeviceId()), src, num,
       stream);
}

F
fwenguang 已提交
965 966
#endif  // PADDLE_WITH_MLU

967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
// NOTE: Only for CPUPlace, XPUPlace and PinnedPlace.
template <>
void Copy<pten::Place, pten::Place>(pten::Place dst_place, void* dst,
                                    pten::Place src_place, const void* src,
                                    size_t num) {
  if (UNLIKELY(num == 0)) return;
  VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
          << dst_place;
  if (src_place.GetType() == pten::AllocationType::CPU &&
      dst_place.GetType() == pten::AllocationType::CPU) {
    std::memcpy(dst, src, num);
  }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  else if (src_place.GetType() == pten::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == pten::AllocationType::GPUPINNED) {
    std::memcpy(dst, src, num);
  } else if (src_place.GetType() == pten::AllocationType::GPUPINNED &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    std::memcpy(dst, src, num);
  } else if (src_place.GetType() == pten::AllocationType::GPUPINNED &&
             dst_place.GetType() == pten::AllocationType::GPUPINNED) {
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_ASCEND_CL
  else if (src_place.GetType() == pten::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == pten::AllocationType::NPUPINNED) {
    std::memcpy(dst, src, num);
  } else if (src_place.GetType() == pten::AllocationType::NPUPINNED &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    std::memcpy(dst, src, num);
  } else if (src_place.GetType() == pten::AllocationType::NPUPINNED &&
             dst_place.GetType() == pten::AllocationType::NPUPINNED) {
    std::memcpy(dst, src, num);
  }
#endif
#ifdef PADDLE_WITH_XPU
  else if (src_place.GetType() == pten::AllocationType::CPU &&  // NOLINT
           dst_place.GetType() == pten::AllocationType::CPU) {
    platform::CPUPlace place_dst, place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::CPU &&
             dst_place.GetType() == pten::AllocationType::XPU) {
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    platform::CPUPlace place_src;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::XPU &&
             dst_place.GetType() == pten::AllocationType::CPU) {
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::CPUPlace place_dst;
    return Copy(place_dst, dst, place_src, src, num);
  } else if (src_place.GetType() == pten::AllocationType::XPU &&
             dst_place.GetType() == pten::AllocationType::XPU) {
    platform::XPUPlace place_src(src_place.GetDeviceId());
    platform::XPUPlace place_dst(dst_place.GetDeviceId());
    return Copy(place_dst, dst, place_src, src, num);
  }
#endif
}

// NOTE: Only for (CPUPlace) -> (CPUPlace and PinnedPlace).
template <>
void Copy<pten::Place, pten::CPUPlace>(pten::Place dst_place, void* dst,
                                       pten::CPUPlace src_place,
                                       const void* src, size_t num) {
  Copy(dst_place, dst, pten::Place(src_place.GetType()), src, num);
}

// NOTE: Only for (CPUPlace and PinnedPlace) -> (CPUPlace).
template <>
void Copy<pten::CPUPlace, pten::Place>(pten::CPUPlace dst_place, void* dst,
                                       pten::Place src_place, const void* src,
                                       size_t num) {
  Copy(pten::Place(dst_place.GetType()), dst, src_place, src, num);
}

Y
Yi Wang 已提交
1043 1044
}  // namespace memory
}  // namespace paddle