提交 b66f0074 编写于 作者: D dongdaxiang

fix data reading bugs in api, add VLOG(3) log for setup

上级 71aa307e
......@@ -44,10 +44,14 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
bool DataFeed::SetFileList(const std::vector<std::string>& files) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
CheckInit();
// Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader
/*
if (finish_set_filelist_) {
VLOG(3) << "info: you have set the filelist.";
return false;
}
*/
PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_.assign(files.begin(), files.end());
file_idx_ = 0;
......
......@@ -54,6 +54,9 @@ std::string DataFeedFactory::DataFeedTypeList() {
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList();
exit(-1);
}
return g_data_feed_map[data_feed_class]();
......
......@@ -12,10 +12,10 @@
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
namespace paddle {
......@@ -24,6 +24,7 @@ namespace framework {
Dataset::Dataset() { thread_num_ = 1; }
void Dataset::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist;
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
......@@ -34,6 +35,8 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) {
}
}
// buggy here, a user should set filelist first before this function
// not user friendly
void Dataset::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) {
......@@ -48,8 +51,8 @@ void Dataset::SetThreadNum(int thread_num) {
void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString(
data_feed_desc_str, &data_feed_desc_);
google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
&data_feed_desc_);
}
const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
......@@ -107,14 +110,19 @@ void Dataset::GlobalShuffle() {
}
void Dataset::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0";
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size();
if (readers_.size() != 0) {
return;
}
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int64_t i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
}
VLOG(3) << "Filelist size in readers: " << filelist_.size();
readers_[0]->SetFileList(filelist_);
}
......
......@@ -23,12 +23,13 @@ namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set) {
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
data_set->GetReaders();
dataset->GetReaders();
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
......
......@@ -14,8 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include <unordered_set>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
......
......@@ -90,6 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int batch_cnt = 0;
timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) {
LOG(WARNING) << "read a batch in thread " << thread_id_;
timeline.Pause();
read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec();
......
......@@ -26,8 +26,12 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here
workers_.resize(thread_num_);
VLOG(3) << "worker thread num: " << thread_num_;
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
......@@ -50,6 +54,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
void MultiTrainer::Run() {
VLOG(3) << "Going to run";
for (int thidx = 0; thidx < thread_num_; ++thidx) {
threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
......
......@@ -22,7 +22,7 @@ class DatasetFactory(object):
def __init__(self):
pass
def create_dataset(self, datafeed_class):
def create_dataset(self, datafeed_class="QueueDataset"):
try:
dataset = globals()[datafeed_class]()
return dataset
......@@ -38,6 +38,7 @@ class DatasetBase(object):
self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset()
self.thread_num = 0
def set_pipe_command(self, pipe_command):
"""
......@@ -63,6 +64,7 @@ class DatasetBase(object):
def set_thread(self, thread_num):
self.dataset.set_thread_num(thread_num)
self.thread_num = thread_num
def set_filelist(self, filelist):
self.dataset.set_filelist(filelist)
......@@ -84,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
def _prepare_to_run(self):
self.dataset.set_data_feed_desc(self.desc())
def desc(self):
"""
Returns a protobuf message for this DataFeedDesc
......@@ -104,7 +109,7 @@ class InMemoryDataset(DatasetBase):
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self):
self.dataset.set_data_feed_desc(self.desc())
_prepare_to_run()
self.dataset.load_into_memory()
def local_shuffle(self):
......
......@@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable
from . import core
from . import compiler
from .. import compat as cpt
from .trainer_factory import TrainerFactory
__all__ = ['Executor', 'global_scope', 'scope_guard']
......@@ -616,6 +617,7 @@ class Executor(object):
dataset=None,
fetch_list=None,
scope=None,
thread=0,
opt_info=None):
if scope is None:
scope = global_scope()
......@@ -624,7 +626,14 @@ class Executor(object):
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory().create_trainer(opt_info)
self._default_executor.run_from_dataset(program_desc,
if thread <= 0:
trainer.set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
dataset._prepare_to_run()
print("run_from_dataset called")
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc())
else:
# For compiled program, more runtime should be implemented
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册