提交 6529a7c5 编写于 作者: W willzhang4a58

two buf in one stream


Former-commit-id: dcf1b183c3b86ac547f9ac4581e2d05de0feb853
上级 51065dc5
op { op {
name: "decode" name: "decode"
decode_ofrecord_conf { decode_ofrecord_conf {
data_dir: "/dataset/mnist_kaggle/test/6" data_dir: "/dataset/mnist_kaggle/train/6"
blob { blob {
name: "img_raw" name: "img_raw"
data_type: kFloat data_type: kFloat
......
...@@ -177,7 +177,7 @@ uint64_t Improver::AvailableMemSize(int64_t machine_id, int64_t memory_zone_id) ...@@ -177,7 +177,7 @@ uint64_t Improver::AvailableMemSize(int64_t machine_id, int64_t memory_zone_id)
JobDesc* job_desc = Global<JobDesc>::Get(); JobDesc* job_desc = Global<JobDesc>::Get();
if (memory_zone_id == job_desc->GpuDeviceNum()) { if (memory_zone_id == job_desc->GpuDeviceNum()) {
mem_size -= job_desc->reserved_host_mem_byte(); mem_size -= job_desc->reserved_host_mem_byte();
mem_size -= job_desc->persistence_buf_byte() * record_load_task_num_.at(machine_id); mem_size -= job_desc->persistence_buf_byte() * 2 * record_load_task_num_.at(machine_id);
} else { } else {
mem_size -= job_desc->reserved_device_mem_byte(); mem_size -= job_desc->reserved_device_mem_byte();
} }
......
#include "oneflow/core/persistence/cyclic_persistent_in_stream_without_local_copy.h" #include "oneflow/core/persistence/cyclic_persistent_in_stream_without_local_copy.h"
#include "oneflow/core/job/job_desc.h"
namespace oneflow { namespace oneflow {
...@@ -10,8 +11,9 @@ CyclicPersistentInStreamWithoutLocalCopy::CyclicPersistentInStreamWithoutLocalCo ...@@ -10,8 +11,9 @@ CyclicPersistentInStreamWithoutLocalCopy::CyclicPersistentInStreamWithoutLocalCo
} }
void CyclicPersistentInStreamWithoutLocalCopy::UpdateBuffer() { void CyclicPersistentInStreamWithoutLocalCopy::UpdateBuffer() {
if (is_first_update_buffer_ == false && file_size() <= mut_buffer()->size() - 1) { if (is_first_update_buffer_ == false
set_cur_buf_begin(mut_buffer()->data()); && file_size() <= Global<JobDesc>::Get()->persistence_buf_byte()) {
set_cur_buf_begin(mut_buffer());
} else { } else {
PersistentInStreamWithoutLocalCopy::UpdateBuffer(); PersistentInStreamWithoutLocalCopy::UpdateBuffer();
} }
......
...@@ -9,7 +9,7 @@ class CyclicPersistentInStreamWithoutLocalCopy final : public PersistentInStream ...@@ -9,7 +9,7 @@ class CyclicPersistentInStreamWithoutLocalCopy final : public PersistentInStream
public: public:
OF_DISALLOW_COPY_AND_MOVE(CyclicPersistentInStreamWithoutLocalCopy); OF_DISALLOW_COPY_AND_MOVE(CyclicPersistentInStreamWithoutLocalCopy);
CyclicPersistentInStreamWithoutLocalCopy() = delete; CyclicPersistentInStreamWithoutLocalCopy() = delete;
~CyclicPersistentInStreamWithoutLocalCopy() = default; ~CyclicPersistentInStreamWithoutLocalCopy() { WaitUntilStandByBufferReadyBytesNotEqualZero(); }
CyclicPersistentInStreamWithoutLocalCopy(fs::FileSystem* fs, const std::string& file_path); CyclicPersistentInStreamWithoutLocalCopy(fs::FileSystem* fs, const std::string& file_path);
......
...@@ -18,6 +18,8 @@ class NormalPersistentInStream final : public PersistentInStreamWithoutLocalCopy ...@@ -18,6 +18,8 @@ class NormalPersistentInStream final : public PersistentInStreamWithoutLocalCopy
NormalPersistentInStream(fs::FileSystem* fs, const std::string& file_path) NormalPersistentInStream(fs::FileSystem* fs, const std::string& file_path)
: NormalPersistentInStream(fs, file_path, 0) {} : NormalPersistentInStream(fs, file_path, 0) {}
~NormalPersistentInStream() { WaitUntilStandByBufferReadyBytesNotEqualZero(); }
private: private:
void AddNForCurFilePos(uint64_t n) override { set_cur_file_pos(cur_file_pos() + n); } void AddNForCurFilePos(uint64_t n) override { set_cur_file_pos(cur_file_pos() + n); }
}; };
......
#include "oneflow/core/persistence/persistent_in_stream_without_local_copy.h" #include "oneflow/core/persistence/persistent_in_stream_without_local_copy.h"
#include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_desc.h"
#include <cstring> #include "oneflow/core/thread/thread_pool.h"
namespace oneflow { namespace oneflow {
static ThreadPool g_persistent_in_thread_pool(1);
PersistentInStreamWithoutLocalCopy::~PersistentInStreamWithoutLocalCopy() {
WaitUntilStandByBufferReadyBytesNotEqualZero();
delete[] standby_buffer_;
delete[] buffer_;
}
int32_t PersistentInStreamWithoutLocalCopy::ReadLine(std::string* l) { int32_t PersistentInStreamWithoutLocalCopy::ReadLine(std::string* l) {
if (IsEof()) { return -1; } if (IsEof()) { return -1; }
l->clear(); l->clear();
...@@ -27,8 +35,8 @@ int32_t PersistentInStreamWithoutLocalCopy::Read(char* s, size_t n) { ...@@ -27,8 +35,8 @@ int32_t PersistentInStreamWithoutLocalCopy::Read(char* s, size_t n) {
while (n) { while (n) {
if (cur_buf_begin_ == cur_buf_end_) { UpdateBuffer(); } if (cur_buf_begin_ == cur_buf_end_) { UpdateBuffer(); }
CHECK_LT(cur_buf_begin_, cur_buf_end_); CHECK_LT(cur_buf_begin_, cur_buf_end_);
int64_t copy_size = std::min(cur_buf_end_ - cur_buf_begin_, static_cast<int64_t>(n)); size_t copy_size = std::min<size_t>(cur_buf_end_ - cur_buf_begin_, n);
std::memcpy(s, cur_buf_begin_, static_cast<size_t>(copy_size)); memcpy(s, cur_buf_begin_, copy_size);
s += copy_size; s += copy_size;
cur_buf_begin_ += copy_size; cur_buf_begin_ += copy_size;
n -= copy_size; n -= copy_size;
...@@ -41,26 +49,56 @@ PersistentInStreamWithoutLocalCopy::PersistentInStreamWithoutLocalCopy(fs::FileS ...@@ -41,26 +49,56 @@ PersistentInStreamWithoutLocalCopy::PersistentInStreamWithoutLocalCopy(fs::FileS
uint64_t offset) { uint64_t offset) {
fs->NewRandomAccessFile(file_path, &file_); fs->NewRandomAccessFile(file_path, &file_);
file_size_ = fs->GetFileSize(file_path); file_size_ = fs->GetFileSize(file_path);
CHECK_LT(offset, file_size_);
standby_buffer_ = new char[Global<JobDesc>::Get()->persistence_buf_byte() + 1];
standby_buffer_ready_bytes_ = 0;
cur_file_pos_ = offset; cur_file_pos_ = offset;
buffer_.resize(Global<JobDesc>::Get()->persistence_buf_byte() + 1); file_read_done_ = false;
cur_buf_begin_ = buffer_.data(); buffer_ = new char[Global<JobDesc>::Get()->persistence_buf_byte() + 1];
cur_buf_end_ = buffer_.data(); cur_buf_begin_ = buffer_;
cur_buf_end_ = buffer_;
*cur_buf_end_ = '\0'; *cur_buf_end_ = '\0';
} AsyncUpdateStandByBuffer();
bool PersistentInStreamWithoutLocalCopy::IsEof() const {
return cur_buf_begin_ == cur_buf_end_ && cur_file_pos_ == file_size_;
} }
void PersistentInStreamWithoutLocalCopy::UpdateBuffer() { void PersistentInStreamWithoutLocalCopy::UpdateBuffer() {
CHECK_EQ(cur_buf_begin_, cur_buf_end_); CHECK_EQ(cur_buf_begin_, cur_buf_end_);
uint64_t n = std::min(buffer_.size() - 1, file_size_ - cur_file_pos_); WaitUntilStandByBufferReadyBytesNotEqualZero();
if (n == 0) { return; } if (standby_buffer_ready_bytes_ == -1) { return; }
file_->Read(cur_file_pos_, n, buffer_.data()); std::swap(standby_buffer_, buffer_);
cur_buf_begin_ = buffer_.data(); cur_buf_begin_ = buffer_;
cur_buf_end_ = buffer_.data() + n; cur_buf_end_ = buffer_ + standby_buffer_ready_bytes_;
*cur_buf_end_ = '\0'; *cur_buf_end_ = '\0';
AddNForCurFilePos(n); standby_buffer_ready_bytes_ = 0;
AsyncUpdateStandByBuffer();
}
void PersistentInStreamWithoutLocalCopy::WaitUntilStandByBufferReadyBytesNotEqualZero() {
std::unique_lock<std::mutex> lck(standby_buffer_ready_mtx_);
standby_buffer_ready_cond_.wait(lck, [this]() { return standby_buffer_ready_bytes_ != 0; });
}
void PersistentInStreamWithoutLocalCopy::AsyncUpdateStandByBuffer() {
g_persistent_in_thread_pool.AddWork([this]() {
uint64_t n =
std::min(Global<JobDesc>::Get()->persistence_buf_byte(), file_size_ - cur_file_pos_);
if (n > 0) {
file_->Read(cur_file_pos_, n, standby_buffer_);
AddNForCurFilePos(n);
}
if (cur_file_pos_ == file_size_) { file_read_done_ = true; }
std::unique_lock<std::mutex> lck(standby_buffer_ready_mtx_);
if (n > 0) {
standby_buffer_ready_bytes_ = n;
} else {
standby_buffer_ready_bytes_ = -1;
}
standby_buffer_ready_cond_.notify_all();
});
}
bool PersistentInStreamWithoutLocalCopy::IsEof() const {
return cur_buf_begin_ == cur_buf_end_ && file_read_done_;
} }
} // namespace oneflow } // namespace oneflow
...@@ -10,7 +10,7 @@ class PersistentInStreamWithoutLocalCopy : public PersistentInStream { ...@@ -10,7 +10,7 @@ class PersistentInStreamWithoutLocalCopy : public PersistentInStream {
public: public:
OF_DISALLOW_COPY_AND_MOVE(PersistentInStreamWithoutLocalCopy); OF_DISALLOW_COPY_AND_MOVE(PersistentInStreamWithoutLocalCopy);
PersistentInStreamWithoutLocalCopy() = delete; PersistentInStreamWithoutLocalCopy() = delete;
virtual ~PersistentInStreamWithoutLocalCopy() = default; virtual ~PersistentInStreamWithoutLocalCopy();
int32_t ReadLine(std::string* l) override; int32_t ReadLine(std::string* l) override;
int32_t Read(char* s, size_t n) override; int32_t Read(char* s, size_t n) override;
...@@ -21,18 +21,27 @@ class PersistentInStreamWithoutLocalCopy : public PersistentInStream { ...@@ -21,18 +21,27 @@ class PersistentInStreamWithoutLocalCopy : public PersistentInStream {
virtual void UpdateBuffer(); virtual void UpdateBuffer();
virtual void AddNForCurFilePos(uint64_t n) = 0; virtual void AddNForCurFilePos(uint64_t n) = 0;
uint64_t file_size() const { return file_size_; } uint64_t file_size() const { return file_size_; }
std::vector<char>* mut_buffer() { return &buffer_; } char* mut_buffer() { return buffer_; }
uint64_t cur_file_pos() const { return cur_file_pos_; } uint64_t cur_file_pos() const { return cur_file_pos_; }
void set_cur_file_pos(uint64_t val) { cur_file_pos_ = val; } void set_cur_file_pos(uint64_t val) { cur_file_pos_ = val; }
void set_cur_buf_begin(char* val) { cur_buf_begin_ = val; } void set_cur_buf_begin(char* val) { cur_buf_begin_ = val; }
void WaitUntilStandByBufferReadyBytesNotEqualZero();
private: private:
void AsyncUpdateStandByBuffer();
bool IsEof() const; bool IsEof() const;
std::unique_ptr<fs::RandomAccessFile> file_; std::unique_ptr<fs::RandomAccessFile> file_;
uint64_t file_size_; uint64_t file_size_;
char* standby_buffer_;
int64_t standby_buffer_ready_bytes_;
std::mutex standby_buffer_ready_mtx_;
std::condition_variable standby_buffer_ready_cond_;
uint64_t cur_file_pos_; uint64_t cur_file_pos_;
std::vector<char> buffer_; std::atomic<bool> file_read_done_;
char* buffer_;
char* cur_buf_begin_; char* cur_buf_begin_;
char* cur_buf_end_; char* cur_buf_end_;
}; };
......
#include "oneflow/core/thread/thread_pool.h"
namespace oneflow {
ThreadPool::ThreadPool(int32_t thread_num)
: work_chans_(thread_num), threads_(thread_num), cur_chan_idx_(0) {
FOR_RANGE(int32_t, i, 0, thread_num) {
Channel<std::function<void()>>* chan = &(work_chans_.at(i));
threads_[i] = std::thread([chan]() {
std::function<void()> work;
while (chan->Receive(&work) == 0) { work(); }
});
}
}
ThreadPool::~ThreadPool() {
FOR_RANGE(int32_t, i, 0, work_chans_.size()) {
work_chans_.at(i).CloseSendEnd();
work_chans_.at(i).CloseReceiveEnd();
threads_.at(i).join();
}
}
void ThreadPool::AddWork(std::function<void()> work) {
if (work_chans_.size() > 1) {
std::unique_lock<std::mutex> lck(cur_chan_idx_mtx_);
work_chans_.at(cur_chan_idx_).Send(work);
cur_chan_idx_ = (cur_chan_idx_ + 1) % work_chans_.size();
} else {
work_chans_.at(cur_chan_idx_).Send(work);
}
}
} // namespace oneflow
#ifndef ONEFLOW_CORE_THREAD_THREAD_POOL_H_
#define ONEFLOW_CORE_THREAD_THREAD_POOL_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/channel.h"
namespace oneflow {
class ThreadPool final {
public:
OF_DISALLOW_COPY_AND_MOVE(ThreadPool);
ThreadPool() = delete;
ThreadPool(int32_t thread_num);
~ThreadPool();
void AddWork(std::function<void()> work);
private:
std::vector<Channel<std::function<void()>>> work_chans_;
std::vector<std::thread> threads_;
std::mutex cur_chan_idx_mtx_;
int32_t cur_chan_idx_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_THREAD_THREAD_POOL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册