未验证 提交 1fe89e8a 编写于 作者: G guru4elephant 提交者: GitHub

Update async_executor.md

上级 273ebf3e
...@@ -9,7 +9,7 @@ def train_loop(): ...@@ -9,7 +9,7 @@ def train_loop():
with tarfile.open(paddle.dataset.common.download(URL, "imdb", MD5)) as tarf: with tarfile.open(paddle.dataset.common.download(URL, "imdb", MD5)) as tarf:
tarf.extractall(path='./') tarf.extractall(path='./')
tarf.close() tarf.close()
# Initialize dataset description # Initialize dataset description
dataset = fluid.DataFeedDesc('train_data/data.prototxt') dataset = fluid.DataFeedDesc('train_data/data.prototxt')
dataset.set_batch_size(128) # See API doc for how to change other fields dataset.set_batch_size(128) # See API doc for how to change other fields
print dataset.desc() # Debug purpose: see what we get print dataset.desc() # Debug purpose: see what we get
...@@ -18,7 +18,7 @@ def train_loop(): ...@@ -18,7 +18,7 @@ def train_loop():
name="words", shape=[1], dtype="int64", lod_level=1) name="words", shape=[1], dtype="int64", lod_level=1)
# label data # label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost, acc, prediction = bow_net(data, label) avg_cost, acc, prediction = bow_net(data, label)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002)
opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost) opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost)
# Run startup program # Run startup program
...@@ -61,7 +61,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -61,7 +61,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
const bool debug) { const bool debug) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
auto& block = main_program.Block(0); auto& block = main_program.Block(0);
for (auto var_name : fetch_var_names) { for (auto var_name : fetch_var_names) {
auto var_desc = block.FindVar(var_name); auto var_desc = block.FindVar(var_name);
auto shapes = var_desc->GetShape(); auto shapes = var_desc->GetShape();
...@@ -83,7 +83,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -83,7 +83,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
} }
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num);
for (auto& worker : workers) { for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker); worker.reset(new ExecutorThreadWorker);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册