diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index c53a9b21b278c08d76174223d6a358a2ab311035..fcba99d5f3fda9d612589eec279e0e61571ab8c6 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -44,10 +44,14 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { bool DataFeed::SetFileList(const std::vector& files) { std::unique_lock lock(mutex_for_pick_file_); CheckInit(); + // Do not set finish_set_filelist_ flag, + // since a user may set file many times after init reader + /* if (finish_set_filelist_) { VLOG(3) << "info: you have set the filelist."; return false; } + */ PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); filelist_.assign(files.begin(), files.end()); file_idx_ = 0; diff --git a/paddle/fluid/framework/data_feed_factory.cc b/paddle/fluid/framework/data_feed_factory.cc index 2938655af57c302f1a90ea4c2f533230b1346c66..201d6c0d0b96469afbee1c3262e549d9d4e512dd 100644 --- a/paddle/fluid/framework/data_feed_factory.cc +++ b/paddle/fluid/framework/data_feed_factory.cc @@ -54,6 +54,9 @@ std::string DataFeedFactory::DataFeedTypeList() { std::shared_ptr DataFeedFactory::CreateDataFeed( std::string data_feed_class) { if (g_data_feed_map.count(data_feed_class) < 1) { + LOG(WARNING) << "Your DataFeed " << data_feed_class + << "is not supported currently"; + LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList(); exit(-1); } return g_data_feed_map[data_feed_class](); diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index baa971cde976865259935dd7986b6539d67e5563..ce59bdff8fa3c482755e78939c4735628b49be6b 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -12,10 +12,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "paddle/fluid/framework/data_set.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" -#include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_feed_factory.h" namespace paddle { @@ -24,6 +24,7 @@ namespace framework { Dataset::Dataset() { thread_num_ = 1; } void Dataset::SetFileList(const std::vector& filelist) { + VLOG(3) << "filelist size: " << filelist.size(); filelist_ = filelist; int file_cnt = filelist_.size(); if (thread_num_ > file_cnt) { @@ -34,6 +35,8 @@ void Dataset::SetFileList(const std::vector& filelist) { } } +// buggy here, a user should set filelist first before this function +// not user friendly void Dataset::SetThreadNum(int thread_num) { int file_cnt = filelist_.size(); if (file_cnt != 0 && thread_num > file_cnt) { @@ -48,8 +51,8 @@ void Dataset::SetThreadNum(int thread_num) { void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) { - google::protobuf::TextFormat::ParseFromString( - data_feed_desc_str, &data_feed_desc_); + google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, + &data_feed_desc_); } const std::vector>& @@ -107,14 +110,19 @@ void Dataset::GlobalShuffle() { } void Dataset::CreateReaders() { + VLOG(3) << "Calling CreateReaders()"; CHECK(thread_num_ > 0) << "thread_num should > 0"; + VLOG(3) << "thread_num in Readers: " << thread_num_; + VLOG(3) << "readers size: " << readers_.size(); if (readers_.size() != 0) { return; } + VLOG(3) << "data feed class name: " << data_feed_desc_.name(); for (int64_t i = 0; i < thread_num_; ++i) { readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); readers_.back()->Init(data_feed_desc_); } + VLOG(3) << "Filelist size in readers: " << filelist_.size(); readers_[0]->SetFileList(filelist_); } diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 9997da01969cbe03a40662fd0cba8dec14002a4a..a56a3cea60acb836701cbe508a74c774cf4d0b14 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -23,12 +23,13 @@ namespace paddle { namespace framework { void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, - Dataset* data_set) { + Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); workers_.resize(thread_num_); + dataset->CreateReaders(); const std::vector> readers = - data_set->GetReaders(); + dataset->GetReaders(); for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 9ba50ff9eea0afa47b940c4bdcb50fb44c7c2c72..501480876b216b36cfe4b6f0e99a7acd7b555193 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -14,8 +14,9 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" #include -#include +#include #include +#include #include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 148557a95427389ad28db586de13bbc689f3313e..0bc65f484dad3320bb95e0e9986629495bbc5368 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -90,6 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() { int batch_cnt = 0; timeline.Start(); while ((cur_batch = device_reader_->Next()) > 0) { + LOG(WARNING) << "read a batch in thread " << thread_id_; timeline.Pause(); read_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec(); diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 0da4fa863f6cf5e134e6322a1e47d65601be1da9..995cef4d076ca131c39c4945f33c127761233178 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -26,8 +26,12 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, thread_num_ = trainer_desc.thread_num(); // get filelist from trainer_desc here workers_.resize(thread_num_); + VLOG(3) << "worker thread num: " << thread_num_; + dataset->CreateReaders(); + VLOG(3) << "readers created"; const std::vector> readers = dataset->GetReaders(); + VLOG(3) << "readers num: " << readers.size(); for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); @@ -50,6 +54,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, } void MultiTrainer::Run() { + VLOG(3) << "Going to run"; for (int thidx = 0; thidx < thread_num_; ++thidx) { threads_.push_back( std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index fd6ce02addae57d53be6a0a833ff6b80763a7e01..31cb05558752f4af7711f0e7c9299a587bc248f3 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -22,7 +22,7 @@ class DatasetFactory(object): def __init__(self): pass - def create_dataset(self, datafeed_class): + def create_dataset(self, datafeed_class="QueueDataset"): try: dataset = globals()[datafeed_class]() return dataset @@ -38,6 +38,7 @@ class DatasetBase(object): self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc.pipe_command = "cat" self.dataset = core.Dataset() + self.thread_num = 0 def set_pipe_command(self, pipe_command): """ @@ -63,6 +64,7 @@ class DatasetBase(object): def set_thread(self, thread_num): self.dataset.set_thread_num(thread_num) + self.thread_num = thread_num def set_filelist(self, filelist): self.dataset.set_filelist(filelist) @@ -84,6 +86,9 @@ class DatasetBase(object): "Currently, fluid.dataset only supports dtype=float32 and dtype=int64" ) + def _prepare_to_run(self): + self.dataset.set_data_feed_desc(self.desc()) + def desc(self): """ Returns a protobuf message for this DataFeedDesc @@ -104,7 +109,7 @@ class InMemoryDataset(DatasetBase): self.proto_desc.name = "MultiSlotInMemoryDataFeed" def load_into_memory(self): - self.dataset.set_data_feed_desc(self.desc()) + _prepare_to_run() self.dataset.load_into_memory() def local_shuffle(self): diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 98a16e20112894801debd4186b9e2bd9d699051f..dd8d2c7c08e3b43e3a6cbffcac9b94a730eb2e09 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable from . import core from . import compiler from .. import compat as cpt +from .trainer_factory import TrainerFactory __all__ = ['Executor', 'global_scope', 'scope_guard'] @@ -616,6 +617,7 @@ class Executor(object): dataset=None, fetch_list=None, scope=None, + thread=0, opt_info=None): if scope is None: scope = global_scope() @@ -624,7 +626,14 @@ class Executor(object): compiled = isinstance(program, compiler.CompiledProgram) if not compiled: trainer = TrainerFactory().create_trainer(opt_info) - self._default_executor.run_from_dataset(program_desc, + if thread <= 0: + trainer.set_thread(dataset.thread_num) + else: + trainer.set_thread(thread) + dataset._prepare_to_run() + print("run_from_dataset called") + self._default_executor.run_from_dataset(program.desc, scope, + dataset.dataset, trainer._desc()) else: # For compiled program, more runtime should be implemented