From ecfc7df913a54dc499a4544905c74aacdad0e263 Mon Sep 17 00:00:00 2001 From: xujiaqi01 Date: Wed, 13 Mar 2019 14:45:20 +0800 Subject: [PATCH] add dataset factory && fix style --- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/data_feed.cc | 39 +++++------ paddle/fluid/framework/data_feed.h | 37 ++++++++-- paddle/fluid/framework/data_set.cc | 6 +- paddle/fluid/framework/dataset_factory.cc | 67 +++++++++++++++++++ paddle/fluid/framework/dataset_factory.h | 29 ++++++++ paddle/fluid/framework/fleet/fleet_wrapper.cc | 50 ++++++++++---- paddle/fluid/framework/fleet/fleet_wrapper.h | 10 +-- paddle/fluid/framework/multi_trainer.cc | 1 + paddle/fluid/pybind/data_set_py.cc | 24 ++++--- python/paddle/fluid/dataset.py | 10 ++- python/paddle/fluid/trainer_desc.py | 4 +- python/paddle/fluid/trainer_factory.py | 3 + 13 files changed, 224 insertions(+), 60 deletions(-) create mode 100644 paddle/fluid/framework/dataset_factory.cc create mode 100644 paddle/fluid/framework/dataset_factory.h diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 24c181e8c..d13009480 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -181,7 +181,7 @@ graph_to_program_pass variable_helper trainer_library data_feed_proto ${NGRAPH_E set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - cc_library(executor SRCS executor.cc multi_trainer.cc + cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry @@ -202,7 +202,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory. executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc - data_set.cc + data_set.cc dataset_factory.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass data_feed_proto diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 8ee625b5c..5cc1b8a6e 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -158,8 +158,6 @@ bool InMemoryDataFeed::Start() { DataFeed::CheckSetFileList(); if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) { FillMemoryDataToChannel(); - //std::unique_lock lock(*mutex_for_update_memory_data_); - //std::vector().swap(memory_data_); } DataFeed::finish_start_ = true; return true; @@ -227,13 +225,13 @@ void InMemoryDataFeed::SetTrainerNum(int trainer_num) { template void InMemoryDataFeed::PutInsToChannel(const std::string& ins_str) { T ins; - DeserializeIns(ins, ins_str); + DeserializeIns(&ins, ins_str); shuffled_ins_->Push(std::move(ins)); } template void InMemoryDataFeed::FillMemoryDataToChannel() { - VLOG(3) << "InMemoryDataFeed::FillMemoryDataToChannel, thread_id=" << thread_id_; + VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_; int64_t start = 0; int64_t end = 0; int64_t size = memory_data_->size(); @@ -252,7 +250,7 @@ void InMemoryDataFeed::FillMemoryDataToChannel() { template void InMemoryDataFeed::FillChannelToMemoryData() { - VLOG(3) << "InMemoryDataFeed::FillChannelToMemoryData, thread_id=" << thread_id_; + VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_; std::vector local_vec; std::shared_ptr> channel = nullptr; if (cur_channel_ == 0) { @@ -274,11 +272,12 @@ void InMemoryDataFeed::FillChannelToMemoryData() { template void InMemoryDataFeed::LoadIntoMemory() { - VLOG(3) << "InMemoryDataFeed::LoadIntoMemory() begin, thread_id=" << thread_id_; + VLOG(3) << "LoadIntoMemory() begin, thread_id=" << thread_id_; std::vector local_vec; std::string filename; while (DataFeed::PickOneFile(&filename)) { - VLOG(3) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_; + VLOG(3) << "PickOneFile, filename=" << filename + << ", thread_id=" << thread_id_; int err_no = 0; PrivateQueueDataFeed::fp_ = fs_open_read(filename, &err_no, PrivateQueueDataFeed::pipe_command_); @@ -287,36 +286,38 @@ void InMemoryDataFeed::LoadIntoMemory() { while (ParseOneInstanceFromPipe(&instance)) { local_vec.push_back(instance); } - VLOG(3) << "InMemoryDataFeed::LoadIntoMemory() read all lines, thread_id=" << thread_id_; + VLOG(3) << "LoadIntoMemory() read all lines, file=" + << filename <<", thread_id=" << thread_id_; { std::lock_guard lock(*mutex_for_update_memory_data_); - memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end()); + memory_data_->insert(memory_data_->end(), + local_vec.begin(), local_vec.end()); } std::vector().swap(local_vec); } - VLOG(3) << "InMemoryDataFeed::LoadIntoMemory() end, thread_id=" << thread_id_; + VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_; } template void InMemoryDataFeed::LocalShuffle() { - VLOG(3) << "InMemoryDataFeed::LocalShuffle() begin, thread_id=" << thread_id_; + VLOG(3) << "LocalShuffle() begin, thread_id=" << thread_id_; FillMemoryDataToChannel(); - VLOG(3) << "InMemoryDataFeed::LocalShuffle() end, thread_id=" << thread_id_; + VLOG(3) << "LocalShuffle() end, thread_id=" << thread_id_; } template void InMemoryDataFeed::GlobalShuffle() { + VLOG(3) << "GlobalShuffle(), thread_id=" << thread_id_; auto fleet_ptr = FleetWrapper::GetInstance(); std::vector send_str_vec(trainer_num_); for (int64_t i = 0; i < memory_data_->size(); ++i) { // todo get ins id - //std::string ins_id = memory_data_[i].ins_id; + // std::string ins_id = memory_data_[i].ins_id; // todo hash - //int64_t hash_id = paddle::ps::local_random_engine()(); - int64_t hash_id = 0; - int64_t node_id = hash_id % trainer_num_; + int64_t random_num = fleet_ptr->local_random_engine()(); + int64_t node_id = random_num % trainer_num_; std::string str; - SerializeIns((*memory_data_)[i], str); + SerializeIns((*memory_data_)[i], &str); send_str_vec[node_id] += str; if (i % fleet_send_batch_size_ == 0 && i != 0) { for (int j = 0; j < send_str_vec.size(); ++j) { @@ -821,12 +822,12 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( // todo serialize ins in global shuffle void MultiSlotInMemoryDataFeed::SerializeIns( - const std::vector& ins, std::string& str) { + const std::vector& ins, std::string* str) { auto fleet_ptr = FleetWrapper::GetInstance(); fleet_ptr->Serialize(ins, str); } // todo deserialize ins in global shuffle -void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector& ins, +void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector* ins, const std::string& str) { auto fleet_ptr = FleetWrapper::GetInstance(); fleet_ptr->Deserialize(ins, str); diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 98aeb4b1f..5afae9ea5 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -212,13 +212,16 @@ class InMemoryDataFeed : public PrivateQueueDataFeed { virtual void LoadIntoMemory(); virtual void LocalShuffle(); virtual void GlobalShuffle(); + protected: - virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0; + virtual void AddInstanceToInsVec(T* vec_ins, + const T& instance, + int index) = 0; virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstanceFromPipe(T* instance) = 0; virtual void PutToFeedVec(const T& ins_vec) = 0; - virtual void SerializeIns(const T& ins, std::string& str) = 0; - virtual void DeserializeIns(T& ins, const std::string& str) = 0; + virtual void SerializeIns(const T& ins, std::string* str) = 0; + virtual void DeserializeIns(T* ins, const std::string& str) = 0; int thread_id_; int thread_num_; @@ -284,6 +287,28 @@ class MultiSlotType { const std::string& GetType() const { return type_; } std::string& MutableType() { return type_; } + std::string DebugString() { + std::stringstream ss; + ss << "type: " << type_ << "\n"; + ss << "offset:\n"; + ss << "["; + for (const size_t& i : offset_) { + ss << offset_[i] << ","; + } + ss << "]\ndata:\n["; + if (type_[0] == 'f') { + for (const float& i : float_feasign_) { + ss << i << ","; + } + } else { + for (const uint64_t& i : uint64_feasign_) { + ss << i << ","; + } + } + ss << "]\n"; + return ss.str(); + } + private: void CheckType(const std::string& type) const { PADDLE_ENFORCE((type == "uint64") || (type == "float"), @@ -336,8 +361,10 @@ class MultiSlotInMemoryDataFeed virtual bool ParseOneInstance(std::vector* instance); virtual bool ParseOneInstanceFromPipe(std::vector* instance); virtual void PutToFeedVec(const std::vector& ins_vec); - virtual void SerializeIns(const std::vector& ins, std::string& str); - virtual void DeserializeIns(std::vector& ins, const std::string& str); + virtual void SerializeIns(const std::vector& ins, + std::string* str); + virtual void DeserializeIns(std::vector* ins, + const std::string& str); }; } // namespace framework diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 7497e4c9a..adeadf0ce 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -54,7 +54,9 @@ void DatasetImpl::SetThreadNum(int thread_num) { } template -void DatasetImpl::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } +void DatasetImpl::SetTrainerNum(int trainer_num) { + trainer_num_ = trainer_num; +} template void DatasetImpl::SetDataFeedDesc(const std::string& data_feed_desc_str) { @@ -115,10 +117,12 @@ void DatasetImpl::GlobalShuffle() { // if it is not InMemory, memory_data_ is empty std::random_shuffle(memory_data_.begin(), memory_data_.end()); auto fleet_ptr = FleetWrapper::GetInstance(); + VLOG(3) << "registe_client2client_msg_handler"; fleet_ptr->registe_client2client_msg_handler(0, [this](int msg_type, int client_id, const std::string& msg) -> int { return this->ReceiveFromClient(msg_type, client_id, msg); }); + VLOG(3) << "start global shuffle threads"; std::vector global_shuffle_threads; for (int i = 0; i < thread_num_; ++i) { global_shuffle_threads.push_back( diff --git a/paddle/fluid/framework/dataset_factory.cc b/paddle/fluid/framework/dataset_factory.cc new file mode 100644 index 000000000..56f425c1e --- /dev/null +++ b/paddle/fluid/framework/dataset_factory.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/dataset_factory.h" +#include +#include +#include + +#include "paddle/fluid/framework/data_set.h" + +namespace paddle { +namespace framework { +typedef std::shared_ptr (*CreateDatasetFunction)(); +typedef std::unordered_map datasetMap; +datasetMap g_dataset_map; + +#define REGISTER_DATASET_CLASS(dataset_class) \ + namespace { \ + std::shared_ptr Creator_##dataset_class() { \ + return std::shared_ptr(new dataset_class); \ + } \ + class __Registerer_##dataset_class { \ + public: \ + __Registerer_##dataset_class() { \ + g_dataset_map[#dataset_class] = &Creator_##dataset_class; \ + } \ + }; \ + __Registerer_##dataset_class g_registerer_##dataset_class; \ + } // namespace + +std::string DatasetFactory::DatasetTypeList() { + std::string dataset_types; + for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); + ++iter) { + if (iter != g_dataset_map.begin()) { + dataset_types += ", "; + } + dataset_types += iter->first; + } + return dataset_types; +} + +std::shared_ptr DatasetFactory::CreateDataset( + std::string dataset_class) { + if (g_dataset_map.count(dataset_class) < 1) { + LOG(WARNING) << "Your Dataset " << dataset_class + << "is not supported currently"; + LOG(WARNING) << "Supported Dataset: " << DatasetTypeList(); + exit(-1); + } + return g_dataset_map[dataset_class](); +} + +REGISTER_DATASET_CLASS(MultiSlotDataset); +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/dataset_factory.h b/paddle/fluid/framework/dataset_factory.h new file mode 100644 index 000000000..2894b69f8 --- /dev/null +++ b/paddle/fluid/framework/dataset_factory.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/data_set.h" + +namespace paddle { +namespace framework { +class DatasetFactory { + public: + static std::string DatasetTypeList(); + static std::shared_ptr CreateDataset(std::string dataset_class); +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.cc b/paddle/fluid/framework/fleet/fleet_wrapper.cc index a2d60927f..2696259f5 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.cc +++ b/paddle/fluid/framework/fleet/fleet_wrapper.cc @@ -27,6 +27,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/fleet/fleet_wrapper.h" +#include #include "paddle/fluid/framework/data_feed.h" namespace paddle { @@ -45,7 +46,7 @@ paddle::ps::Archive& operator << ( ar << ins.GetOffset(); ar << ins.GetFloatData(); ar << ins.GetUint64Data(); -return ar; + return ar; } template @@ -56,7 +57,7 @@ paddle::ps::Archive& operator >> ( ar >> ins.MutableOffset(); ar >> ins.MutableFloatData(); ar >> ins.MutableUint64Data(); -return ar; + return ar; } #endif @@ -291,42 +292,63 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( #endif } -// todo registe_client2client_msg_handler -int FleetWrapper::registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler) { - return 0; +int FleetWrapper::registe_client2client_msg_handler( + int msg_type, MsgHandlerFunc handler) { + pslib_ptr_->_worker_ptr->registe_client2client_msg_handler( + msg_type, handler); + return 0; } -// todo send_client2client_msg -int FleetWrapper::send_client2client_msg(int msg_type, int to_client_id, const std::string& msg) { - return 0; +int FleetWrapper::send_client2client_msg( + int msg_type, int to_client_id, const std::string& msg) { + pslib_ptr_->_worker_ptr->send_client2client_msg( + msg_type, to_client_id, msg); + return 0; +} + +std::default_random_engine& FleetWrapper::local_random_engine() { + struct engine_wrapper_t { + std::default_random_engine engine; + engine_wrapper_t() { + struct timespec tp; + clock_gettime(CLOCK_REALTIME, &tp); + double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9; + static std::atomic x(0); + std::seed_seq sseq = {x++, x++, x++, + (uint64_t)(cur_time * 1000)}; + engine.seed(sseq); + } + }; + thread_local engine_wrapper_t r; + return r.engine; } template -void FleetWrapper::Serialize(const T& t, std::string& str) { +void FleetWrapper::Serialize(const T& t, std::string* str) { #ifdef PADDLE_WITH_PSLIB paddle::ps::BinaryArchive ar; ar << t; - str = std::string(ar.buffer(), ar.length()); + *str = std::string(ar.buffer(), ar.length()); #else VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib"; #endif } template -void FleetWrapper::Deserialize(T& t, const std::string& str) { +void FleetWrapper::Deserialize(T* t, const std::string& str) { #ifdef PADDLE_WITH_PSLIB paddle::ps::BinaryArchive ar; ar.set_read_buffer(const_cast(str.c_str()), str.length(), nullptr); - t = ar.get(); + *t = ar.get(); #else VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib"; #endif } template void FleetWrapper::Serialize>( - const std::vector&, std::string&); + const std::vector&, std::string*); template void FleetWrapper::Deserialize( - std::vector&, const std::string&); + std::vector*, const std::string&); } // end namespace framework } // end namespace paddle diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index f98db1fe8..0e2027fcf 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -21,7 +21,7 @@ limitations under the License. */ #endif #include #include -#include +#include #include #include #include "paddle/fluid/framework/scope.h" @@ -116,13 +116,15 @@ class FleetWrapper { typedef std::function MsgHandlerFunc; int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler); - int send_client2client_msg(int msg_type, int to_client_id, const std::string& msg); + int send_client2client_msg(int msg_type, + int to_client_id, + const std::string& msg); std::default_random_engine& local_random_engine(); template - void Serialize(const T& t, std::string& str); + void Serialize(const T& t, std::string* str); template - void Deserialize(T& t, const std::string& str); + void Deserialize(T* t, const std::string& str); static std::shared_ptr GetInstance() { if (NULL == s_instance_) { diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 995cef4d0..c3b38fade 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -65,6 +65,7 @@ void MultiTrainer::Finalize() { for (auto& th : threads_) { th.join(); } + // todo dataset->DestroyReaders(); } } // end namespace framework diff --git a/paddle/fluid/pybind/data_set_py.cc b/paddle/fluid/pybind/data_set_py.cc index ca0545129..3ed4c01be 100644 --- a/paddle/fluid/pybind/data_set_py.cc +++ b/paddle/fluid/pybind/data_set_py.cc @@ -21,7 +21,7 @@ limitations under the License. */ #endif #include #include - +#include #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/async_executor.h" @@ -33,6 +33,7 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/variant.h" #include "paddle/fluid/pybind/data_set_py.h" +#include "paddle/fluid/framework/dataset_factory.h" namespace py = pybind11; namespace pd = paddle::framework; @@ -41,17 +42,18 @@ namespace paddle { namespace pybind { void BindDataset(py::module* m) { - py::class_(*m, "MultiSlotDataset") - .def(py::init([]() { - return std::unique_ptr(new framework::MultiSlotDataset()); + py::class_>(*m, "Dataset") + .def(py::init([](const std::string& name = "MultiSlotDataset") { + return framework::DatasetFactory::CreateDataset(name); })) - .def("set_filelist", &framework::MultiSlotDataset::SetFileList) - .def("set_thread_num", &framework::MultiSlotDataset::SetThreadNum) - .def("set_trainer_num", &framework::MultiSlotDataset::SetTrainerNum) - .def("set_data_feed_desc", &framework::MultiSlotDataset::SetDataFeedDesc) - .def("load_into_memory", &framework::MultiSlotDataset::LoadIntoMemory) - .def("local_shuffle", &framework::MultiSlotDataset::LocalShuffle) - .def("global_shuffle", &framework::MultiSlotDataset::GlobalShuffle); + .def("set_filelist", &framework::Dataset::SetFileList) + .def("set_thread_num", &framework::Dataset::SetThreadNum) + .def("set_trainer_num", &framework::Dataset::SetTrainerNum) + .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc) + .def("load_into_memory", &framework::Dataset::LoadIntoMemory) + .def("local_shuffle", &framework::Dataset::LocalShuffle) + .def("global_shuffle", &framework::Dataset::GlobalShuffle); } } // end namespace pybind diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 932fb6429..6d239260c 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -37,7 +37,7 @@ class DatasetBase(object): # to decide whether we need create in memory instance self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc.pipe_command = "cat" - self.dataset = core.MultiSlotDataset() + self.dataset = core.Dataset("MultiSlotDataset") self.thread_num = 0 def set_pipe_command(self, pipe_command): @@ -119,10 +119,16 @@ class InMemoryDataset(DatasetBase): from .distributed import ps_instance instance = ps_instance.PaddlePSInstance(1, 2) self.dataset.set_trainer_num(instance.get_worker_num()) - self.global_shuffle() + self.dataset.global_shuffle() class QueueDataset(DatasetBase): def __init__(self): super(QueueDataset, self).__init__() self.proto_desc.name = "MultiSlotDataFeed" + + def local_shuffle(self): + pass + + def global_shuffle(self): + pass diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 176da959f..61165cc6e 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -20,7 +20,7 @@ from google.protobuf import text_format __all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer'] -# can be initialized from train_desc, +# can be initialized from train_desc, class TrainerDesc(object): def __init__(self): ''' @@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc): def gen_trainer_desc(self): super(MultiTrainer, self).gen_trainer_desc() self.proto_desc.class_name = "MultiTrainer" - self.device_worker_.gen_worker_desc(self.proto_desc, fleet_desc_) + self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_) class DistMultiTrainer(TrainerDesc): diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index 51c7ddb9a..9d3883c5d 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from trainer_desc import * +from device_worker import * + __all__ = ["TrainerFactory"] -- GitLab