未验证 提交 273ebf3e 编写于 作者: G guru4elephant 提交者: GitHub

Update async_executor.md

上级 683d115e
......@@ -13,7 +13,6 @@ def train_loop():
dataset = fluid.DataFeedDesc('train_data/data.prototxt')
dataset.set_batch_size(128) # See API doc for how to change other fields
print dataset.desc() # Debug purpose: see what we get
# define network
# input text data
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
......@@ -22,7 +21,7 @@ def train_loop():
avg_cost, acc, prediction = bow_net(data, label)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002)
opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost)
# Run startup program
# Run startup program
startup_program = fluid.default_startup_program()
place = fluid.CPUPlace()
executor = fluid.Executor(place)
......@@ -71,13 +70,13 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
"only variables with the last dimension size 1 supported",
var_name);
}
DataFeedDesc data_feed_desc;
DataFeedDesc data_feed_desc;
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc);
int actual_thread_num = thread_num;
int actual_thread_num = thread_num;
int file_cnt = filelist.size();
PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty");
if (actual_thread_num > file_cnt) {
if (actual_thread_num > file_cnt) {
VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing thread_num = " << file_cnt;
actual_thread_num = file_cnt;
......@@ -89,21 +88,21 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker);
}
// prepare thread resource here
// prepare thread resource here
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
CreateThreads(workers[thidx].get(), main_program, readers[thidx],
fetch_var_names, root_scope_, thidx, debug);
}
// start executing ops in multiple threads
// start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
threads.push_back(
std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get()));
}
for (auto& th : threads) {
for (auto& th : threads) {
th.join();
}
root_scope_->DropKids();
return;
root_scope_->DropKids();
return;
}
```
......@@ -130,13 +129,10 @@ Inside the function ```Trainfiles```,
void ExecutorThreadWorker::TrainFiles() {
// todo: configurable
SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0);
fetch_values_.resize(fetch_var_num);
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
......@@ -144,25 +140,16 @@ void ExecutorThreadWorker::TrainFiles() {
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
++batch_cnt;
thread_scope_->DropKids();
}
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
if (debug_ == false || thread_id_ != 0) {
continue;
}
}
for (int i = 0; i < fetch_var_num; ++i) {
print_fetch_var(thread_scope_, fetch_var_names_[i]);
} // end for (int i = 0...)
} // end while ()
}
```
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册