From 5c52b6885cb48c501eff8cda6d650910a43d8bbc Mon Sep 17 00:00:00 2001 From: wangguibao Date: Wed, 21 Nov 2018 15:11:25 +0800 Subject: [PATCH] Fix bug --- paddle/fluid/framework/data_feed.cc | 2 +- paddle/fluid/framework/executor_thread_worker.cc | 7 +++++++ python/paddle/fluid/async_executor.py | 15 ++++++++++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index a1b718fe2c4..be983e614ef 100755 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -98,7 +98,6 @@ void DataFeed::CheckStart() { template void PrivateQueueDataFeed::SetQueueSize(int queue_size) { - CheckInit(); if (queue_size <= 0) { LOG(ERROR) << "error: illegal queue size: " << queue_size; return; @@ -165,6 +164,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { } paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); SetBatchSize(data_feed_desc.batch()); + SetQueueSize(data_feed_desc.batch()); size_t all_slot_num = multi_slot_desc.slots_size(); all_slots_.resize(all_slot_num); all_slots_type_.resize(all_slot_num); diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index 6a84136ac70..c360fe5d8ae 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -176,6 +176,7 @@ void ExecutorThreadWorker::TrainFiles() { thread_reader_->Start(); int cur_batch; + int batch_cnt = 0; while ((cur_batch = thread_reader_->Next()) > 0) { // executor run here for (auto& op : ops_) { @@ -190,8 +191,14 @@ void ExecutorThreadWorker::TrainFiles() { fetch_values_[i] += avg_inspect; } + ++batch_cnt; thread_scope_->DropKids(); } + + for (int i = 0; i < fetch_var_num; ++i) { + fetch_values_[i] = fetch_values_[i] / batch_cnt; + } + } void ExecutorThreadWorker::SetThreadId(int tid) { diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index a6bb4ce75ce..202abb20f3e 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -19,7 +19,7 @@ import contextlib import six from .framework import Program, default_main_program, Variable from . import core -from .executor import global_scope +from .executor import global_scope, Executor from paddle.fluid.proto import data_feed_pb2 from google.protobuf import text_format @@ -67,6 +67,19 @@ class AsyncExecutor(object): scope = global_scope() self.executor = core.AsyncExecutor(scope, p) + def run_startup_program(self, program=None, place=None): + if program is None: + program = fluid.default_startup_program() + + if place is None: + place = core.CPUPlace() + + if not isinstance(place, core.CPUPlace): + raise ValueError("AsyncExecutor only supports CPU device") + + executor = Executor(place) + executor.run(program) + def run(self, program, data_feed, filelist, thread_num, fetch): """ Run program by this Executor. Feed data by feed map, fetch result by fetch_list. -- GitLab