未验证 提交 f929f87e 编写于 作者: G guru4elephant 提交者: GitHub

Update async_executor.md

上级 599fe363
......@@ -39,16 +39,16 @@ Why we use multiple queues for data reader? a experimental result page needs to
## Main Interface of Async Executor
We have RunFromFiles interface which is an execution interface for users to call. Every time a user calls RunFromFiles, a main_program should be provided and it is running in the global scope previously defined. A list of file names and corresponding Dataset should be provided. Inside the RunFromFiles interface, readers will be created through Dataset configurations. Files will be fed into created readers.
``` c++
void AsyncExecutor::RunFromFiles(
std::vector<float> AsyncExecutor::RunFromFile(
const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::vector<std::string> & files,
const int thread_num) {
// todo: remove fluid related interface
root_scope_->DropKids();
const std::string& data_feed_desc_str,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_var_names) {
std::vector<std::thread> threads;
threads.resize(thread_num);
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, &data_feed_desc);
/*
readerDesc: protobuf description for reader initlization
argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
......@@ -59,23 +59,10 @@ void AsyncExecutor::RunFromFiles(
2) each reader has a Next() iterface, that can fetch an instance
from the input queue
*/
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed> > readers;
readers.resize(thread_num);
for (int i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
// todo(dongdaxiang): add the following code for worker generalization
/*
std::vector<std::shared_ptr<ExecutorStrategy> > workers;
workers.resize(thread_num);
std::string str_name = strategy_.name;
for (auto& worker : workers) {
worker.reset(
ExecutorStrategyFactory::CreateExecutorStrategy(str_name));
}
*/
PrepareReaders(readers, thread_num, data_feed_desc, filelist);
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
workers.resize(thread_num);
for (auto& worker : workers) {
......@@ -85,10 +72,11 @@ void AsyncExecutor::RunFromFiles(
// prepare thread resource here
for (int thidx = 0; thidx < thread_num; ++thidx) {
CreateThreads(workers[thidx].get(), main_program,
readers[thidx].get(), root_scope_, thidx);
readers[thidx], fetch_var_names, root_scope_, thidx);
}
// start executing ops in multiple threads
for (int thidx = 0; thidx < thread_num_; ++thidx) {
for (int thidx = 0; thidx < thread_num; ++thidx) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFiles,
workers[thidx].get()));
}
......@@ -96,22 +84,44 @@ void AsyncExecutor::RunFromFiles(
for (auto& th : threads) {
th.join();
}
// fetch variables in scope 0, and return, to be added
std::vector<float> fetch_values;
fetch_values.resize(fetch_var_names.size(), 0);
std::vector<std::vector<float>*> fetch_value_vectors;
fetch_value_vectors.resize(thread_num);
for (int i = 0; i < thread_num; ++i) {
fetch_value_vectors[i] = &workers[i]->GetFetchValues();
}
for (unsigned int i = 0; i < fetch_var_names.size(); ++i) {
float value = 0.0;
for (int j = 0; j < thread_num; ++j) {
value += fetch_value_vectors[j]->at(i);
}
value /= thread_num;
fetch_values[i] = value;
}
return fetch_values;
}
```
Inside the function ```CreateThreads```,
``` c++
void AsyncExecutor::CreateThreads(const ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const DataFeed& reader,
const Scope& root_scope,
const int thread_index) {
worker->SetThreadid(thread_index);
void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope& root_scope,
const int thread_index) {
worker->SetThreadId(thread_index);
worker->SetRootScope(&root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->BindingDataFeedMemory(reader);
worker->SetRootScope(root_scope);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
}
```
......@@ -120,20 +130,44 @@ Inside the function ```Trainfiles```,
void ExecutorThreadWorker::TrainFiles() {
// todo: configurable
SetDevice();
thread_reader_->Start(); // start reading thread within reader
while (int cur_batch = thread_reader_->Next()) {
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0);
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
++batch_cnt;
thread_scope_->DropKids();
}
}
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
}
}
```
## How to print variable information during execution
Inside async_executor, no information is printed. Variable can be fetched through an execution of async_executor. The fetched variables can be printed through python.
Inside async_executor, no information is printed. Variable can be fetched through an execution of async_executor. The fetched variables can be printed through python. Since we train several files of instances within async_executor, the fetched variables are not accurate. In this version of design, we only fetch variables of the last iteration for each thread and we average the fetched variables by batch_size * thread_num.
## How to save models
Models can be saved between execution of async_executor through io.save method.
......@@ -144,11 +178,6 @@ Models can be saved between execution of async_executor through io.save method.
* data preparation
* performance and accuracy
### Text Matching
* network configuration
* data preparation
* performance and accuracy
## references
1. [Sentiment Analysis](https://arxiv.org/pdf/1801.07883.pdf)
2. [Word2Vec](https://arxiv.org/abs/1301.3781)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册