未验证 提交 8b185a7a 编写于 作者: C cheng cheng 提交者: GitHub

Eager:: CudaAllocator (#3079)

* vm::cuda_allocator interface

* Implement of CudaAllocator::Allocate

* Implement of CudaAllocator::Deallocate

* Add cuda allocator test scripts and fix bug of CudaAllocator::pieces_ and pass test

* remove part note

* fix CudaAllocator::total_memory_bytes_ strategy and Pass all test

* CudaAllocator Support dynamic growth and garbage collection

* add interface note

* refine total_memory_bytes when garbage collection

* fix bug of merge

* remove log and strong boundary test

* SingleThreadOnlyAllocator and ThreadSafeAllocator implement and test

* SingleThreadOnlyAllcoator for CudaStreamHanldeDeviceCtx

* ThreadOnlyAllcoator for CudaStreamHanldeDeviceCtx

* fix spell
Co-authored-by: NLi Xinqi <lixinqi2010@gmail.com>
Co-authored-by: Nguo ran <360112263@qq.com>
上级 c58f1a5c
#include "oneflow/core/vm/cuda_allocator.h" #include "oneflow/core/vm/cuda_allocator.h"
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
#include <iostream>
namespace oneflow { namespace oneflow {
namespace vm { namespace vm {
void CudaAllocator::Allocate(char** mem_ptr, std::size_t size) { namespace {
inline size_t CudaMemAlignedBytes(size_t bytes) { return RoundUp(bytes, kCudaMemAllocAlignSize); }
inline bool IsAlignedSize(size_t size) { return size % kCudaMemAllocAlignSize == 0; }
static const size_t kPieceSplitThreshold = 128 << 20; // 128MiB
} // namespace
CudaAllocator::CudaAllocator(int64_t device_id)
: Allocator(), device_id_(device_id), total_memory_bytes_(0), recycle_piece_list_(nullptr) {
bins_.resize(kBinNumSize);
for (int i = 0; i < kBinNumSize; ++i) {
size_t bin_size = BinSize4BinNum(i);
bins_.at(i).size = bin_size;
CHECK_EQ(BinNum4BinSize(bin_size), i);
CHECK_EQ(BinNum4BinSize(bin_size + kCudaMemAllocAlignSize - 1), i);
CHECK_EQ(BinNum4BinSize(bin_size * 2 - 1), i);
CHECK_EQ(BinNum4BinSize(bin_size * 2), i == (kBinNumSize - 1) ? i : i + 1);
}
}
CudaAllocator::~CudaAllocator() {
if (total_memory_bytes_ == 0) {
CHECK_EQ(mem_ptr2block_.size(), 0);
return;
}
cudaSetDevice(device_id_); cudaSetDevice(device_id_);
CudaCheck(cudaMalloc(mem_ptr, size)); for (auto& pair : mem_ptr2block_) { CudaCheck(cudaFree(pair.first)); }
} }
void CudaAllocator::Deallocate(char* mem_ptr, std::size_t size) { void CudaAllocator::InsertPiece2Bin(Piece* piece) {
CHECK(piece->is_free && piece->bin_num == kInvalidBinNum);
int32_t bin_num = BinNum4BinSize(piece->size);
piece->bin_num = bin_num;
CHECK(bins_.at(bin_num).pieces.insert(piece).second);
}
void CudaAllocator::RemovePieceFromBin(Piece* piece) {
CHECK(piece->is_free);
CHECK_NE(piece->bin_num, kInvalidBinNum);
CHECK_GT(bins_.at(piece->bin_num).pieces.erase(piece), 0);
piece->bin_num = kInvalidBinNum;
}
CudaAllocator::Piece* CudaAllocator::AllocatePiece() {
if (recycle_piece_list_) {
Piece* ret = recycle_piece_list_;
recycle_piece_list_ = recycle_piece_list_->next;
return ret;
} else {
pieces_.emplace_back(new Piece());
return pieces_.at(pieces_.size() - 1).get();
}
}
void CudaAllocator::DeallocatePiece(Piece* piece) {
piece->ptr = nullptr;
piece->size = 0;
piece->bin_num = kInvalidBinNum;
piece->is_free = true;
piece->prev = nullptr;
piece->next = recycle_piece_list_;
recycle_piece_list_ = piece;
}
void CudaAllocator::MarkPiece(Piece* piece) {
CHECK_NOTNULL(piece->ptr);
CHECK(ptr2piece_.emplace(piece->ptr, piece).second);
}
void CudaAllocator::UnMarkPiece(Piece* piece) {
CHECK_NOTNULL(piece->ptr);
auto it = ptr2piece_.find(piece->ptr);
CHECK(it != ptr2piece_.end());
ptr2piece_.erase(it);
}
CudaAllocator::Piece* CudaAllocator::FindPiece(size_t aligned_size) {
CHECK(IsAlignedSize(aligned_size));
for (int32_t bin_num = BinNum4BinSize(aligned_size); bin_num < kBinNumSize; ++bin_num) {
Bin* bin = &bins_.at(bin_num);
for (auto it = bin->pieces.begin(); it != bin->pieces.end(); ++it) {
Piece* piece = *it;
CHECK(piece->is_free);
CHECK_NOTNULL(piece->ptr);
CHECK_EQ(piece->bin_num, bin_num);
CHECK(IsAlignedSize(piece->size));
if (piece->size >= aligned_size) {
bin->pieces.erase(it);
piece->bin_num = kInvalidBinNum;
piece->is_free = false;
if (piece->size >= aligned_size * 2 || piece->size - aligned_size >= kPieceSplitThreshold) {
Piece* new_piece = AllocatePiece();
new_piece->ptr = piece->ptr + aligned_size;
new_piece->size = piece->size - aligned_size;
piece->size = aligned_size;
Piece* next_p = piece->next;
piece->next = new_piece;
new_piece->prev = piece;
new_piece->next = next_p;
if (next_p != nullptr) { next_p->prev = new_piece; }
new_piece->is_free = true;
new_piece->bin_num = kInvalidBinNum;
CHECK(IsAlignedSize(piece->size));
CHECK(IsAlignedSize(new_piece->size));
InsertPiece2Bin(new_piece);
MarkPiece(new_piece);
}
return piece;
}
}
}
return nullptr;
}
void CudaAllocator::MergeNeighbourFreePiece(Piece* lhs, Piece* rhs) {
CHECK(lhs->is_free);
CHECK(rhs->is_free);
CHECK(lhs->next == rhs);
CHECK(lhs == rhs->prev);
CHECK(lhs->ptr + lhs->size == rhs->ptr);
lhs->size += rhs->size;
lhs->next = rhs->next;
if (rhs->next != nullptr) { rhs->next->prev = lhs; }
UnMarkPiece(rhs);
DeallocatePiece(rhs);
}
bool CudaAllocator::AllocateBlockToExtendTotalMem(size_t aligned_size) {
CHECK(IsAlignedSize(aligned_size));
size_t allocate_bytes = 1048576; // 1MiB base size
allocate_bytes = std::max(allocate_bytes, aligned_size);
cudaSetDevice(device_id_); cudaSetDevice(device_id_);
CudaCheck(cudaFree(mem_ptr)); size_t free_bytes = -1;
size_t total_bytes = -1;
CudaCheck(cudaMemGetInfo(&free_bytes, &total_bytes));
const size_t remain_bytes = 50 * 1048576;
const size_t available_bytes = free_bytes - remain_bytes; // remain at least 50MiB memory
// growth double total memory bytes if could
if (total_memory_bytes_ > 0) {
allocate_bytes = std::max(allocate_bytes, std::min(total_memory_bytes_, available_bytes));
}
const size_t final_allocate_bytes = CudaMemAlignedBytes(allocate_bytes);
if (final_allocate_bytes > available_bytes) { return false; }
if (final_allocate_bytes < aligned_size) { return false; }
char* mem_ptr = nullptr;
if (cudaMalloc(&mem_ptr, final_allocate_bytes) != cudaSuccess) { return false; }
// extend sucess
total_memory_bytes_ += final_allocate_bytes;
Piece* piece = AllocatePiece();
piece->size = final_allocate_bytes;
piece->ptr = mem_ptr;
piece->prev = nullptr;
piece->next = nullptr;
piece->is_free = true;
piece->bin_num = kInvalidBinNum;
InsertPiece2Bin(piece);
MarkPiece(piece);
CHECK(mem_ptr2block_.emplace(mem_ptr, Block(piece)).second);
return true;
}
bool CudaAllocator::DeallocateFreeBlockForGarbageCollection() {
size_t total_free_bytes = 0;
HashSet<char*> free_block_ptrs;
for (const auto& pair : mem_ptr2block_) {
const Block& block = pair.second;
bool all_free = true;
Piece* p = block.start_piece;
while (p != nullptr) {
if (!(p->is_free)) {
all_free = false;
break;
}
p = p->next;
}
if (all_free) {
total_free_bytes += block.size;
free_block_ptrs.insert(pair.first);
}
}
total_memory_bytes_ -= total_free_bytes;
if (total_free_bytes > 0) {
LOG(WARNING) << "CudaAllocator try deallocate free block for garbage collection. "
<< " deallocate free bytes : " << total_free_bytes;
cudaSetDevice(device_id_);
for (char* ptr : free_block_ptrs) {
auto it = mem_ptr2block_.find(ptr);
CHECK(it != mem_ptr2block_.end());
const Block& block = it->second;
// delete all Piece on Block
size_t piece_size_sum = 0;
Piece* p = block.start_piece;
CHECK_EQ(block.ptr, block.start_piece->ptr);
CHECK_EQ(block.ptr, ptr);
while (p != nullptr) {
Piece* next_p = p->next;
piece_size_sum += p->size;
RemovePieceFromBin(p);
UnMarkPiece(p);
DeallocatePiece(p);
p = next_p;
}
CHECK_EQ(block.size, piece_size_sum);
mem_ptr2block_.erase(it);
CudaCheck(cudaFree(ptr));
}
}
return total_free_bytes > 0;
}
void CudaAllocator::Allocate(char** mem_ptr, std::size_t size) {
if (size == 0) {
*mem_ptr = nullptr;
return;
}
size_t aligned_size = CudaMemAlignedBytes(size);
Piece* piece = FindPiece(aligned_size);
if (piece == nullptr) {
if (AllocateBlockToExtendTotalMem(aligned_size)) { piece = FindPiece(aligned_size); }
}
if (piece == nullptr) {
if (DeallocateFreeBlockForGarbageCollection() && AllocateBlockToExtendTotalMem(aligned_size)) {
piece = FindPiece(aligned_size);
}
}
CHECK(piece != nullptr) << "Error! : Out of memory when allocate size : " << size;
CHECK_NOTNULL(piece->ptr);
CHECK(ptr2piece_.find(piece->ptr) != ptr2piece_.end());
*mem_ptr = piece->ptr;
}
void CudaAllocator::Deallocate(char* mem_ptr, std::size_t size) {
if (mem_ptr == nullptr) { return; }
auto it = ptr2piece_.find(mem_ptr);
CHECK(it != ptr2piece_.end()) << "Error! : Try deallocate mem_ptr non-existent. mem ptr = "
<< mem_ptr << " size = " << size;
Piece* piece = it->second;
CHECK_NOTNULL(piece);
CHECK_EQ(piece->ptr, mem_ptr);
CHECK(!piece->is_free);
piece->is_free = true;
Piece* last_piece_insert_to_bin = piece;
Piece* next_p = piece->next;
Piece* prev_p = piece->prev;
if (next_p != nullptr && next_p->is_free) {
CHECK_EQ(next_p->ptr, piece->ptr + piece->size);
RemovePieceFromBin(next_p);
MergeNeighbourFreePiece(piece, next_p);
}
if (prev_p != nullptr && prev_p->is_free) {
CHECK_EQ(piece->ptr, prev_p->ptr + prev_p->size);
RemovePieceFromBin(prev_p);
MergeNeighbourFreePiece(prev_p, piece);
last_piece_insert_to_bin = prev_p;
}
InsertPiece2Bin(last_piece_insert_to_bin);
} }
} // namespace vm } // namespace vm
......
...@@ -3,20 +3,109 @@ ...@@ -3,20 +3,109 @@
#include <cstdint> #include <cstdint>
#include "oneflow/core/vm/allocator.h" #include "oneflow/core/vm/allocator.h"
#include "oneflow/core/common/util.h"
namespace oneflow { namespace oneflow {
namespace vm { namespace vm {
class CudaAllocator final : public Allocator { class CudaAllocator final : public Allocator {
public: public:
explicit CudaAllocator(int64_t device_id) : Allocator(), device_id_(device_id) {} explicit CudaAllocator(int64_t device_id);
~CudaAllocator() override = default; ~CudaAllocator() override;
void Allocate(char** mem_ptr, std::size_t size) override; void Allocate(char** mem_ptr, std::size_t size) override;
void Deallocate(char* mem_ptr, std::size_t size) override; void Deallocate(char* mem_ptr, std::size_t size) override;
private: private:
static constexpr int32_t kInvalidBinNum = -1;
static constexpr int32_t kBinNumSize = 20;
// Piece is the basic memory unit of CudaAllocator.
// A Piece is either is free(is_free = true) or in used(is_free = false).
// If the Piece is_free = true, the pointer to the piece will be stored in the Bin structure of
// the corresponding BinSize. Pieces are stored in a linked list. The Piece's prev and next are
// continuous with the current Piece in physical memory.
struct Piece {
size_t size = 0;
char* ptr = nullptr;
bool is_free = false;
Piece* prev = nullptr;
Piece* next = nullptr;
int32_t bin_num = kInvalidBinNum;
};
// Bin is a structure that stores a set of pieces which is free and has similar size, and
// these Pieces are arger than the size of bin
//
// CudaAllocator has a set of Bin structures according to the binary multiple increasing relation,
// which is used to quickly index and find the free Piece of appropriate size when Allocate()
//
// The size of the smallest bin is 512 (512 is the smallest unit Allocated by CudaAllocator,
// and the memory size of all Allocated will be multiples of 512, 512 is kCudaMemAllocAlignSize).
// The size of each Bin is twice the size of the previous Bin, like
// BinNum: Bin0, Bin1, Bin2, Bin3, ..., Bin19
// BinSize: 512, 1024, 2048, 4096, ... , 512MB
struct Bin {
size_t size = 0;
struct PieceCmp {
bool operator()(const Piece* lhs, const Piece* rhs) const {
if (lhs->size != rhs->size) { return lhs->size < rhs->size; }
return lhs->ptr < rhs->ptr;
}
};
std::set<Piece*, PieceCmp> pieces;
};
// Block is large physical memory that is actually allocated.
// There maybe many consecutive disjoint Pieces distributed on the Block memory
struct Block {
size_t size = 0;
char* ptr = nullptr;
Piece* start_piece = nullptr;
Block(Piece* p) : size(p->size), ptr(p->ptr), start_piece(p) {}
};
size_t BinSize4BinNum(int32_t bin_num) { return kCudaMemAllocAlignSize << bin_num; }
int32_t BinNum4BinSize(size_t size) {
uint64_t value = std::max(size, kCudaMemAllocAlignSize) >> 9;
return std::min(kBinNumSize - 1, static_cast<int32_t>(63 ^ __builtin_clzll(value)));
}
// Try find free Piece which size is larger than aligned_size in Bins.
// Return nullptr when find failure
Piece* FindPiece(size_t aligned_size);
// Insert the free Piece to the appropriate Bin which bin size is smaller than piece
void InsertPiece2Bin(Piece* piece);
// Create new empty Piece or recycle a Piece from recycle_piece_list_
Piece* AllocatePiece();
// Delete a Piece and move in the linked list recycle_piece_list_
void DeallocatePiece(Piece* piece);
// Insert a {piece->ptr, piece} pair into the ptr2piece_ map for search Piece when call
// Deallocate()
void MarkPiece(Piece* piece);
// Erase the {piece->ptr, piece} pair from ptr2piece_ because the ptr is useless
// Usually call before DeallocatePiece()
void UnMarkPiece(Piece* piece);
void MergeNeighbourFreePiece(Piece* lhs, Piece* rhs);
void RemovePieceFromBin(Piece* piece);
bool AllocateBlockToExtendTotalMem(size_t aligned_size);
bool DeallocateFreeBlockForGarbageCollection();
int64_t device_id_; int64_t device_id_;
size_t total_memory_bytes_;
HashMap<char*, Block> mem_ptr2block_;
std::vector<Bin> bins_;
std::vector<std::unique_ptr<Piece>> pieces_;
HashMap<char*, Piece*> ptr2piece_;
Piece* recycle_piece_list_;
}; };
} // namespace vm } // namespace vm
......
#ifdef WITH_CUDA
#include "oneflow/core/vm/cuda_allocator.h"
#include "oneflow/core/vm/thread_safe_allocator.h"
#include "oneflow/core/device/cuda_util.h"
namespace oneflow {
namespace vm {
TEST(CudaAllocator, cuda_allocator) {
int gpu_num = -1;
cudaGetDeviceCount(&gpu_num);
if (gpu_num <= 0) {
LOG(INFO) << "CudaAllocator Test: Skip because of non GPU device.";
return;
}
ASSERT_TRUE(cudaSuccess == cudaSetDevice(0));
size_t free_bytes = -1;
size_t total_bytes = -1;
const size_t remain_bytes = 50 * 1048576;
ASSERT_TRUE(cudaSuccess == cudaMemGetInfo(&free_bytes, &total_bytes));
if (free_bytes <= remain_bytes || free_bytes - remain_bytes < remain_bytes) {
LOG(INFO) << "CudaAllocator Test: Skip because of allocator mem bytes less than 50MiB in GPU 0";
return;
}
std::unique_ptr<Allocator> allo(new CudaAllocator(0));
allo.reset(new SingleThreadOnlyAllocator(std::move(allo)));
Allocator* a = allo.get();
std::vector<char*> ptrs;
for (int i = 0; i < 512; ++i) {
char* ptr = nullptr;
a->Allocate(&ptr, 1);
ASSERT_TRUE(ptr != nullptr);
ptrs.push_back(ptr);
}
std::sort(ptrs.begin(), ptrs.end());
for (int i = 0; i < 512; ++i) {
if (i > 0) {
ASSERT_TRUE(ptrs.at(i) != ptrs.at(i - 1));
ASSERT_TRUE(std::abs(ptrs.at(i) - ptrs.at(i - 1)) >= kCudaMemAllocAlignSize);
}
a->Deallocate(ptrs.at(i), 1);
}
ptrs.clear();
for (int i = 0; i < 2048; ++i) {
char* ptr = nullptr;
a->Allocate(&ptr, 10000);
ASSERT_TRUE(ptr != nullptr);
ptrs.push_back(ptr);
}
std::sort(ptrs.begin(), ptrs.end());
for (int i = 0; i < 2048; ++i) {
if (i > 0) {
ASSERT_TRUE(ptrs.at(i) != ptrs.at(i - 1));
ASSERT_TRUE(std::abs(ptrs.at(i) - ptrs.at(i - 1)) >= kCudaMemAllocAlignSize);
}
a->Deallocate(ptrs.at(i), 10000);
}
char* data_ptr_1 = nullptr;
a->Allocate(&data_ptr_1, 2048 * sizeof(float));
char* data_ptr_2 = nullptr;
a->Allocate(&data_ptr_2, 4096 * sizeof(double));
ASSERT_TRUE(data_ptr_1 != data_ptr_2);
if (data_ptr_1 < data_ptr_2) {
ASSERT_TRUE(data_ptr_1 + 2048 * sizeof(float) <= data_ptr_2);
} else {
ASSERT_TRUE(data_ptr_2 + 4096 * sizeof(double) <= data_ptr_1);
}
a->Deallocate(data_ptr_2, 4096 * sizeof(double));
a->Deallocate(data_ptr_1, 2048 * sizeof(float));
}
} // namespace vm
} // namespace oneflow
#endif // WITH_CUDA
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "oneflow/core/device/cuda_stream_handle.h" #include "oneflow/core/device/cuda_stream_handle.h"
#include "oneflow/core/common/callback.msg.h" #include "oneflow/core/common/callback.msg.h"
#include "oneflow/core/vm/cuda_allocator.h" #include "oneflow/core/vm/cuda_allocator.h"
#include "oneflow/core/vm/thread_safe_allocator.h"
namespace oneflow { namespace oneflow {
namespace vm { namespace vm {
...@@ -21,7 +22,8 @@ class CudaStreamHandleDeviceCtx : public DeviceCtx { ...@@ -21,7 +22,8 @@ class CudaStreamHandleDeviceCtx : public DeviceCtx {
CudaStreamHandleDeviceCtx(CallbackMsgListPtr callback_msg_list, int64_t device_id) CudaStreamHandleDeviceCtx(CallbackMsgListPtr callback_msg_list, int64_t device_id)
: cuda_handler_(new CudaStreamHandle(nullptr)), : cuda_handler_(new CudaStreamHandle(nullptr)),
callback_msg_list_(callback_msg_list), callback_msg_list_(callback_msg_list),
cuda_allocator_(device_id) {} cuda_allocator_(
new ThreadSafeAllocator(std::unique_ptr<Allocator>(new CudaAllocator(device_id)))) {}
const cudaStream_t& cuda_stream() const override { return *(cuda_handler_->cuda_stream()); } const cudaStream_t& cuda_stream() const override { return *(cuda_handler_->cuda_stream()); }
const cublasHandle_t& cublas_pmh_handle() const override { const cublasHandle_t& cublas_pmh_handle() const override {
...@@ -41,12 +43,12 @@ class CudaStreamHandleDeviceCtx : public DeviceCtx { ...@@ -41,12 +43,12 @@ class CudaStreamHandleDeviceCtx : public DeviceCtx {
callback_msg_list_->EmplaceBack(ObjectMsgPtr<CallbackMsg>::New(callback)); callback_msg_list_->EmplaceBack(ObjectMsgPtr<CallbackMsg>::New(callback));
} }
vm::Allocator* mut_allocator() override { return &cuda_allocator_; } vm::Allocator* mut_allocator() override { return cuda_allocator_.get(); }
protected: protected:
std::unique_ptr<CudaStreamHandle> cuda_handler_; std::unique_ptr<CudaStreamHandle> cuda_handler_;
CallbackMsgListPtr callback_msg_list_; CallbackMsgListPtr callback_msg_list_;
CudaAllocator cuda_allocator_; std::unique_ptr<Allocator> cuda_allocator_;
}; };
#endif // WITH_CUDA #endif // WITH_CUDA
......
#include "oneflow/core/vm/thread_safe_allocator.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
namespace vm {
void ThreadSafeAllocator::Allocate(char** mem_ptr, std::size_t size) {
std::unique_lock<std::mutex> lock(mutex4backend_allocator_);
backend_allocator_->Allocate(mem_ptr, size);
}
void ThreadSafeAllocator::Deallocate(char* mem_ptr, std::size_t size) {
std::unique_lock<std::mutex> lock(mutex4backend_allocator_);
backend_allocator_->Deallocate(mem_ptr, size);
}
void SingleThreadOnlyAllocator::Allocate(char** mem_ptr, std::size_t size) {
CheckUniqueThreadAccess();
backend_allocator_->Allocate(mem_ptr, size);
}
void SingleThreadOnlyAllocator::Deallocate(char* mem_ptr, std::size_t size) {
CheckUniqueThreadAccess();
backend_allocator_->Deallocate(mem_ptr, size);
}
void SingleThreadOnlyAllocator::CheckUniqueThreadAccess() {
std::unique_lock<std::mutex> lock(mutex4accessed_thread_id_);
CHECK(accessed_thread_id_ == std::this_thread::get_id());
}
} // namespace vm
} // namespace oneflow
#ifndef ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_
#define ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_
#include <cstdint>
#include <mutex>
#include <thread>
#include "oneflow/core/vm/allocator.h"
namespace oneflow {
namespace vm {
class ThreadSafeAllocator final : public Allocator {
public:
explicit ThreadSafeAllocator(std::unique_ptr<Allocator>&& backend_allocator)
: Allocator(), backend_allocator_(std::move(backend_allocator)) {}
~ThreadSafeAllocator() override = default;
void Allocate(char** mem_ptr, std::size_t size) override;
void Deallocate(char* mem_ptr, std::size_t size) override;
private:
std::unique_ptr<Allocator> backend_allocator_;
std::mutex mutex4backend_allocator_;
};
class SingleThreadOnlyAllocator final : public Allocator {
public:
explicit SingleThreadOnlyAllocator(std::unique_ptr<Allocator>&& backend_allocator)
: Allocator(),
backend_allocator_(std::move(backend_allocator)),
accessed_thread_id_(std::this_thread::get_id()) {}
~SingleThreadOnlyAllocator() override = default;
void Allocate(char** mem_ptr, std::size_t size) override;
void Deallocate(char* mem_ptr, std::size_t size) override;
private:
void CheckUniqueThreadAccess();
std::unique_ptr<Allocator> backend_allocator_;
std::thread::id accessed_thread_id_;
std::mutex mutex4accessed_thread_id_;
};
} // namespace vm
} // namespace oneflow
#endif // ONEFLOW_CORE_VM_THREAD_SAFE_ALLOCATOR_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册