未验证 提交 273ebf3e 编写于 作者: G guru4elephant 提交者: GitHub

Update async_executor.md

上级 683d115e
......@@ -13,7 +13,6 @@ def train_loop():
dataset = fluid.DataFeedDesc('train_data/data.prototxt')
dataset.set_batch_size(128) # See API doc for how to change other fields
print dataset.desc() # Debug purpose: see what we get
# define network
# input text data
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
......@@ -130,13 +129,10 @@ Inside the function ```Trainfiles```,
void ExecutorThreadWorker::TrainFiles() {
// todo: configurable
SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0);
fetch_values_.resize(fetch_var_num);
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
......@@ -144,25 +140,16 @@ void ExecutorThreadWorker::TrainFiles() {
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
++batch_cnt;
thread_scope_->DropKids();
if (debug_ == false || thread_id_ != 0) {
continue;
}
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
}
}
print_fetch_var(thread_scope_, fetch_var_names_[i]);
} // end for (int i = 0...)
} // end while ()
}
```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册