未验证 提交 f929f87e 编写于 作者: G guru4elephant 提交者: GitHub

Update async_executor.md

上级 599fe363
...@@ -39,16 +39,16 @@ Why we use multiple queues for data reader? a experimental result page needs to ...@@ -39,16 +39,16 @@ Why we use multiple queues for data reader? a experimental result page needs to
## Main Interface of Async Executor ## Main Interface of Async Executor
We have RunFromFiles interface which is an execution interface for users to call. Every time a user calls RunFromFiles, a main_program should be provided and it is running in the global scope previously defined. A list of file names and corresponding Dataset should be provided. Inside the RunFromFiles interface, readers will be created through Dataset configurations. Files will be fed into created readers. We have RunFromFiles interface which is an execution interface for users to call. Every time a user calls RunFromFiles, a main_program should be provided and it is running in the global scope previously defined. A list of file names and corresponding Dataset should be provided. Inside the RunFromFiles interface, readers will be created through Dataset configurations. Files will be fed into created readers.
``` c++ ``` c++
void AsyncExecutor::RunFromFiles( std::vector<float> AsyncExecutor::RunFromFile(
const ProgramDesc& main_program, const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc, const std::string& data_feed_desc_str,
const std::vector<std::string> & files, const std::vector<std::string>& filelist,
const int thread_num) { const int thread_num,
// todo: remove fluid related interface const std::vector<std::string>& fetch_var_names) {
root_scope_->DropKids();
std::vector<std::thread> threads; std::vector<std::thread> threads;
threads.resize(thread_num);
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc);
/* /*
readerDesc: protobuf description for reader initlization readerDesc: protobuf description for reader initlization
argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
...@@ -59,22 +59,9 @@ void AsyncExecutor::RunFromFiles( ...@@ -59,22 +59,9 @@ void AsyncExecutor::RunFromFiles(
2) each reader has a Next() iterface, that can fetch an instance 2) each reader has a Next() iterface, that can fetch an instance
from the input queue from the input queue
*/ */
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed> > readers; std::vector<std::shared_ptr<DataFeed> > readers;
readers.resize(thread_num); PrepareReaders(readers, thread_num, data_feed_desc, filelist);
for (int i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
// todo(dongdaxiang): add the following code for worker generalization
/*
std::vector<std::shared_ptr<ExecutorStrategy> > workers;
workers.resize(thread_num);
std::string str_name = strategy_.name;
for (auto& worker : workers) {
worker.reset(
ExecutorStrategyFactory::CreateExecutorStrategy(str_name));
}
*/
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers; std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
workers.resize(thread_num); workers.resize(thread_num);
...@@ -85,10 +72,11 @@ void AsyncExecutor::RunFromFiles( ...@@ -85,10 +72,11 @@ void AsyncExecutor::RunFromFiles(
// prepare thread resource here // prepare thread resource here
for (int thidx = 0; thidx < thread_num; ++thidx) { for (int thidx = 0; thidx < thread_num; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, CreateThreads(workers[thidx].get(), main_program,
readers[thidx].get(), root_scope_, thidx); readers[thidx], fetch_var_names, root_scope_, thidx);
} }
// start executing ops in multiple threads // start executing ops in multiple threads
for (int thidx = 0; thidx < thread_num_; ++thidx) { for (int thidx = 0; thidx < thread_num; ++thidx) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFiles, threads.push_back(std::thread(&ExecutorThreadWorker::TrainFiles,
workers[thidx].get())); workers[thidx].get()));
} }
...@@ -96,22 +84,44 @@ void AsyncExecutor::RunFromFiles( ...@@ -96,22 +84,44 @@ void AsyncExecutor::RunFromFiles(
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
// fetch variables in scope 0, and return, to be added
std::vector<float> fetch_values;
fetch_values.resize(fetch_var_names.size(), 0);
std::vector<std::vector<float>*> fetch_value_vectors;
fetch_value_vectors.resize(thread_num);
for (int i = 0; i < thread_num; ++i) {
fetch_value_vectors[i] = &workers[i]->GetFetchValues();
}
for (unsigned int i = 0; i < fetch_var_names.size(); ++i) {
float value = 0.0;
for (int j = 0; j < thread_num; ++j) {
value += fetch_value_vectors[j]->at(i);
}
value /= thread_num;
fetch_values[i] = value;
}
return fetch_values;
} }
``` ```
Inside the function ```CreateThreads```, Inside the function ```CreateThreads```,
``` c++ ``` c++
void AsyncExecutor::CreateThreads(const ExecutorThreadWorker* worker, void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker,
const ProgramDesc& main_program, const ProgramDesc& main_program,
const DataFeed& reader, const std::shared_ptr<DataFeed>& reader,
const Scope& root_scope, const std::vector<std::string>& fetch_var_names,
Scope& root_scope,
const int thread_index) { const int thread_index) {
worker->SetThreadid(thread_index); worker->SetThreadId(thread_index);
worker->SetRootScope(&root_scope);
worker->CreateThreadResource(main_program, place_); worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader); worker->SetDataFeed(reader);
worker->BindingDataFeedMemory(reader); worker->SetFetchVarNames(fetch_var_names);
worker->SetRootScope(root_scope); worker->BindingDataFeedMemory();
} }
``` ```
...@@ -120,20 +130,44 @@ Inside the function ```Trainfiles```, ...@@ -120,20 +130,44 @@ Inside the function ```Trainfiles```,
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
// todo: configurable // todo: configurable
SetDevice(); SetDevice();
thread_reader_->Start(); // start reading thread within reader
while (int cur_batch = thread_reader_->Next()) { int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0);
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here // executor run here
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
++batch_cnt;
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
}
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
}
}
``` ```
## How to print variable information during execution ## How to print variable information during execution
Inside async_executor, no information is printed. Variable can be fetched through an execution of async_executor. The fetched variables can be printed through python. Inside async_executor, no information is printed. Variable can be fetched through an execution of async_executor. The fetched variables can be printed through python. Since we train several files of instances within async_executor, the fetched variables are not accurate. In this version of design, we only fetch variables of the last iteration for each thread and we average the fetched variables by batch_size * thread_num.
## How to save models ## How to save models
Models can be saved between execution of async_executor through io.save method. Models can be saved between execution of async_executor through io.save method.
...@@ -144,11 +178,6 @@ Models can be saved between execution of async_executor through io.save method. ...@@ -144,11 +178,6 @@ Models can be saved between execution of async_executor through io.save method.
* data preparation * data preparation
* performance and accuracy * performance and accuracy
### Text Matching
* network configuration
* data preparation
* performance and accuracy
## references ## references
1. [Sentiment Analysis](https://arxiv.org/pdf/1801.07883.pdf) 1. [Sentiment Analysis](https://arxiv.org/pdf/1801.07883.pdf)
2. [Word2Vec](https://arxiv.org/abs/1301.3781) 2. [Word2Vec](https://arxiv.org/abs/1301.3781)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册