diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index a1b718fe2c41f7028ab1f7d1a49a7c1264a8cfad..be983e614ef41dde2883ba333c6db217d9c7b124 100755 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -98,7 +98,6 @@ void DataFeed::CheckStart() { template void PrivateQueueDataFeed::SetQueueSize(int queue_size) { - CheckInit(); if (queue_size <= 0) { LOG(ERROR) << "error: illegal queue size: " << queue_size; return; @@ -165,6 +164,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { } paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); SetBatchSize(data_feed_desc.batch()); + SetQueueSize(data_feed_desc.batch()); size_t all_slot_num = multi_slot_desc.slots_size(); all_slots_.resize(all_slot_num); all_slots_type_.resize(all_slot_num); diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index 6a84136ac7033dc5f17e382f625b81c3fca762e8..c360fe5d8aeb8882ec9abc0e8cb38262e24bb1b3 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -176,6 +176,7 @@ void ExecutorThreadWorker::TrainFiles() { thread_reader_->Start(); int cur_batch; + int batch_cnt = 0; while ((cur_batch = thread_reader_->Next()) > 0) { // executor run here for (auto& op : ops_) { @@ -190,8 +191,14 @@ void ExecutorThreadWorker::TrainFiles() { fetch_values_[i] += avg_inspect; } + ++batch_cnt; thread_scope_->DropKids(); } + + for (int i = 0; i < fetch_var_num; ++i) { + fetch_values_[i] = fetch_values_[i] / batch_cnt; + } + } void ExecutorThreadWorker::SetThreadId(int tid) { diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index feabfdcfa21e6dc9d52c6ff174afeb0024972a1d..7896a61f02ecbb20e0163343e2831d373025293f 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -19,7 +19,7 @@ import contextlib import six from .framework import Program, default_main_program, Variable from . import core -from .executor import global_scope +from .executor import global_scope, Executor from paddle.fluid.proto import data_feed_pb2 from google.protobuf import text_format @@ -75,6 +75,19 @@ class AsyncExecutor(object): scope = global_scope() self.executor = core.AsyncExecutor(scope, p) + def run_startup_program(self, program=None, place=None): + if program is None: + program = fluid.default_startup_program() + + if place is None: + place = core.CPUPlace() + + if not isinstance(place, core.CPUPlace): + raise ValueError("AsyncExecutor only supports CPU device") + + executor = Executor(place) + executor.run(program) + def run(self, program, data_feed, filelist, thread_num, fetch): """ Run program by this Executor. Feed data by feed map, fetch result by fetch_list.