提交 ecfc7df9 编写于 作者: X xujiaqi01 提交者: dongdaxiang

add dataset factory && fix style

上级 328f11b8
...@@ -181,7 +181,7 @@ graph_to_program_pass variable_helper trainer_library data_feed_proto ${NGRAPH_E ...@@ -181,7 +181,7 @@ graph_to_program_pass variable_helper trainer_library data_feed_proto ${NGRAPH_E
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
cc_library(executor SRCS executor.cc multi_trainer.cc cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
...@@ -202,7 +202,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory. ...@@ -202,7 +202,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto feed_fetch_method graph_to_program_pass data_feed_proto
......
...@@ -158,8 +158,6 @@ bool InMemoryDataFeed<T>::Start() { ...@@ -158,8 +158,6 @@ bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList(); DataFeed::CheckSetFileList();
if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) { if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
FillMemoryDataToChannel(); FillMemoryDataToChannel();
//std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
//std::vector<T>().swap(memory_data_);
} }
DataFeed::finish_start_ = true; DataFeed::finish_start_ = true;
return true; return true;
...@@ -227,13 +225,13 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) { ...@@ -227,13 +225,13 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template <typename T> template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) { void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins; T ins;
DeserializeIns(ins, ins_str); DeserializeIns(&ins, ins_str);
shuffled_ins_->Push(std::move(ins)); shuffled_ins_->Push(std::move(ins));
} }
template <typename T> template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() { void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
VLOG(3) << "InMemoryDataFeed<T>::FillMemoryDataToChannel, thread_id=" << thread_id_; VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
int64_t start = 0; int64_t start = 0;
int64_t end = 0; int64_t end = 0;
int64_t size = memory_data_->size(); int64_t size = memory_data_->size();
...@@ -252,7 +250,7 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() { ...@@ -252,7 +250,7 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
template <typename T> template <typename T>
void InMemoryDataFeed<T>::FillChannelToMemoryData() { void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG(3) << "InMemoryDataFeed<T>::FillChannelToMemoryData, thread_id=" << thread_id_; VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_;
std::vector<T> local_vec; std::vector<T> local_vec;
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr; std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
if (cur_channel_ == 0) { if (cur_channel_ == 0) {
...@@ -274,11 +272,12 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() { ...@@ -274,11 +272,12 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() { void InMemoryDataFeed<T>::LoadIntoMemory() {
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() begin, thread_id=" << thread_id_; VLOG(3) << "LoadIntoMemory() begin, thread_id=" << thread_id_;
std::vector<T> local_vec; std::vector<T> local_vec;
std::string filename; std::string filename;
while (DataFeed::PickOneFile(&filename)) { while (DataFeed::PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_; VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
int err_no = 0; int err_no = 0;
PrivateQueueDataFeed<T>::fp_ = PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_); fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
...@@ -287,36 +286,38 @@ void InMemoryDataFeed<T>::LoadIntoMemory() { ...@@ -287,36 +286,38 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
while (ParseOneInstanceFromPipe(&instance)) { while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance); local_vec.push_back(instance);
} }
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id=" << thread_id_; VLOG(3) << "LoadIntoMemory() read all lines, file="
<< filename <<", thread_id=" << thread_id_;
{ {
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_); std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end()); memory_data_->insert(memory_data_->end(),
local_vec.begin(), local_vec.end());
} }
std::vector<T>().swap(local_vec); std::vector<T>().swap(local_vec);
} }
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() end, thread_id=" << thread_id_; VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_;
} }
template <typename T> template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() { void InMemoryDataFeed<T>::LocalShuffle() {
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() begin, thread_id=" << thread_id_; VLOG(3) << "LocalShuffle() begin, thread_id=" << thread_id_;
FillMemoryDataToChannel(); FillMemoryDataToChannel();
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() end, thread_id=" << thread_id_; VLOG(3) << "LocalShuffle() end, thread_id=" << thread_id_;
} }
template <typename T> template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle() { void InMemoryDataFeed<T>::GlobalShuffle() {
VLOG(3) << "GlobalShuffle(), thread_id=" << thread_id_;
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::string> send_str_vec(trainer_num_); std::vector<std::string> send_str_vec(trainer_num_);
for (int64_t i = 0; i < memory_data_->size(); ++i) { for (int64_t i = 0; i < memory_data_->size(); ++i) {
// todo get ins id // todo get ins id
//std::string ins_id = memory_data_[i].ins_id; // std::string ins_id = memory_data_[i].ins_id;
// todo hash // todo hash
//int64_t hash_id = paddle::ps::local_random_engine()(); int64_t random_num = fleet_ptr->local_random_engine()();
int64_t hash_id = 0; int64_t node_id = random_num % trainer_num_;
int64_t node_id = hash_id % trainer_num_;
std::string str; std::string str;
SerializeIns((*memory_data_)[i], str); SerializeIns((*memory_data_)[i], &str);
send_str_vec[node_id] += str; send_str_vec[node_id] += str;
if (i % fleet_send_batch_size_ == 0 && i != 0) { if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) { for (int j = 0; j < send_str_vec.size(); ++j) {
...@@ -821,12 +822,12 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec( ...@@ -821,12 +822,12 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle // todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns( void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<MultiSlotType>& ins, std::string& str) { const std::vector<MultiSlotType>& ins, std::string* str) {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str); fleet_ptr->Serialize(ins, str);
} }
// todo deserialize ins in global shuffle // todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins, void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>* ins,
const std::string& str) { const std::string& str) {
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str); fleet_ptr->Deserialize(ins, str);
......
...@@ -212,13 +212,16 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> { ...@@ -212,13 +212,16 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void LoadIntoMemory(); virtual void LoadIntoMemory();
virtual void LocalShuffle(); virtual void LocalShuffle();
virtual void GlobalShuffle(); virtual void GlobalShuffle();
protected: protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0; virtual void AddInstanceToInsVec(T* vec_ins,
const T& instance,
int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0; virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0; virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0; virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const T& ins, std::string& str) = 0; virtual void SerializeIns(const T& ins, std::string* str) = 0;
virtual void DeserializeIns(T& ins, const std::string& str) = 0; virtual void DeserializeIns(T* ins, const std::string& str) = 0;
int thread_id_; int thread_id_;
int thread_num_; int thread_num_;
...@@ -284,6 +287,28 @@ class MultiSlotType { ...@@ -284,6 +287,28 @@ class MultiSlotType {
const std::string& GetType() const { return type_; } const std::string& GetType() const { return type_; }
std::string& MutableType() { return type_; } std::string& MutableType() { return type_; }
std::string DebugString() {
std::stringstream ss;
ss << "type: " << type_ << "\n";
ss << "offset:\n";
ss << "[";
for (const size_t& i : offset_) {
ss << offset_[i] << ",";
}
ss << "]\ndata:\n[";
if (type_[0] == 'f') {
for (const float& i : float_feasign_) {
ss << i << ",";
}
} else {
for (const uint64_t& i : uint64_feasign_) {
ss << i << ",";
}
}
ss << "]\n";
return ss.str();
}
private: private:
void CheckType(const std::string& type) const { void CheckType(const std::string& type) const {
PADDLE_ENFORCE((type == "uint64") || (type == "float"), PADDLE_ENFORCE((type == "uint64") || (type == "float"),
...@@ -336,8 +361,10 @@ class MultiSlotInMemoryDataFeed ...@@ -336,8 +361,10 @@ class MultiSlotInMemoryDataFeed
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance); virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance); virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec); virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins, std::string& str); virtual void SerializeIns(const std::vector<MultiSlotType>& ins,
virtual void DeserializeIns(std::vector<MultiSlotType>& ins, const std::string& str); std::string* str);
virtual void DeserializeIns(std::vector<MultiSlotType>* ins,
const std::string& str);
}; };
} // namespace framework } // namespace framework
......
...@@ -54,7 +54,9 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) { ...@@ -54,7 +54,9 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) {
} }
template <typename T> template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; } void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
template <typename T> template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) { void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
...@@ -115,10 +117,12 @@ void DatasetImpl<T>::GlobalShuffle() { ...@@ -115,10 +117,12 @@ void DatasetImpl<T>::GlobalShuffle() {
// if it is not InMemory, memory_data_ is empty // if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end()); std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance(); auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "registe_client2client_msg_handler";
fleet_ptr->registe_client2client_msg_handler(0, fleet_ptr->registe_client2client_msg_handler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int { [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg); return this->ReceiveFromClient(msg_type, client_id, msg);
}); });
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads; std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back( global_shuffle_threads.push_back(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/dataset_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
datasetMap g_dataset_map;
#define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \
} \
class __Registerer_##dataset_class { \
public: \
__Registerer_##dataset_class() { \
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
} \
}; \
__Registerer_##dataset_class g_registerer_##dataset_class; \
} // namespace
std::string DatasetFactory::DatasetTypeList() {
std::string dataset_types;
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end();
++iter) {
if (iter != g_dataset_map.begin()) {
dataset_types += ", ";
}
dataset_types += iter->first;
}
return dataset_types;
}
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
std::string dataset_class) {
if (g_dataset_map.count(dataset_class) < 1) {
LOG(WARNING) << "Your Dataset " << dataset_class
<< "is not supported currently";
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
exit(-1);
}
return g_dataset_map[dataset_class]();
}
REGISTER_DATASET_CLASS(MultiSlotDataset);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
class DatasetFactory {
public:
static std::string DatasetTypeList();
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
};
} // namespace framework
} // namespace paddle
...@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and ...@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
namespace paddle { namespace paddle {
...@@ -45,7 +46,7 @@ paddle::ps::Archive<AR>& operator << ( ...@@ -45,7 +46,7 @@ paddle::ps::Archive<AR>& operator << (
ar << ins.GetOffset(); ar << ins.GetOffset();
ar << ins.GetFloatData(); ar << ins.GetFloatData();
ar << ins.GetUint64Data(); ar << ins.GetUint64Data();
return ar; return ar;
} }
template<class AR> template<class AR>
...@@ -56,7 +57,7 @@ paddle::ps::Archive<AR>& operator >> ( ...@@ -56,7 +57,7 @@ paddle::ps::Archive<AR>& operator >> (
ar >> ins.MutableOffset(); ar >> ins.MutableOffset();
ar >> ins.MutableFloatData(); ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data(); ar >> ins.MutableUint64Data();
return ar; return ar;
} }
#endif #endif
...@@ -291,42 +292,63 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -291,42 +292,63 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
// todo registe_client2client_msg_handler int FleetWrapper::registe_client2client_msg_handler(
int FleetWrapper::registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler) { int msg_type, MsgHandlerFunc handler) {
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(
msg_type, handler);
return 0; return 0;
} }
// todo send_client2client_msg int FleetWrapper::send_client2client_msg(
int FleetWrapper::send_client2client_msg(int msg_type, int to_client_id, const std::string& msg) { int msg_type, int to_client_id, const std::string& msg) {
pslib_ptr_->_worker_ptr->send_client2client_msg(
msg_type, to_client_id, msg);
return 0; return 0;
} }
std::default_random_engine& FleetWrapper::local_random_engine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++,
(uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
};
thread_local engine_wrapper_t r;
return r.engine;
}
template<typename T> template<typename T>
void FleetWrapper::Serialize(const T& t, std::string& str) { void FleetWrapper::Serialize(const T& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar; paddle::ps::BinaryArchive ar;
ar << t; ar << t;
str = std::string(ar.buffer(), ar.length()); *str = std::string(ar.buffer(), ar.length());
#else #else
VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib"; VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib";
#endif #endif
} }
template<typename T> template<typename T>
void FleetWrapper::Deserialize(T& t, const std::string& str) { void FleetWrapper::Deserialize(T* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar; paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr); ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
t = ar.get<T>(); *t = ar.get<T>();
#else #else
VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib"; VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib";
#endif #endif
} }
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>( template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string&); const std::vector<MultiSlotType>&, std::string*);
template void FleetWrapper::Deserialize( template void FleetWrapper::Deserialize(
std::vector<MultiSlotType>&, const std::string&); std::vector<MultiSlotType>*, const std::string&);
} // end namespace framework } // end namespace framework
} // end namespace paddle } // end namespace paddle
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#endif #endif
#include <random> #include <random>
#include <atomic> #include <atomic>
#include <time.h> #include <ctime>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -116,13 +116,15 @@ class FleetWrapper { ...@@ -116,13 +116,15 @@ class FleetWrapper {
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc; typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler); int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler);
int send_client2client_msg(int msg_type, int to_client_id, const std::string& msg); int send_client2client_msg(int msg_type,
int to_client_id,
const std::string& msg);
std::default_random_engine& local_random_engine(); std::default_random_engine& local_random_engine();
template<typename T> template<typename T>
void Serialize(const T& t, std::string& str); void Serialize(const T& t, std::string* str);
template<typename T> template<typename T>
void Deserialize(T& t, const std::string& str); void Deserialize(T* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() { static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
......
...@@ -65,6 +65,7 @@ void MultiTrainer::Finalize() { ...@@ -65,6 +65,7 @@ void MultiTrainer::Finalize() {
for (auto& th : threads_) { for (auto& th : threads_) {
th.join(); th.join();
} }
// todo dataset->DestroyReaders();
} }
} // end namespace framework } // end namespace framework
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#endif #endif
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/async_executor.h"
...@@ -33,6 +33,7 @@ limitations under the License. */ ...@@ -33,6 +33,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/framework/dataset_factory.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework; namespace pd = paddle::framework;
...@@ -41,17 +42,18 @@ namespace paddle { ...@@ -41,17 +42,18 @@ namespace paddle {
namespace pybind { namespace pybind {
void BindDataset(py::module* m) { void BindDataset(py::module* m) {
py::class_<framework::MultiSlotDataset>(*m, "MultiSlotDataset") py::class_<framework::Dataset,
.def(py::init([]() { std::shared_ptr<framework::Dataset>>(*m, "Dataset")
return std::unique_ptr<framework::MultiSlotDataset>(new framework::MultiSlotDataset()); .def(py::init([](const std::string& name = "MultiSlotDataset") {
return framework::DatasetFactory::CreateDataset(name);
})) }))
.def("set_filelist", &framework::MultiSlotDataset::SetFileList) .def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::MultiSlotDataset::SetThreadNum) .def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::MultiSlotDataset::SetTrainerNum) .def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::MultiSlotDataset::SetDataFeedDesc) .def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::MultiSlotDataset::LoadIntoMemory) .def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::MultiSlotDataset::LocalShuffle) .def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::MultiSlotDataset::GlobalShuffle); .def("global_shuffle", &framework::Dataset::GlobalShuffle);
} }
} // end namespace pybind } // end namespace pybind
......
...@@ -37,7 +37,7 @@ class DatasetBase(object): ...@@ -37,7 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance # to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc() self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat" self.proto_desc.pipe_command = "cat"
self.dataset = core.MultiSlotDataset() self.dataset = core.Dataset("MultiSlotDataset")
self.thread_num = 0 self.thread_num = 0
def set_pipe_command(self, pipe_command): def set_pipe_command(self, pipe_command):
...@@ -119,10 +119,16 @@ class InMemoryDataset(DatasetBase): ...@@ -119,10 +119,16 @@ class InMemoryDataset(DatasetBase):
from .distributed import ps_instance from .distributed import ps_instance
instance = ps_instance.PaddlePSInstance(1, 2) instance = ps_instance.PaddlePSInstance(1, 2)
self.dataset.set_trainer_num(instance.get_worker_num()) self.dataset.set_trainer_num(instance.get_worker_num())
self.global_shuffle() self.dataset.global_shuffle()
class QueueDataset(DatasetBase): class QueueDataset(DatasetBase):
def __init__(self): def __init__(self):
super(QueueDataset, self).__init__() super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
def local_shuffle(self):
pass
def global_shuffle(self):
pass
...@@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc): ...@@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc):
def gen_trainer_desc(self): def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc() super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer" self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, fleet_desc_) self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_)
class DistMultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc):
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from trainer_desc import *
from device_worker import *
__all__ = ["TrainerFactory"] __all__ = ["TrainerFactory"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册