未验证 提交 273ebf3e 编写于 作者: G guru4elephant 提交者: GitHub

Update async_executor.md

上级 683d115e
...@@ -13,7 +13,6 @@ def train_loop(): ...@@ -13,7 +13,6 @@ def train_loop():
dataset = fluid.DataFeedDesc('train_data/data.prototxt') dataset = fluid.DataFeedDesc('train_data/data.prototxt')
dataset.set_batch_size(128) # See API doc for how to change other fields dataset.set_batch_size(128) # See API doc for how to change other fields
print dataset.desc() # Debug purpose: see what we get print dataset.desc() # Debug purpose: see what we get
# define network
# input text data # input text data
data = fluid.layers.data( data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1) name="words", shape=[1], dtype="int64", lod_level=1)
...@@ -130,13 +129,10 @@ Inside the function ```Trainfiles```, ...@@ -130,13 +129,10 @@ Inside the function ```Trainfiles```,
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
// todo: configurable // todo: configurable
SetDevice(); SetDevice();
int fetch_var_num = fetch_var_names_.size(); int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear(); fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0); fetch_values_.resize(fetch_var_num);
thread_reader_->Start(); thread_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) { while ((cur_batch = thread_reader_->Next()) > 0) {
...@@ -144,25 +140,16 @@ void ExecutorThreadWorker::TrainFiles() { ...@@ -144,25 +140,16 @@ void ExecutorThreadWorker::TrainFiles() {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
++batch_cnt; ++batch_cnt;
thread_scope_->DropKids(); thread_scope_->DropKids();
if (debug_ == false || thread_id_ != 0) {
continue;
} }
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt; print_fetch_var(thread_scope_, fetch_var_names_[i]);
} } // end for (int i = 0...)
} } // end while ()
}
``` ```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册