memory.cc 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/memory/memory.h"
L
liaogang 已提交
16 17
#include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h"
18

19 20 21
namespace paddle {
namespace memory {

L
liaogang 已提交
22 23 24 25 26 27 28
detail::BuddyAllocator* GetCPUBuddyAllocator() {
  static detail::BuddyAllocator* a = nullptr;
  if (a == nullptr) {
    a = new detail::BuddyAllocator(new detail::CPUAllocator,
                                   platform::CpuMinChunkSize(),
                                   platform::CpuMaxChunkSize());
  }
L
liaogang 已提交
29
  return a;
L
liaogang 已提交
30 31
}

L
liaogang 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
template <>
void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size) {
  return GetCPUBuddyAllocator()->Alloc(size);
}

template <>
void Free<platform::CPUPlace>(platform::CPUPlace place, void* p) {
  GetCPUBuddyAllocator()->Free(p);
}

template <>
size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
  return GetCPUBuddyAllocator()->Used();
}

L
liaogang 已提交
47 48
template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
L
liaogang 已提交
49 50 51
                                                  platform::CPUPlace,
                                                  const void* src, size_t num) {
  memcpy(dst, src, num);
L
liaogang 已提交
52 53
}

L
liaogang 已提交
54
#ifndef PADDLE_ONLY_CPU
L
liaogang 已提交
55 56 57 58

detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
  static detail::BuddyAllocator** as = NULL;
  if (as == NULL) {
L
liaogang 已提交
59
    int gpu_num = platform::GetDeviceCount();
L
liaogang 已提交
60 61 62 63 64 65 66 67 68 69 70
    as = new detail::BuddyAllocator*[gpu_num];
    for (int gpu = 0; gpu < gpu_num; gpu++) {
      platform::SetDeviceId(gpu);
      as[gpu] = new detail::BuddyAllocator(new detail::GPUAllocator,
                                           platform::GpuMinChunkSize(),
                                           platform::GpuMaxChunkSize());
    }
  }
  return as[gpu_id];
}

L
liaogang 已提交
71 72 73 74
template <>
void* Alloc<platform::GPUPlace>(platform::GPUPlace place, size_t size) {
  return GetGPUBuddyAllocator(place.device)->Alloc(size);
}
L
liaogang 已提交
75

L
liaogang 已提交
76 77 78
template <>
void Free<platform::GPUPlace>(platform::GPUPlace place, void* p) {
  GetGPUBuddyAllocator(place.device)->Free(p);
79 80
}

L
liaogang 已提交
81 82 83
template <>
size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
  return GetGPUBuddyAllocator(place.device)->Used();
84 85
}

L
liaogang 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace, void* dst,
                                                  platform::GPUPlace,
                                                  const void* src, size_t num,
                                                  cudaStream_t stream) {
  platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}

template <>
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace, void* dst,
                                                  platform::CPUPlace,
                                                  const void* src, size_t num,
                                                  cudaStream_t stream) {
  platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}

template <>
void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
                                                  void* dst,
                                                  platform::GPUPlace src_place,
                                                  const void* src, size_t num,
                                                  cudaStream_t stream) {
  if (dst_place == src_place) {
    platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
  } else {
    platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,
                            stream);
  }
}

116
#endif  // PADDLE_ONLY_CPU
117 118 119

}  // namespace memory
}  // namespace paddle