提交 b66f0074 编写于 作者: D dongdaxiang

fix data reading bugs in api, add VLOG(3) log for setup

上级 71aa307e
...@@ -44,10 +44,14 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) { ...@@ -44,10 +44,14 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
bool DataFeed::SetFileList(const std::vector<std::string>& files) { bool DataFeed::SetFileList(const std::vector<std::string>& files) {
std::unique_lock<std::mutex> lock(mutex_for_pick_file_); std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
CheckInit(); CheckInit();
// Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader
/*
if (finish_set_filelist_) { if (finish_set_filelist_) {
VLOG(3) << "info: you have set the filelist."; VLOG(3) << "info: you have set the filelist.";
return false; return false;
} }
*/
PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_.assign(files.begin(), files.end()); filelist_.assign(files.begin(), files.end());
file_idx_ = 0; file_idx_ = 0;
......
...@@ -54,6 +54,9 @@ std::string DataFeedFactory::DataFeedTypeList() { ...@@ -54,6 +54,9 @@ std::string DataFeedFactory::DataFeedTypeList() {
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) { std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) { if (g_data_feed_map.count(data_feed_class) < 1) {
LOG(WARNING) << "Your DataFeed " << data_feed_class
<< "is not supported currently";
LOG(WARNING) << "Supported DataFeed: " << DataFeedTypeList();
exit(-1); exit(-1);
} }
return g_data_feed_map[data_feed_class](); return g_data_feed_map[data_feed_class]();
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
namespace paddle { namespace paddle {
...@@ -24,6 +24,7 @@ namespace framework { ...@@ -24,6 +24,7 @@ namespace framework {
Dataset::Dataset() { thread_num_ = 1; } Dataset::Dataset() { thread_num_ = 1; }
void Dataset::SetFileList(const std::vector<std::string>& filelist) { void Dataset::SetFileList(const std::vector<std::string>& filelist) {
VLOG(3) << "filelist size: " << filelist.size();
filelist_ = filelist; filelist_ = filelist;
int file_cnt = filelist_.size(); int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) { if (thread_num_ > file_cnt) {
...@@ -34,6 +35,8 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) { ...@@ -34,6 +35,8 @@ void Dataset::SetFileList(const std::vector<std::string>& filelist) {
} }
} }
// buggy here, a user should set filelist first before this function
// not user friendly
void Dataset::SetThreadNum(int thread_num) { void Dataset::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size(); int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) { if (file_cnt != 0 && thread_num > file_cnt) {
...@@ -48,8 +51,8 @@ void Dataset::SetThreadNum(int thread_num) { ...@@ -48,8 +51,8 @@ void Dataset::SetThreadNum(int thread_num) {
void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) { void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) {
google::protobuf::TextFormat::ParseFromString( google::protobuf::TextFormat::ParseFromString(data_feed_desc_str,
data_feed_desc_str, &data_feed_desc_); &data_feed_desc_);
} }
const std::vector<std::shared_ptr<paddle::framework::DataFeed>>& const std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
...@@ -107,14 +110,19 @@ void Dataset::GlobalShuffle() { ...@@ -107,14 +110,19 @@ void Dataset::GlobalShuffle() {
} }
void Dataset::CreateReaders() { void Dataset::CreateReaders() {
VLOG(3) << "Calling CreateReaders()";
CHECK(thread_num_ > 0) << "thread_num should > 0"; CHECK(thread_num_ > 0) << "thread_num should > 0";
VLOG(3) << "thread_num in Readers: " << thread_num_;
VLOG(3) << "readers size: " << readers_.size();
if (readers_.size() != 0) { if (readers_.size() != 0) {
return; return;
} }
VLOG(3) << "data feed class name: " << data_feed_desc_.name();
for (int64_t i = 0; i < thread_num_; ++i) { for (int64_t i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name())); readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_); readers_.back()->Init(data_feed_desc_);
} }
VLOG(3) << "Filelist size in readers: " << filelist_.size();
readers_[0]->SetFileList(filelist_); readers_[0]->SetFileList(filelist_);
} }
......
...@@ -23,12 +23,13 @@ namespace paddle { ...@@ -23,12 +23,13 @@ namespace paddle {
namespace framework { namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* data_set) { Dataset* dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_); workers_.resize(thread_num_);
dataset->CreateReaders();
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers = const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
data_set->GetReaders(); dataset->GetReaders();
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
......
...@@ -14,8 +14,9 @@ limitations under the License. */ ...@@ -14,8 +14,9 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <deque> #include <deque>
#include <unordered_set> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h" #include "google/protobuf/message.h"
......
...@@ -90,6 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -90,6 +90,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int batch_cnt = 0; int batch_cnt = 0;
timeline.Start(); timeline.Start();
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
LOG(WARNING) << "read a batch in thread " << thread_id_;
timeline.Pause(); timeline.Pause();
read_time += timeline.ElapsedSec(); read_time += timeline.ElapsedSec();
total_time += timeline.ElapsedSec(); total_time += timeline.ElapsedSec();
......
...@@ -26,8 +26,12 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -26,8 +26,12 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here // get filelist from trainer_desc here
workers_.resize(thread_num_); workers_.resize(thread_num_);
VLOG(3) << "worker thread num: " << thread_num_;
dataset->CreateReaders();
VLOG(3) << "readers created";
const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers = const std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers =
dataset->GetReaders(); dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name()); trainer_desc.device_worker_name());
...@@ -50,6 +54,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program, ...@@ -50,6 +54,7 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
} }
void MultiTrainer::Run() { void MultiTrainer::Run() {
VLOG(3) << "Going to run";
for (int thidx = 0; thidx < thread_num_; ++thidx) { for (int thidx = 0; thidx < thread_num_; ++thidx) {
threads_.push_back( threads_.push_back(
std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get())); std::thread(&DeviceWorker::TrainFiles, workers_[thidx].get()));
......
...@@ -22,7 +22,7 @@ class DatasetFactory(object): ...@@ -22,7 +22,7 @@ class DatasetFactory(object):
def __init__(self): def __init__(self):
pass pass
def create_dataset(self, datafeed_class): def create_dataset(self, datafeed_class="QueueDataset"):
try: try:
dataset = globals()[datafeed_class]() dataset = globals()[datafeed_class]()
return dataset return dataset
...@@ -38,6 +38,7 @@ class DatasetBase(object): ...@@ -38,6 +38,7 @@ class DatasetBase(object):
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat" self.proto_desc.pipe_command = "cat"
self.dataset = core.Dataset() self.dataset = core.Dataset()
self.thread_num = 0
def set_pipe_command(self, pipe_command): def set_pipe_command(self, pipe_command):
""" """
...@@ -63,6 +64,7 @@ class DatasetBase(object): ...@@ -63,6 +64,7 @@ class DatasetBase(object):
def set_thread(self, thread_num): def set_thread(self, thread_num):
self.dataset.set_thread_num(thread_num) self.dataset.set_thread_num(thread_num)
self.thread_num = thread_num
def set_filelist(self, filelist): def set_filelist(self, filelist):
self.dataset.set_filelist(filelist) self.dataset.set_filelist(filelist)
...@@ -84,6 +86,9 @@ class DatasetBase(object): ...@@ -84,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64" "Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
) )
def _prepare_to_run(self):
self.dataset.set_data_feed_desc(self.desc())
def desc(self): def desc(self):
""" """
Returns a protobuf message for this DataFeedDesc Returns a protobuf message for this DataFeedDesc
...@@ -104,7 +109,7 @@ class InMemoryDataset(DatasetBase): ...@@ -104,7 +109,7 @@ class InMemoryDataset(DatasetBase):
self.proto_desc.name = "MultiSlotInMemoryDataFeed" self.proto_desc.name = "MultiSlotInMemoryDataFeed"
def load_into_memory(self): def load_into_memory(self):
self.dataset.set_data_feed_desc(self.desc()) _prepare_to_run()
self.dataset.load_into_memory() self.dataset.load_into_memory()
def local_shuffle(self): def local_shuffle(self):
......
...@@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable ...@@ -23,6 +23,7 @@ from .framework import Program, default_main_program, Variable
from . import core from . import core
from . import compiler from . import compiler
from .. import compat as cpt from .. import compat as cpt
from .trainer_factory import TrainerFactory
__all__ = ['Executor', 'global_scope', 'scope_guard'] __all__ = ['Executor', 'global_scope', 'scope_guard']
...@@ -616,6 +617,7 @@ class Executor(object): ...@@ -616,6 +617,7 @@ class Executor(object):
dataset=None, dataset=None,
fetch_list=None, fetch_list=None,
scope=None, scope=None,
thread=0,
opt_info=None): opt_info=None):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
...@@ -624,7 +626,14 @@ class Executor(object): ...@@ -624,7 +626,14 @@ class Executor(object):
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
if not compiled: if not compiled:
trainer = TrainerFactory().create_trainer(opt_info) trainer = TrainerFactory().create_trainer(opt_info)
self._default_executor.run_from_dataset(program_desc, if thread <= 0:
trainer.set_thread(dataset.thread_num)
else:
trainer.set_thread(thread)
dataset._prepare_to_run()
print("run_from_dataset called")
self._default_executor.run_from_dataset(program.desc, scope,
dataset.dataset,
trainer._desc()) trainer._desc())
else: else:
# For compiled program, more runtime should be implemented # For compiled program, more runtime should be implemented
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册