提交 eb6a941f 编写于 作者: W wangguibao

Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch

上级 1d239cc8
......@@ -40,13 +40,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::mutex ExecutorThreadWorker::s_locker_for_pick_file_;
unsigned int ExecutorThreadWorker::s_current_file_idx_ = 0;
size_t ExecutorThreadWorker::s_current_finished_file_cnt_ = 0;
unsigned int ExecutorThreadWorker::s_current_epoch_ = 0;
int ExecutorThreadWorker::s_current_save_epoch_ = 0;
bool ExecutorThreadWorker::s_is_first_worker_ = false;
std::vector<std::string> ExecutorThreadWorker::s_thread_filelist_;
bool AsyncExecutor::workers_initialized_ = false;
void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
......@@ -124,7 +119,6 @@ static void SaveModel(
{{"X", {var->Name()}}},
{},
attrs);
save_op->Run(*scope, place);
} else {
paralist.push_back(var->Name());
......@@ -140,15 +134,14 @@ static void SaveModel(
{{"X", paralist}},
{},
attrs);
save_op->Run(*scope, place);
}
} // end SaveModel
void ExecutorThreadWorker::AddTrainFile(const std::string& file) {
s_thread_filelist_.push_back(file);
void ExecutorThreadWorker::Reset() {
inspect_values_.clear();
}
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
......@@ -175,8 +168,12 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
}
}
void ExecutorThreadWorker::SetDataFeed(const std::shared_ptr<DataFeed>& datafeed) {
local_reader_ = datafeed;
void ExecutorThreadWorker::SetDataFeed(DataFeed& datafeed) {
if (typeid(datafeed) == typeid(TextClassDataFeed)) {
local_reader_.reset(
new TextClassDataFeed(dynamic_cast<TextClassDataFeed &>(datafeed)));
local_reader_->SetThreadId(thread_id_);
}
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
......@@ -186,9 +183,11 @@ void ExecutorThreadWorker::BindingDataFeedMemory() {
}
}
void ExecutorThreadWorker::SetInspectVarName(
const std::string& inspect_var_name) {
inspect_var_name_ = inspect_var_name;
void ExecutorThreadWorker::SetInspectVarNames(
const std::vector<std::string>& inspect_var_names) {
inspect_var_names_.clear();
inspect_var_names_.insert(inspect_var_names_.end(),
inspect_var_names.begin(), inspect_var_names.end());
}
void ExecutorThreadWorker::SetModelParamNames(
......@@ -196,11 +195,6 @@ void ExecutorThreadWorker::SetModelParamNames(
model_param_names_ = param_names;
}
void ExecutorThreadWorker::SetSparseCommData(
const std::map<std::string, int>& param_names) {
sparse_comm_data_ = param_names;
}
void ExecutorThreadWorker::SetDevice() {
static unsigned priority[] = {
0, 1, 2, 3, 4, 5,
......@@ -228,138 +222,79 @@ void ExecutorThreadWorker::SetDevice() {
CPU_ZERO(&mask);
if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
&& CPU_ISSET(proc, &mask)) {
LOG(ERROR) << "TRACE: Thread " << i << " is running on processor " << proc << "...";
LOG(ERROR) << "TRACE: Thread " << i
<< " is running on processor " << proc
<< "...";
}
}
}
}
void ExecutorThreadWorker::UpdateEpochNum() {
s_current_finished_file_cnt_++;
if (s_current_finished_file_cnt_ >= s_thread_filelist_.size()) {
s_current_finished_file_cnt_ = 0;
s_current_epoch_++;
}
}
void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train";
SetDevice();
const char* ExecutorThreadWorker::PickOneFile() {
std::string file_to_be_preocessed;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
int inspect_var_num = inspect_var_names_.size();
inspect_values_.clear();
inspect_values_.resize(inspect_var_num, 0);
if (s_current_file_idx_ >= s_thread_filelist_.size()) {
std::random_shuffle(s_thread_filelist_.begin(),
s_thread_filelist_.end());
s_current_file_idx_ = 0;
// s_current_epoch_++; //example: when one file, one thread, it's bug
LOG(ERROR) << "thread " << thread_id_
<< ": finish traing for epoch " << s_current_epoch_ + 1;
}
file_to_be_preocessed = s_thread_filelist_[s_current_file_idx_];
local_reader_->WaitNextEpoch();
int epoch = local_reader_->GetCurrentEpoch();
s_current_file_idx_++;
return file_to_be_preocessed.c_str();
}
LOG(ERROR) << "epoch: " << epoch;
void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train";
SetDevice();
#ifdef LOCAL_PROF
std::vector<double> op_total_time;
std::vector<std::string> op_name;
// int total_batch = 0;
for (auto& op : ops_) {
op_name.push_back(op->Type());
}
op_total_time.resize(ops_.size());
for (int i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
#endif
std::string inspect_key = "inspect";
if (!inspect_var_name_.empty()) {
inspect_key = inspect_var_name_.substr(0,
inspect_var_name_.find_first_of('_'));
}
for (unsigned i = 0; i < max_epoch_; ++i) {
LOG(ERROR) << "epoch: " << i;
#ifdef LOCAL_PROF
Timer timeline;
double total_time = 0.0;
double read_time = 0.0;
#endif
float total_inspect = 0;
int batch_num = 1;
while (i == s_current_epoch_) {
const char* filename = PickOneFile();
local_reader_->SetFile(filename);
while (true) {
#ifdef LOCAL_PROF
timeline.start();
#endif
bool flag = local_reader_->ReadBatch();
if (!flag) {
const char *file = local_reader_->PickOneFile();
if (file == NULL) {
break;
}
#ifdef LOCAL_PROF
timeline.pause();
read_time += timeline.elapsed_sec();
total_time += timeline.elapsed_sec();
#endif
if (!local_reader_->SetFile(file)) {
break;
}
while (true) {
bool flag = local_reader_->ReadBatch();
if (!flag) {
break;
}
for (unsigned int i = 0; i < ops_.size(); ++i) {
#ifdef LOCAL_PROF
timeline.start();
#endif
ops_[i]->Run(*thread_scope_, place_);
#ifdef LOCAL_PROF
timeline.pause();
op_total_time[i] += timeline.elapsed_sec();
total_time += timeline.elapsed_sec();
#endif
}
batch_num++;
float avg_inspect = 0.0;
if (!inspect_var_name_.empty()) {
avg_inspect = thread_scope_->FindVar(inspect_var_name_)
for (int i = 0; i < inspect_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(inspect_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
inspect_values_[i] += avg_inspect;
}
total_inspect += avg_inspect;
thread_scope_->DropKids();
}
UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << i + 1
local_reader_->UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << epoch + 1
<< " called: " << memory::memory_usage(place_);
}
for (int i = 0; i < inspect_var_num; ++i) {
inspect_values_[i] /= batch_num;
std::string var = inspect_var_names_[i].substr(
0,
inspect_var_names_[i].find_first_of("_"));
LOG(ERROR) << "mean " << var.c_str()
<< " of epoch " << i + 1 << ": " << inspect_values_[i];
}
#ifdef LOCAL_PROF
for (int i = 0; i < op_total_time.size(); ++i) {
std::cerr << "op_name:[" << i << "][" << op_name[i] << "]"
<< " op_mean_time:[" << op_total_time[i] << "s]"
<< std::endl;
}
std::cerr << "read time: " << read_time << "s" << std::endl;
#endif
}
#ifdef LOCAL_PROF
LOG(ERROR) << "mean " << inspect_key.c_str()
<< " of epoch " << i + 1 << ": " << total_inspect / batch_num
<< ", total_time: " << total_time;
#else
LOG(ERROR) << "mean " << inspect_key.c_str()
<< " of epoch " << i + 1 << ": " << total_inspect / batch_num;
#endif
if (thread_id_ == 0) {
char modelfile[1024];
snprintf(&modelfile[0],
sizeof(modelfile),
"%s_epoch%d.model",
model_prefix_.c_str(),
i);
snprintf(&modelfile[0], sizeof(modelfile), "%s_epoch%d.model",
model_prefix_.c_str(), epoch);
std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this
// general
......@@ -372,7 +307,6 @@ void ExecutorThreadWorker::Train() {
model_filename,
true);
}
}
}
void ExecutorThreadWorker::SetThreadId(int tid) {
......@@ -396,7 +330,20 @@ void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
AsyncExecutor::AsyncExecutor(const platform::Place& place) : place_(place) {}
AsyncExecutor::AsyncExecutor(ProgramDesc& main_program,
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed,
unsigned int thread_num,
const platform::Place& place)
: thread_num_(thread_num),
place_(place),
main_program_(main_program),
data_feed_(data_feed) {
model_param_names_.clear();
model_param_names_.insert(model_param_names_.end(),
param_names.begin(),
param_names.end());
}
void AsyncExecutor::InitRootScope(Scope* scope) {
root_scope_ = scope;
......@@ -406,10 +353,6 @@ void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
void AsyncExecutor::SetDataFeedName(const char* feedname) {
feed_name_ = std::string(feedname);
}
void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix;
}
......@@ -463,60 +406,16 @@ std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile(
return program;
}
void AsyncExecutor::SetDenseCommTensor(
const std::vector<std::string>& dense_comm_tensor) {
dense_comm_tensor_.resize(dense_comm_tensor.size());
for (unsigned int i = 0; i < dense_comm_tensor.size(); ++i) {
dense_comm_tensor_[i] = dense_comm_tensor[i];
}
}
void AsyncExecutor::SetSparseCommTensor(
const std::vector<std::string>& sparse_comm_tensor) {
sparse_comm_tensor_.resize(sparse_comm_tensor.size());
for (unsigned int i = 0; i < sparse_comm_tensor.size(); ++i) {
sparse_comm_tensor_[i] = sparse_comm_tensor[i];
}
}
void AsyncExecutor::SetSparseCommData(
const std::map<std::string, int>& sparse_comm_data) {
sparse_comm_data_ = sparse_comm_data;
LOG(INFO) << "Sparse comm data: " << sparse_comm_data_.size();
}
void AsyncExecutor::SetFileList(const char* filelist) {
filelist_.clear();
std::ifstream fin(filelist);
std::string filename;
while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist";
filelist_.push_back(filename);
}
fin.close();
}
void AsyncExecutor::SetFileList(std::vector<std::string> tfiles) {
filelist_.clear();
filelist_.insert(filelist_.end(), tfiles.begin(), tfiles.end());
return;
}
void AsyncExecutor::SetInspectVarName(const std::string& inspect_var_name) {
inspect_var_name_ = inspect_var_name;
}
void AsyncExecutor::SetParamNames(const std::vector<std::string>& param_names) {
model_param_names_ = param_names;
}
void AsyncExecutor::SetThreadNum(const int thread_num) {
thread_num_ = thread_num;
void AsyncExecutor::SetInspectVarNames(
const std::vector<std::string>& inspect_var_names) {
inspect_var_names_.clear();
inspect_var_names_.insert(inspect_var_names_.end(),
inspect_var_names.begin(), inspect_var_names.end());
}
void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_.resize(thread_num_);
for (unsigned i = 0; i < thread_num_; ++i) {
for (int i = 0; i < thread_num_; ++i) {
workers_[i].reset(new ExecutorThreadWorker);
workers_[i]->SetThreadId(i);
workers_[i]->CreateThreadOperators(host_program);
......@@ -524,34 +423,31 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_[i]->SetPlace(place_);
workers_[i]->SetMaxTrainingEpoch(max_epoch_);
workers_[i]->CreateThreadScope(host_program);
workers_[i]->SetInspectVarName(inspect_var_name_);
workers_[i]->SetInspectVarNames(inspect_var_names_);
workers_[i]->SetModelParamNames(model_param_names_);
workers_[i]->SetSparseCommData(sparse_comm_data_);
workers_[i]->SetMainProgram(host_program);
workers_[i]->SetModelPrefix(model_prefix_);
}
for (unsigned i = 0; i < filelist_.size(); ++i) {
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
workers_[0]->AddTrainFile(filelist_[i]);
}
for (unsigned i = 0; i < thread_num_; ++i) {
//
// new a datafeed here
std::shared_ptr<DataFeed> local_feed = CreateDataFeed(feed_name_.c_str());
local_feed->Init();
local_feed->SetBatchSize(batch_size_);
workers_[i]->SetDataFeed(local_feed);
workers_[i]->SetDataFeed(data_feed_);
workers_[i]->BindingDataFeedMemory();
workers_[i]->SetThreadId(i);
}
}
void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) {
std::vector<float>& AsyncExecutor::Run(
const std::vector<std::string>& inspect_var_names) {
SetInspectVarNames(inspect_var_names);
threads_.clear();
// thread binding here?
PrepareThreads(host_program);
for (unsigned i = 0; i < thread_num_; ++i) {
if (workers_initialized_ == false) {
PrepareThreads(main_program_);
workers_initialized_ = true;
}
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->Reset();
workers_[i]->SetInspectVarNames(inspect_var_names);
threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
workers_[i].get()));
}
......@@ -559,6 +455,27 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) {
for (auto& th : threads_) {
th.join();
}
inspect_values_.clear();
inspect_values_.resize(inspect_var_names_.size(), 0);
std::vector<std::vector<float>*> inspect_value_vectors;
inspect_value_vectors.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
inspect_value_vectors[i] = &workers_[i]->GetInspectValues();
}
for (unsigned int i = 0; i < inspect_var_names_.size(); ++i) {
float value = 0.0;
for (int j = 0; j < thread_num_; ++j) {
value += inspect_value_vectors[j]->at(i);
}
value /= thread_num_;
inspect_values_[i] = value;
}
return inspect_values_;
}
void AsyncExecutor::LoadInitModel() {
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <vector>
#include <typeinfo>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
......@@ -36,10 +37,9 @@ class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
~ExecutorThreadWorker() {}
void CreateThreadScope(const framework::ProgramDesc& program);
void SetDataFeed(const DataFeed& datafeed);
void CreateThreadScope(const ProgramDesc& program);
void SetThreadId(int tid);
void CreateThreadOperators(const framework::ProgramDesc& program);
void CreateThreadOperators(const ProgramDesc& program);
void SetRootScope(Scope* g_scope);
void SetDevice();
void AddFidSet();
......@@ -52,25 +52,16 @@ class ExecutorThreadWorker {
void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void SetInspectVarName(const std::string& inspect_var_name);
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
void SetModelParamNames(const std::vector<std::string>& param_names);
void SetSparseCommData(const std::map<std::string, int>& param_names);
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
void SetDataFeed(DataFeed& datafeed); // NOLINT
void Train();
const char* PickOneFile();
void UpdateEpochNum();
void Reset();
void SetDenseCommTensor(const std::vector<std::string>& param_names) {}
void Initialize() {}
public:
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::vector<std::string> s_thread_filelist_; // filelist
static bool s_is_first_worker_;
std::vector<float>& GetInspectValues() {return inspect_values_;}
protected:
// thread index
......@@ -88,14 +79,13 @@ class ExecutorThreadWorker {
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<framework::ProgramDesc> main_program_;
std::unique_ptr<ProgramDesc> main_program_;
// binary data reader
std::shared_ptr<DataFeed> local_reader_;
std::unique_ptr<DataFeed> local_reader_;
std::string inspect_var_name_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
std::map<std::string, int> sparse_comm_data_;
// execution place
platform::Place place_;
......@@ -105,24 +95,26 @@ class ExecutorThreadWorker {
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
private:
std::vector<float> inspect_values_;
};
class AsyncExecutor {
public:
explicit AsyncExecutor(const platform::Place& place);
explicit AsyncExecutor(ProgramDesc& main_program, // NOLINT
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed, // NOLINT
unsigned int thread_num,
const platform::Place& place);
virtual ~AsyncExecutor() {}
static std::unique_ptr<ProgramDesc> LoadDescFromFile(
const std::string& filename);
void InitRootScope(Scope* scope);
void SetInspectVarName(const std::string& inspect_var_name);
void SetParamNames(const std::vector<std::string>& param_names);
void SetMaxTrainingEpoch(const int max_epoch);
Scope* GetRootScope() { return root_scope_; }
void SetThreadNum(const int thread_num);
void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
void SetFileList(const char* filelist);
void SetFileList(const std::vector<std::string> filelist);
void SetDataFeedName(const char* feedname);
void SetCommBatch(int comm_batch) {
comm_batch_ = comm_batch;
}
......@@ -140,37 +132,38 @@ class AsyncExecutor {
}
void SetModelPrefix(const std::string& model_prefix);
void SetDenseCommTensor(const std::vector<std::string>& dense_comm_tensor);
void SetSparseCommTensor(
const std::vector<std::string>& sparse_comm_tensor);
void SetSparseCommData(const std::map<std::string, int>& sparse_comm_data);
virtual void PrepareThreads(const framework::ProgramDesc& host_program);
void RunStartupProgram(const framework::ProgramDesc& program,
framework::Scope* scope);
void RunAsyncExecutor(const ProgramDesc& host_program);
virtual void PrepareThreads(const ProgramDesc& host_program);
void RunStartupProgram(const ProgramDesc& program, Scope* scope);
std::vector<float>& Run(const std::vector<std::string>& inspect_var_names);
void LoadInitModel();
private:
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
public:
unsigned int thread_num_;
int thread_num_;
int max_epoch_;
int batch_size_;
int comm_batch_;
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
std::vector<std::thread> threads_;
std::vector<std::string> filelist_;
std::string inspect_var_name_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
std::vector<std::string> dense_comm_tensor_;
std::vector<std::string> sparse_comm_tensor_;
std::map<std::string, int> sparse_comm_data_;
std::string model_prefix_;
std::string model_path_;
std::string init_prog_file_;
std::string init_model_file_;
std::string feed_name_;
Scope* root_scope_;
platform::Place place_;
private:
ProgramDesc& main_program_;
TextClassDataFeed& data_feed_;
std::vector<float> inspect_values_;
private:
static bool workers_initialized_;
};
} // namespace framework
......
......@@ -38,6 +38,16 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle {
namespace framework {
std::vector<std::string> TextClassDataFeed::s_filelist_;
std::mutex TextClassDataFeed::s_locker_for_pick_file_;
unsigned int TextClassDataFeed::s_current_file_idx_ = 0;
size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int TextClassDataFeed::s_current_epoch_ = 0;
int TextClassDataFeed::s_current_save_epoch_ = 0;
std::mutex TextClassDataFeed::s_locker_epoch_start_;
std::condition_variable TextClassDataFeed::s_condition_epoch_start_;
bool TextClassDataFeed::s_epoch_start_flag_ = false;
void TextClassDataFeed::Init() {
// hard coding for a specific datafeed
feed_vec_.resize(2);
......@@ -59,6 +69,12 @@ void TextClassDataFeed::Init() {
label_host_.reset(new int[10240],
[](int *p) {delete[] p;}); // max label in a batch
label_ptr_ = label_host_.get();
field_names_.clear();
}
TextClassDataFeed::TextClassDataFeed() {
Init();
}
// todo: use elegant implemention for this function
......@@ -69,6 +85,7 @@ bool TextClassDataFeed::ReadBatch() {
int inst_idx = 0;
offset.resize(batch_size_ + 1);
offset[0] = 0;
while (inst_idx < batch_size_) {
int ptr_offset = 0;
if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) {
......@@ -125,6 +142,12 @@ bool TextClassDataFeed::ReadBatch() {
return true;
}
TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) {
Init();
SetBatchSize(data_feed.batch_size_);
SetFieldNames(data_feed.field_names_);
}
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == use_slot_alias_[i]) {
......@@ -133,30 +156,99 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
}
}
void TextClassDataFeed::SetFileList(const char* filelist) {
s_filelist_.clear();
std::ifstream fin(filelist);
PADDLE_ENFORCE(fin.good(),
"Opening file %s fail",
filelist);
std::string filename;
while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist";
s_filelist_.push_back(filename);
}
fin.close();
}
void TextClassDataFeed::SetFieldNames(
const std::vector<std::string>& field_names) {
field_names_.clear();
field_names_.insert(field_names_.end(), field_names.begin(),
field_names.end());
}
bool TextClassDataFeed::SetFile(const char* filename) {
// termnum termid termid ... termid label
int filesize = ReadWholeFile(filename, file_content_buffer_);
// todo , remove magic number
std::ifstream ifs(filename, std::ios::binary);
if (ifs.fail()) {
return false;
}
ifs.seekg(0, std::ios::end);
int filesize = ifs.tellg();
ifs.seekg(0, std::ios::beg);
ifs.read(file_content_buffer_, filesize);
if (filesize < 0 || filesize >= 1024 * 1024 * 1024) {
return false;
}
file_content_buffer_ptr_ = file_content_buffer_;
file_size_ = filesize;
// todo , remove magic number
return true;
}
int TextClassDataFeed::ReadWholeFile(const std::string& filename,
char* buffer) {
std::ifstream ifs(filename.c_str(), std::ios::binary);
if (ifs.fail()) {
return -1;
void TextClassDataFeed::UpdateEpochNum() {
s_current_finished_file_cnt_++;
if (s_current_finished_file_cnt_ >= s_filelist_.size()) {
s_current_finished_file_cnt_ = 0;
s_current_epoch_++;
#if 1
LOG(WARNING) << "UpdateEpochNum: epoch = " << s_current_epoch_;
#endif
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = false;
}
}
}
ifs.seekg(0, std::ios::end);
int file_size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
ifs.read(buffer, file_size);
return file_size;
void TextClassDataFeed::StartOneEpoch() {
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
std::random_shuffle(s_filelist_.begin(), s_filelist_.end());
s_current_file_idx_ = 0;
LOG(INFO) << "Beginning epoch " << s_current_epoch_;
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = true;
}
s_condition_epoch_start_.notify_all();
}
void TextClassDataFeed::WaitNextEpoch() {
std::unique_lock<std::mutex> lock(s_locker_epoch_start_);
s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;});
}
const char* TextClassDataFeed::PickOneFile() {
std::string file_to_be_processed;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
// One epoch has run over
// Wait for next epoch
if (s_current_file_idx_ >= s_filelist_.size()) {
LOG(ERROR) << "thread " << thread_id_
<< ": finish traing for epoch " << s_current_epoch_ + 1;
return NULL;
}
file_to_be_processed = s_filelist_[s_current_file_idx_];
s_current_file_idx_++;
return file_to_be_processed.c_str();
}
} // namespace framework
......
......@@ -47,24 +47,9 @@ struct Instance {
std::vector<Gauc> gauc_vec;
};
class DataFeed {
DataFeed() {}
virtual ~DataFeed() {}
};
class BlockingQueueDataFeed : DataFeed {
BlockingQueueDataFeed() {}
virtual ~BlockingQueueDataFeed() {}
};
class ThreadedDataFeed : DataFeed {
ThreadedDataFeed() {}
virtual ~ThreadedDataFeed() {}
};
class DataFeed {
public:
DataFeed() {}
DataFeed() : default_batch_size_(1), batch_size_(0), thread_id_(0) {}
virtual ~DataFeed() {}
virtual void Init() = 0;
/*
......@@ -93,6 +78,11 @@ class DataFeed {
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; }
virtual void SetBufferSize(int buffer_size) {}
virtual unsigned int GetCurrentEpoch() = 0;
virtual const char *PickOneFile() = 0;
virtual void UpdateEpochNum() = 0;
virtual void StartOneEpoch() = 0;
virtual void WaitNextEpoch() = 0;
std::vector<LoDTensor*>& GetFeedVec() {
return feed_vec_;
......@@ -103,6 +93,9 @@ class DataFeed {
return feed_vec_;
}
int GetThreadId() {return thread_id_;}
void SetThreadId(int thread_id) {thread_id_ = thread_id;}
protected:
std::vector<uint16_t> all_slot_ids_;
std::vector<uint16_t> use_slot_ids_;
......@@ -110,9 +103,14 @@ class DataFeed {
std::vector<LoDTensor*> feed_vec_;
int default_batch_size_;
int batch_size_;
int thread_id_;
};
class TextClassDataFeed : public DataFeed {
public:
TextClassDataFeed();
TextClassDataFeed(const TextClassDataFeed& data_feed);
public:
virtual ~TextClassDataFeed() {}
virtual void Init();
......@@ -120,25 +118,45 @@ class TextClassDataFeed : public DataFeed {
virtual void AddFeedVar(Variable* feed, const std::string& name);
virtual void BindScope(Scope* scope) {}
virtual bool SetFile(const char* filename);
virtual bool CheckFile(const char* filename) {
// TODO(xxx)
return false;
}
void SetBatchSize(int batch) {batch_size_ = batch;}
unsigned int GetCurrentEpoch() {return s_current_epoch_;}
void UpdateEpochNum();
void StartOneEpoch();
void WaitNextEpoch();
public:
void SetFieldNames(const std::vector<std::string>& field_names);
public:
static void SetFileList(const char* filelist);
private:
const char* PickOneFile();
private:
int ReadWholeFile(const std::string& filename, char* buffer);
char* file_content_buffer_;
char* file_content_buffer_ptr_;
int* batch_id_buffer_;
int* label_ptr_;
int file_size_;
std::vector<std::string> names_;
std::vector<std::string> field_names_;
std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> label_host_;
static std::vector<std::string> s_filelist_;
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::mutex s_locker_epoch_start_;
static std::condition_variable s_condition_epoch_start_;
static bool s_epoch_start_flag_;
};
} // namespace framework
......
......@@ -21,7 +21,10 @@ limitations under the License. */
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <vector>
#include <string>
#include "paddle/fluid/pybind/async_executor_py.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "paddle/fluid/inference/io.h"
......@@ -29,58 +32,36 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/framework/data_feed.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindAsyncExecutor(py::module* m) {
py::class_<paddle::AsyncExecutorParameter>(*m, "AsyncExecutorParameter")
.def(py::init<>())
.def("parse",
[](paddle::AsyncExecutorParameter &self, const std::string &conf_file) {
int file_descriptor = open(conf_file.c_str(), O_RDONLY);
google::protobuf::io::FileInputStream file_input(file_descriptor);
google::protobuf::TextFormat::Parse(&file_input, &self);
close(file_descriptor);
}
);
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<const platform::Place&>())
.def("init",
[](framework::AsyncExecutor &self,
paddle::AsyncExecutorParameter &parameter,
framework::Scope *scope) {
paddle::BaseParameter base_param = parameter.base_param();
py::class_<framework::DataFeed>(*m, "DataFeed");
py::class_<framework::TextClassDataFeed,
framework::DataFeed>(*m, "TextDataFeed")
.def(py::init())
.def("set_filelist",
[] (framework::TextClassDataFeed &self, const char *data_list_file) {
self.SetFileList(data_list_file);
})
.def("set_batch_size", &framework::TextClassDataFeed::SetBatchSize)
.def("set_field_names", &framework::TextClassDataFeed::SetFieldNames)
.def("start_one_epoch", &framework::TextClassDataFeed::StartOneEpoch);
// TODO Extract parameter list from python side, instead of
// providing them in confgurations manually
std::vector<std::string> param_names;
for (int i = 0; i < base_param.model_param_names_size(); ++i) {
param_names.push_back(base_param.model_param_names(i));
}
paddle::framework::InitDevices(false);
self.InitRootScope(scope);
self.SetThreadNum(base_param.thread_num());
self.SetMaxTrainingEpoch(base_param.max_epoch());
self.SetFileList(base_param.filelist().c_str());
self.SetBatchSize(base_param.batch_size());
self.SetDataFeedName(base_param.datafeed_class().c_str());
self.SetInspectVarName(base_param.inspect_var_name());
self.SetParamNames(param_names);
self.SetModelPath(base_param.model_path());
self.SetModelPrefix(base_param.model_prefix());
self.SetInitProgFile(base_param.init_prog_file());
self.SetInitModelFile(base_param.init_model_file());
return;
}
)
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<framework::ProgramDesc&,
std::vector<std::string>&,
framework::TextClassDataFeed&,
unsigned int,
const platform::Place&>())
.def("init_root_scope", &framework::AsyncExecutor::InitRootScope)
.def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram)
.def("load_init_model", &framework::AsyncExecutor::LoadInitModel)
.def("run", &framework::AsyncExecutor::RunAsyncExecutor);
.def("run", &framework::AsyncExecutor::Run);
} // end BindAsyncExecutor
} // end namespace framework
} // end namespace pybind
} // end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */
......@@ -15,6 +15,7 @@
#ifndef PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#define PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
......
......@@ -21,22 +21,28 @@ from .framework import Program, default_main_program, Variable
from . import core
from . import Executor
__all__ = ['AsyncExecutorParameter', 'AsyncExecutor']
__all__ = ['TextDataFeed', 'AsyncExecutor']
g_scope = core.Scope()
class AsyncExecutorParameter(object):
"""
AsyncExecutor configure parameter
Args:
None
"""
class TextDataFeed():
def __init__(self):
self.parameter = core.AsyncExecutorParameter()
self.feed = core.TextDataFeed()
def set_filelist(self, filelist):
self.feed.set_filelist(filelist)
def set_batch_size(self, batch_size):
self.feed.set_batch_size(batch_size)
def set_field_names(self, field_names):
if isinstance(field_names, Variable):
field_names = [field_names]
self.feed.set_field_names(field_names)
def parse(self, conf_file):
self.parameter.parse(conf_file)
def start_an_epoch(self):
self.feed.start_one_epoch()
class AsyncExecutor(object):
"""
......@@ -50,39 +56,31 @@ class AsyncExecutor(object):
"""
def __init__(self,
async_executor_parameter,
place,
scope):
if not isinstance(async_executor_parameter, AsyncExecutorParameter):
raise TypeError(
"AsyncExecutor requires AsyncExecutorParameter as its parameter. "
"But you passed in %s" %s (type(async_executor_parameter))
)
self.place = place
p = core.Place()
p.set_place(place)
self.executor = core.AsyncExecutor(p)
self.executor.init(async_executor_parameter.parameter, scope)
self._closed = False
self.parameter = async_executor_parameter.parameter
program,
param_names,
data_feed,
thread_num,
place=None,
scope=None):
if program is None:
program = default_main_program()
program_desc = program.desc
def close(self):
"""
Close this executor.
if not isinstance(data_feed, TextDataFeed):
raise ValueError("data_feed for AsyncExecutor.run() type error")
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
if place is None:
place = core.CPUPlace()
if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device")
if isinstance(param_names, Variable):
param_names = [param_names]
p = core.Place()
p.set_place(place)
self.executor = core.AsyncExecutor(program_desc, param_names, data_feed.feed, thread_num, p)
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if not self._closed:
self._closed = True
def run_startup_program(self,
program=None,
scope=None):
......@@ -95,7 +93,7 @@ class AsyncExecutor(object):
self.executor.run_startup_program(program_desc, scope)
def run(self, program=None, scope=None):
def run(self, inspect_vars, scope=None):
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
......@@ -138,23 +136,16 @@ class AsyncExecutor(object):
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
if program is None:
program = default_main_program()
program_desc = program.desc
if not isinstance(program, Program):
raise TypeError(
"Executor requires Program as its Parameter. But you passed in %s"
% (type(program)))
if inspect_vars is not None:
if isinstance(inspect_vars, Variable):
inspect_vars = [inspect_vars]
inspect_var_names = [var.name for var in inspect_vars]
if scope is None:
scope = g_scope
self.executor.run(program.desc)
self.executor.init_root_scope(scope)
evaluation = self.executor.run(inspect_var_names)
return evaluation
def load_init_model(self):
return self.executor.load_init_model()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册