memcpy.cc 4.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/memory/memcpy.h"
16 17 18 19 20 21 22 23 24 25 26 27 28

#include <cstring>  // for memcpy

namespace paddle {
namespace memory {

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

29
#ifdef PADDLE_WITH_CUDA
30
template <>
D
dzhwinter 已提交
31 32
void Copy<platform::CPUPlace, platform::CUDAPlace>(
    platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
33
    const void* src, size_t num, cudaStream_t stream) {
L
liaogang 已提交
34
  platform::SetDeviceId(src_place.device);
35 36 37 38 39
  if (stream) {
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
  }
40 41 42
}

template <>
D
dzhwinter 已提交
43 44
void Copy<platform::CUDAPlace, platform::CPUPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
F
fengjiayi 已提交
45
    const void* src, size_t num, cudaStream_t stream) {
L
liaogang 已提交
46
  platform::SetDeviceId(dst_place.device);
47 48 49 50 51
  if (stream) {
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
  }
52 53 54
}

template <>
D
dzhwinter 已提交
55 56
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
    platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
F
fengjiayi 已提交
57
    const void* src, size_t num, cudaStream_t stream) {
58
  if (dst_place == src_place) {
L
liaogang 已提交
59
    platform::SetDeviceId(src_place.device);
60 61 62 63 64
    if (stream) {
      platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
    } else {
      platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
    }
65
  } else {
66 67 68 69 70
    if (stream) {
      platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device,
                                   num, stream);
    } else {
      platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device,
F
fengjiayi 已提交
71
                                  num);
72
    }
73 74 75
  }
}

C
chengduoZH 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
    platform::CPUPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CPUPlace src_place, const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
  std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
    platform::CUDAPinnedPlace dst_place, void* dst,
    platform::CUDAPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
101
    cudaStream_t stream) {
C
chengduoZH 已提交
102
  platform::SetDeviceId(src_place.device);
103 104 105 106 107
  if (stream) {
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
  } else {
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost);
  }
C
chengduoZH 已提交
108 109 110 111 112 113
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
    platform::CUDAPlace dst_place, void* dst,
    platform::CUDAPinnedPlace src_place, const void* src, size_t num,
F
fengjiayi 已提交
114
    cudaStream_t stream) {
C
chengduoZH 已提交
115
  platform::SetDeviceId(dst_place.device);
116 117 118 119 120
  if (stream) {
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
  } else {
    platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
  }
C
chengduoZH 已提交
121 122
}

L
Luo Tao 已提交
123
#endif
Y
Yi Wang 已提交
124 125 126

}  // namespace memory
}  // namespace paddle