提交 5c52b688 编写于 作者: W wangguibao

Fix bug

上级 11136db7
......@@ -98,7 +98,6 @@ void DataFeed::CheckStart() {
template<typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
CheckInit();
if (queue_size <= 0) {
LOG(ERROR) << "error: illegal queue size: " << queue_size;
return;
......@@ -165,6 +164,7 @@ void MultiSlotDataFeed::Init(paddle::framework::DataFeedDesc& data_feed_desc) {
}
paddle::framework::MultiSlotDesc multi_slot_desc = data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch());
SetQueueSize(data_feed_desc.batch());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
......
......@@ -176,6 +176,7 @@ void ExecutorThreadWorker::TrainFiles() {
thread_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here
for (auto& op : ops_) {
......@@ -190,8 +191,14 @@ void ExecutorThreadWorker::TrainFiles() {
fetch_values_[i] += avg_inspect;
}
++batch_cnt;
thread_scope_->DropKids();
}
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
}
}
void ExecutorThreadWorker::SetThreadId(int tid) {
......
......@@ -19,7 +19,7 @@ import contextlib
import six
from .framework import Program, default_main_program, Variable
from . import core
from .executor import global_scope
from .executor import global_scope, Executor
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
......@@ -67,6 +67,19 @@ class AsyncExecutor(object):
scope = global_scope()
self.executor = core.AsyncExecutor(scope, p)
def run_startup_program(self, program=None, place=None):
if program is None:
program = fluid.default_startup_program()
if place is None:
place = core.CPUPlace()
if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device")
executor = Executor(place)
executor.run(program)
def run(self, program, data_feed, filelist, thread_num, fetch):
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册