From dd67ad08a21a4b0b3be1fc32baf5827578fde82d Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Sat, 9 Mar 2019 16:02:24 +0800 Subject: [PATCH] modify c++ and python dataset related code & fix bug --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/async_executor.cc | 12 +++++++ paddle/fluid/framework/data_feed.cc | 7 ++-- paddle/fluid/framework/data_set.cc | 9 +++-- paddle/fluid/framework/data_set.h | 3 +- paddle/fluid/framework/dist_multi_trainer.cc | 2 +- paddle/fluid/framework/executor.cc | 6 ++-- paddle/fluid/framework/executor.h | 4 ++- paddle/fluid/framework/multi_trainer.cc | 2 +- paddle/fluid/framework/trainer.h | 6 ++-- python/paddle/fluid/__init__.py | 3 ++ python/paddle/fluid/data_feed_desc.py | 4 --- python/paddle/fluid/dataset.py | 34 +++++++++++++------ .../paddle/fluid/distributed/ps_instance.py | 12 +++++++ 14 files changed, 74 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 8c73de9cda2..e6e4a2ce48f 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -206,7 +206,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory. DEPS op_registry device_context scope framework_proto trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass data_feed_proto - variable_helper timer) + variable_helper timer fs shell) cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index d1a086f7148..078bd3961fb 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -59,6 +59,12 @@ void AsyncExecutor::GatherServers(const std::vector& host_sign_list, fleet_ptr_->GatherServers(host_sign_list, node_num); } +// todo InitModel +void AsyncExecutor::InitModel() { } + +// todo SaveModel +void AsyncExecutor::SaveModel(const std::string& path) { } + void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, const std::string& data_feed_desc_str, const std::vector& filelist, @@ -154,5 +160,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, return; } +// todo RunFromDataset +void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program, + Dataset* data_set, + const std::string& trainer_desc_str, + const bool debug) { } + } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index bf7ade95b28..c53a9b21b27 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.h" #include +#include #include "gflags/gflags.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" @@ -135,6 +136,7 @@ int PrivateQueueDataFeed::Next() { return batch_size_; } +// explicit instantiation template class PrivateQueueDataFeed>; template @@ -220,8 +222,6 @@ void InMemoryDataFeed::LocalShuffle() { std::random_shuffle(memory_data_.begin(), memory_data_.end()); } -template class InMemoryDataFeed>; - // todo global shuffle /* template @@ -242,6 +242,9 @@ void InMemoryDataFeed::GlobalShuffle(int trainer_num) { } */ +// explicit instantiation +template class InMemoryDataFeed>; + void MultiSlotDataFeed::Init( const paddle::framework::DataFeedDesc& data_feed_desc) { finish_init_ = false; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 047b172df45..457ae9360d7 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -12,6 +12,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/data_feed_factory.h" @@ -44,9 +47,9 @@ void Dataset::SetThreadNum(int thread_num) { void Dataset::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } -void Dataset::SetDataFeedDesc( - const paddle::framework::DataFeedDesc& data_feed_desc) { - data_feed_desc_ = data_feed_desc; +void Dataset::SetDataFeedDesc(const std::string& data_feed_desc_str) { + google::protobuf::TextFormat::ParseFromString( + data_feed_desc_str, &data_feed_desc_); } std::vector> diff --git a/paddle/fluid/framework/data_set.h b/paddle/fluid/framework/data_set.h index 91998e98ad7..06f47da3224 100644 --- a/paddle/fluid/framework/data_set.h +++ b/paddle/fluid/framework/data_set.h @@ -34,8 +34,7 @@ class Dataset { virtual void SetFileList(const std::vector& filelist); virtual void SetThreadNum(int thread_num); virtual void SetTrainerNum(int trainer_num); - virtual void SetDataFeedDesc( - const paddle::framework::DataFeedDesc& data_feed_desc); + virtual void SetDataFeedDesc(const std::string& data_feed_desc_str); virtual const std::vector& GetFileList() { return filelist_; } virtual int GetThreadNum() { return thread_num_; } diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 44509486ceb..cbfd2950130 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -22,7 +22,7 @@ namespace paddle { namespace framework { void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, - const Dataset& data_set) { + Dataset* data_set) { thread_num_ = trainer_desc.thread_num(); workers_.resize(thread_num_); readers_.resize(thread_num_); diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index ef84d387637..9eccea7aca7 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -14,11 +14,9 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" #include -#include -#include #include +#include #include - #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" @@ -119,7 +117,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, } void Executor::RunFromDataset(const ProgramDesc& main_program, - const Dataset& dataset, + Dataset* dataset, const std::string& trainer_desc_str, const bool debug) { VLOG(3) << "Start to RunFromDataset in executor"; diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 8685ad8028a..6368d9b38f1 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -19,6 +19,8 @@ limitations under the License. */ #include #include #include +#include +#include #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" @@ -112,7 +114,7 @@ class Executor { void EnableMKLDNN(const ProgramDesc& program); - void RunFromDataset(const ProgramDesc& main_program, const Dataset& dataset, + void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset, const std::string& trainer_desc_str, const bool debug); public: diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index dd52d3608a5..7d9b6839e38 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -22,7 +22,7 @@ namespace paddle { namespace framework { void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, - const Dataset& dataset) { + Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); // get filelist from trainer_desc here workers_.resize(thread_num_); diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 2de4d93cb87..30f19704859 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -42,7 +42,7 @@ class TrainerBase { void SetScope(Scope* root_scope); void SetDebug(const bool debug) { debug_ = debug; } virtual void Initialize(const TrainerDesc& trainer_desc, - const Dataset& data_set) = 0; + Dataset* data_set) = 0; virtual void InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place) = 0; virtual void InitOtherEnv(const ProgramDesc& main_program) = 0; @@ -62,7 +62,7 @@ class MultiTrainer : public TrainerBase { MultiTrainer() {} virtual ~MultiTrainer() {} virtual void Initialize(const TrainerDesc& trainer_desc, - const Dataset& data_set); + Dataset* data_set); virtual void InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place); virtual void InitOtherEnv(const ProgramDesc& main_program) {} @@ -81,7 +81,7 @@ class DistMultiTrainer : public MultiTrainer { DistMultiTrainer() {} virtual ~DistMultiTrainer() {} virtual void Initialize(const TrainerDesc& trainer_desc, - const Dataset& data_set); + Dataset* data_set); virtual void InitOtherEnv(const ProgramDesc& main_program); virtual void Finalize(); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 24c8a6934fe..b67651bf310 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -24,6 +24,9 @@ from .executor import * from . import data_feed_desc from .data_feed_desc import * +from . import dataset +from .dataset import * + from . import async_executor from .async_executor import * diff --git a/python/paddle/fluid/data_feed_desc.py b/python/paddle/fluid/data_feed_desc.py index b041ba90cff..80745aac830 100644 --- a/python/paddle/fluid/data_feed_desc.py +++ b/python/paddle/fluid/data_feed_desc.py @@ -139,10 +139,6 @@ class DataFeedDesc(object): self.proto_desc.multi_slot_desc.slots[self.__name_to_index[ name]].is_used = True - def global_shuffle(self): - self.data.global_shuffle() - pass - def desc(self): """ Returns a protobuf message for this DataFeedDesc diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 10963511642..fd6ce02adda 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -23,9 +23,9 @@ class DatasetFactory(object): pass def create_dataset(self, datafeed_class): - datafeed_class = datafeed_class.capitalize() try: dataset = globals()[datafeed_class]() + return dataset except: raise ValueError("datafeed class %s does not exist" % datafeed_class) @@ -37,6 +37,7 @@ class DatasetBase(object): # to decide whether we need create in memory instance self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc.pipe_command = "cat" + self.dataset = core.Dataset() def set_pipe_command(self, pipe_command): """ @@ -60,17 +61,23 @@ class DatasetBase(object): """ self.proto_desc.batch_size = batch_size + def set_thread(self, thread_num): + self.dataset.set_thread_num(thread_num) + + def set_filelist(self, filelist): + self.dataset.set_filelist(filelist) + def set_use_var(self, var_list): - multi_slot = self.proto_desc.multi_slot_desc() + multi_slot = self.proto_desc.multi_slot_desc for var in var_list: - slot_var = multi_slot.add() + slot_var = multi_slot.slots.add() slot_var.is_used = True slot_var.name = var.name if var.lod_level == 0: slot_var.is_dense = True - if var.dtype == core.VarType.FP32: + if var.dtype == core.VarDesc.VarType.FP32: slot_var.type = "float32" - elif var.dtype == core.VarType.INT64: + elif var.dtype == core.VarDesc.VarType.INT64: slot_var.type = "uint64" else: raise ValueError( @@ -93,17 +100,24 @@ class DatasetBase(object): class InMemoryDataset(DatasetBase): def __init__(self): - super(InMemoryDataset.__init__()) - self.proto_desc.name = "InMemoryDataFeed" + super(InMemoryDataset, self).__init__() + self.proto_desc.name = "MultiSlotInMemoryDataFeed" + + def load_into_memory(self): + self.dataset.set_data_feed_desc(self.desc()) + self.dataset.load_into_memory() def local_shuffle(self): - pass + self.dataset.local_shuffle() def global_shuffle(self): - pass + from .distributed import ps_instance + instance = ps_instance.PaddlePSInstance(1, 2) + self.dataset.set_trainer_num(instance.get_worker_num()) + self.global_shuffle() class QueueDataset(DatasetBase): def __init__(self): - super(QueueDataset.__init__()) + super(QueueDataset, self).__init__() self.proto_desc.name = "MultiSlotDataFeed" diff --git a/python/paddle/fluid/distributed/ps_instance.py b/python/paddle/fluid/distributed/ps_instance.py index d3ce3ce6934..19d661c660e 100644 --- a/python/paddle/fluid/distributed/ps_instance.py +++ b/python/paddle/fluid/distributed/ps_instance.py @@ -121,6 +121,18 @@ class PaddlePSInstance(object): """ return self._nodes + def get_worker_num(self): + """ + Return worker num + """ + return self._worker_num + + def get_server_num(self): + """ + Return server num + """ + return self._server_num + def barrier_all(self): """ barrier workers and servers -- GitLab