提交 6529a7c5 编写于 作者: W willzhang4a58

two buf in one stream


Former-commit-id: dcf1b183c3b86ac547f9ac4581e2d05de0feb853
上级 51065dc5
op {
name: "decode"
decode_ofrecord_conf {
data_dir: "/dataset/mnist_kaggle/test/6"
data_dir: "/dataset/mnist_kaggle/train/6"
blob {
name: "img_raw"
data_type: kFloat
......
......@@ -177,7 +177,7 @@ uint64_t Improver::AvailableMemSize(int64_t machine_id, int64_t memory_zone_id)
JobDesc* job_desc = Global<JobDesc>::Get();
if (memory_zone_id == job_desc->GpuDeviceNum()) {
mem_size -= job_desc->reserved_host_mem_byte();
mem_size -= job_desc->persistence_buf_byte() * record_load_task_num_.at(machine_id);
mem_size -= job_desc->persistence_buf_byte() * 2 * record_load_task_num_.at(machine_id);
} else {
mem_size -= job_desc->reserved_device_mem_byte();
}
......
#include "oneflow/core/persistence/cyclic_persistent_in_stream_without_local_copy.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow {
......@@ -10,8 +11,9 @@ CyclicPersistentInStreamWithoutLocalCopy::CyclicPersistentInStreamWithoutLocalCo
}
void CyclicPersistentInStreamWithoutLocalCopy::UpdateBuffer() {
if (is_first_update_buffer_ == false && file_size() <= mut_buffer()->size() - 1) {
set_cur_buf_begin(mut_buffer()->data());
if (is_first_update_buffer_ == false
&& file_size() <= Global<JobDesc>::Get()->persistence_buf_byte()) {
set_cur_buf_begin(mut_buffer());
} else {
PersistentInStreamWithoutLocalCopy::UpdateBuffer();
}
......
......@@ -9,7 +9,7 @@ class CyclicPersistentInStreamWithoutLocalCopy final : public PersistentInStream
public:
OF_DISALLOW_COPY_AND_MOVE(CyclicPersistentInStreamWithoutLocalCopy);
CyclicPersistentInStreamWithoutLocalCopy() = delete;
~CyclicPersistentInStreamWithoutLocalCopy() = default;
~CyclicPersistentInStreamWithoutLocalCopy() { WaitUntilStandByBufferReadyBytesNotEqualZero(); }
CyclicPersistentInStreamWithoutLocalCopy(fs::FileSystem* fs, const std::string& file_path);
......
......@@ -18,6 +18,8 @@ class NormalPersistentInStream final : public PersistentInStreamWithoutLocalCopy
NormalPersistentInStream(fs::FileSystem* fs, const std::string& file_path)
: NormalPersistentInStream(fs, file_path, 0) {}
~NormalPersistentInStream() { WaitUntilStandByBufferReadyBytesNotEqualZero(); }
private:
void AddNForCurFilePos(uint64_t n) override { set_cur_file_pos(cur_file_pos() + n); }
};
......
#include "oneflow/core/persistence/persistent_in_stream_without_local_copy.h"
#include "oneflow/core/job/job_desc.h"
#include <cstring>
#include "oneflow/core/thread/thread_pool.h"
namespace oneflow {
static ThreadPool g_persistent_in_thread_pool(1);
PersistentInStreamWithoutLocalCopy::~PersistentInStreamWithoutLocalCopy() {
WaitUntilStandByBufferReadyBytesNotEqualZero();
delete[] standby_buffer_;
delete[] buffer_;
}
int32_t PersistentInStreamWithoutLocalCopy::ReadLine(std::string* l) {
if (IsEof()) { return -1; }
l->clear();
......@@ -27,8 +35,8 @@ int32_t PersistentInStreamWithoutLocalCopy::Read(char* s, size_t n) {
while (n) {
if (cur_buf_begin_ == cur_buf_end_) { UpdateBuffer(); }
CHECK_LT(cur_buf_begin_, cur_buf_end_);
int64_t copy_size = std::min(cur_buf_end_ - cur_buf_begin_, static_cast<int64_t>(n));
std::memcpy(s, cur_buf_begin_, static_cast<size_t>(copy_size));
size_t copy_size = std::min<size_t>(cur_buf_end_ - cur_buf_begin_, n);
memcpy(s, cur_buf_begin_, copy_size);
s += copy_size;
cur_buf_begin_ += copy_size;
n -= copy_size;
......@@ -41,26 +49,56 @@ PersistentInStreamWithoutLocalCopy::PersistentInStreamWithoutLocalCopy(fs::FileS
uint64_t offset) {
fs->NewRandomAccessFile(file_path, &file_);
file_size_ = fs->GetFileSize(file_path);
CHECK_LT(offset, file_size_);
standby_buffer_ = new char[Global<JobDesc>::Get()->persistence_buf_byte() + 1];
standby_buffer_ready_bytes_ = 0;
cur_file_pos_ = offset;
buffer_.resize(Global<JobDesc>::Get()->persistence_buf_byte() + 1);
cur_buf_begin_ = buffer_.data();
cur_buf_end_ = buffer_.data();
file_read_done_ = false;
buffer_ = new char[Global<JobDesc>::Get()->persistence_buf_byte() + 1];
cur_buf_begin_ = buffer_;
cur_buf_end_ = buffer_;
*cur_buf_end_ = '\0';
}
bool PersistentInStreamWithoutLocalCopy::IsEof() const {
return cur_buf_begin_ == cur_buf_end_ && cur_file_pos_ == file_size_;
AsyncUpdateStandByBuffer();
}
void PersistentInStreamWithoutLocalCopy::UpdateBuffer() {
CHECK_EQ(cur_buf_begin_, cur_buf_end_);
uint64_t n = std::min(buffer_.size() - 1, file_size_ - cur_file_pos_);
if (n == 0) { return; }
file_->Read(cur_file_pos_, n, buffer_.data());
cur_buf_begin_ = buffer_.data();
cur_buf_end_ = buffer_.data() + n;
WaitUntilStandByBufferReadyBytesNotEqualZero();
if (standby_buffer_ready_bytes_ == -1) { return; }
std::swap(standby_buffer_, buffer_);
cur_buf_begin_ = buffer_;
cur_buf_end_ = buffer_ + standby_buffer_ready_bytes_;
*cur_buf_end_ = '\0';
AddNForCurFilePos(n);
standby_buffer_ready_bytes_ = 0;
AsyncUpdateStandByBuffer();
}
void PersistentInStreamWithoutLocalCopy::WaitUntilStandByBufferReadyBytesNotEqualZero() {
std::unique_lock<std::mutex> lck(standby_buffer_ready_mtx_);
standby_buffer_ready_cond_.wait(lck, [this]() { return standby_buffer_ready_bytes_ != 0; });
}
void PersistentInStreamWithoutLocalCopy::AsyncUpdateStandByBuffer() {
g_persistent_in_thread_pool.AddWork([this]() {
uint64_t n =
std::min(Global<JobDesc>::Get()->persistence_buf_byte(), file_size_ - cur_file_pos_);
if (n > 0) {
file_->Read(cur_file_pos_, n, standby_buffer_);
AddNForCurFilePos(n);
}
if (cur_file_pos_ == file_size_) { file_read_done_ = true; }
std::unique_lock<std::mutex> lck(standby_buffer_ready_mtx_);
if (n > 0) {
standby_buffer_ready_bytes_ = n;
} else {
standby_buffer_ready_bytes_ = -1;
}
standby_buffer_ready_cond_.notify_all();
});
}
bool PersistentInStreamWithoutLocalCopy::IsEof() const {
return cur_buf_begin_ == cur_buf_end_ && file_read_done_;
}
} // namespace oneflow
......@@ -10,7 +10,7 @@ class PersistentInStreamWithoutLocalCopy : public PersistentInStream {
public:
OF_DISALLOW_COPY_AND_MOVE(PersistentInStreamWithoutLocalCopy);
PersistentInStreamWithoutLocalCopy() = delete;
virtual ~PersistentInStreamWithoutLocalCopy() = default;
virtual ~PersistentInStreamWithoutLocalCopy();
int32_t ReadLine(std::string* l) override;
int32_t Read(char* s, size_t n) override;
......@@ -21,18 +21,27 @@ class PersistentInStreamWithoutLocalCopy : public PersistentInStream {
virtual void UpdateBuffer();
virtual void AddNForCurFilePos(uint64_t n) = 0;
uint64_t file_size() const { return file_size_; }
std::vector<char>* mut_buffer() { return &buffer_; }
char* mut_buffer() { return buffer_; }
uint64_t cur_file_pos() const { return cur_file_pos_; }
void set_cur_file_pos(uint64_t val) { cur_file_pos_ = val; }
void set_cur_buf_begin(char* val) { cur_buf_begin_ = val; }
void WaitUntilStandByBufferReadyBytesNotEqualZero();
private:
void AsyncUpdateStandByBuffer();
bool IsEof() const;
std::unique_ptr<fs::RandomAccessFile> file_;
uint64_t file_size_;
char* standby_buffer_;
int64_t standby_buffer_ready_bytes_;
std::mutex standby_buffer_ready_mtx_;
std::condition_variable standby_buffer_ready_cond_;
uint64_t cur_file_pos_;
std::vector<char> buffer_;
std::atomic<bool> file_read_done_;
char* buffer_;
char* cur_buf_begin_;
char* cur_buf_end_;
};
......
#include "oneflow/core/thread/thread_pool.h"
namespace oneflow {
ThreadPool::ThreadPool(int32_t thread_num)
: work_chans_(thread_num), threads_(thread_num), cur_chan_idx_(0) {
FOR_RANGE(int32_t, i, 0, thread_num) {
Channel<std::function<void()>>* chan = &(work_chans_.at(i));
threads_[i] = std::thread([chan]() {
std::function<void()> work;
while (chan->Receive(&work) == 0) { work(); }
});
}
}
ThreadPool::~ThreadPool() {
FOR_RANGE(int32_t, i, 0, work_chans_.size()) {
work_chans_.at(i).CloseSendEnd();
work_chans_.at(i).CloseReceiveEnd();
threads_.at(i).join();
}
}
void ThreadPool::AddWork(std::function<void()> work) {
if (work_chans_.size() > 1) {
std::unique_lock<std::mutex> lck(cur_chan_idx_mtx_);
work_chans_.at(cur_chan_idx_).Send(work);
cur_chan_idx_ = (cur_chan_idx_ + 1) % work_chans_.size();
} else {
work_chans_.at(cur_chan_idx_).Send(work);
}
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_THREAD_THREAD_POOL_H_
#define ONEFLOW_CORE_THREAD_THREAD_POOL_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/channel.h"
namespace oneflow {
class ThreadPool final {
public:
OF_DISALLOW_COPY_AND_MOVE(ThreadPool);
ThreadPool() = delete;
ThreadPool(int32_t thread_num);
~ThreadPool();
void AddWork(std::function<void()> work);
private:
std::vector<Channel<std::function<void()>>> work_chans_;
std::vector<std::thread> threads_;
std::mutex cur_chan_idx_mtx_;
int32_t cur_chan_idx_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_THREAD_THREAD_POOL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册