未验证 提交 273ebf3e 编写于 作者: G guru4elephant 提交者: GitHub

Update async_executor.md

上级 683d115e
...@@ -13,7 +13,6 @@ def train_loop(): ...@@ -13,7 +13,6 @@ def train_loop():
dataset = fluid.DataFeedDesc('train_data/data.prototxt') dataset = fluid.DataFeedDesc('train_data/data.prototxt')
dataset.set_batch_size(128) # See API doc for how to change other fields dataset.set_batch_size(128) # See API doc for how to change other fields
print dataset.desc() # Debug purpose: see what we get print dataset.desc() # Debug purpose: see what we get
# define network
# input text data # input text data
data = fluid.layers.data( data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1) name="words", shape=[1], dtype="int64", lod_level=1)
...@@ -22,7 +21,7 @@ def train_loop(): ...@@ -22,7 +21,7 @@ def train_loop():
avg_cost, acc, prediction = bow_net(data, label) avg_cost, acc, prediction = bow_net(data, label)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002)
opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost) opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost)
# Run startup program # Run startup program
startup_program = fluid.default_startup_program() startup_program = fluid.default_startup_program()
place = fluid.CPUPlace() place = fluid.CPUPlace()
executor = fluid.Executor(place) executor = fluid.Executor(place)
...@@ -71,13 +70,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -71,13 +70,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
"only variables with the last dimension size 1 supported", "only variables with the last dimension size 1 supported",
var_name); var_name);
} }
DataFeedDesc data_feed_desc; DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc); &data_feed_desc);
int actual_thread_num = thread_num; int actual_thread_num = thread_num;
int file_cnt = filelist.size(); int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) { if (actual_thread_num > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt; << ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt; actual_thread_num = file_cnt;
...@@ -89,21 +88,21 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -89,21 +88,21 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& worker : workers) { for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker); worker.reset(new ExecutorThreadWorker);
} }
// prepare thread resource here // prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx], CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug); fetch_var_names, root_scope_, thidx, debug);
} }
// start executing ops in multiple threads // start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) { for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
threads.push_back( threads.push_back(
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get())); std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
} }
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
root_scope_->DropKids(); root_scope_->DropKids();
return; return;
} }
``` ```
...@@ -130,13 +129,10 @@ Inside the function ```Trainfiles```, ...@@ -130,13 +129,10 @@ Inside the function ```Trainfiles```,
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
// todo: configurable // todo: configurable
SetDevice(); SetDevice();
int fetch_var_num = fetch_var_names_.size(); int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear(); fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0); fetch_values_.resize(fetch_var_num);
thread_reader_->Start(); thread_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) { while ((cur_batch = thread_reader_->Next()) > 0) {
...@@ -144,25 +140,16 @@ void ExecutorThreadWorker::TrainFiles() { ...@@ -144,25 +140,16 @@ void ExecutorThreadWorker::TrainFiles() {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
++batch_cnt; ++batch_cnt;
thread_scope_->DropKids(); thread_scope_->DropKids();
} if (debug_ == false || thread_id_ != 0) {
continue;
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
} }
} for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
} // end for (int i = 0...)
} // end while ()
}
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册