提交 5c52b688 编写于 作者: W wangguibao

Fix bug

上级 11136db7
...@@ -98,7 +98,6 @@ void DataFeed::CheckStart() { ...@@ -98,7 +98,6 @@ void DataFeed::CheckStart() {
template<typename T> template<typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) { void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
CheckInit();
if (queue_size <= 0) { if (queue_size <= 0) {
LOG(ERROR) << "error: illegal queue size: " << queue_size; LOG(ERROR) << "error: illegal queue size: " << queue_size;
return; return;
...@@ -165,6 +164,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) { ...@@ -165,6 +164,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
} }
paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc(); paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch()); SetBatchSize(data_feed_desc.batch());
SetQueueSize(data_feed_desc.batch());
size_t all_slot_num = multi_slot_desc.slots_size(); size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num); all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num); all_slots_type_.resize(all_slot_num);
......
...@@ -176,6 +176,7 @@ void ExecutorThreadWorker::TrainFiles() { ...@@ -176,6 +176,7 @@ void ExecutorThreadWorker::TrainFiles() {
thread_reader_->Start(); thread_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) { while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here // executor run here
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -190,8 +191,14 @@ void ExecutorThreadWorker::TrainFiles() { ...@@ -190,8 +191,14 @@ void ExecutorThreadWorker::TrainFiles() {
fetch_values_[i] += avg_inspect; fetch_values_[i] += avg_inspect;
} }
++batch_cnt;
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
}
} }
void ExecutorThreadWorker::SetThreadId(int tid) { void ExecutorThreadWorker::SetThreadId(int tid) {
......
...@@ -19,7 +19,7 @@ import contextlib ...@@ -19,7 +19,7 @@ import contextlib
import six import six
from .framework import Program, default_main_program, Variable from .framework import Program, default_main_program, Variable
from . import core from . import core
from .executor import global_scope from .executor import global_scope, Executor
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
...@@ -67,6 +67,19 @@ class AsyncExecutor(object): ...@@ -67,6 +67,19 @@ class AsyncExecutor(object):
scope = global_scope() scope = global_scope()
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
def run_startup_program(self, program=None, place=None):
if program is None:
program = fluid.default_startup_program()
if place is None:
place = core.CPUPlace()
if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device")
executor = Executor(place)
executor.run(program)
def run(self, program, data_feed, filelist, thread_num, fetch): def run(self, program, data_feed, filelist, thread_num, fetch):
""" """
Run program by this Executor. Feed data by feed map, fetch result by fetch_list. Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册