提交 ecfc7df9 编写于 作者: X xujiaqi01 提交者: dongdaxiang

add dataset factory && fix style

上级 328f11b8
......@@ -181,7 +181,7 @@ graph_to_program_pass variable_helper trainer_library data_feed_proto ${NGRAPH_E
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc
cc_library(executor SRCS executor.cc multi_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
......@@ -202,7 +202,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
data_set.cc dataset_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
......
......@@ -158,8 +158,6 @@ bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList();
if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
FillMemoryDataToChannel();
//std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
//std::vector<T>().swap(memory_data_);
}
DataFeed::finish_start_ = true;
return true;
......@@ -227,13 +225,13 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
DeserializeIns(ins, ins_str);
DeserializeIns(&ins, ins_str);
shuffled_ins_->Push(std::move(ins));
}
template <typename T>
void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
VLOG(3) << "InMemoryDataFeed<T>::FillMemoryDataToChannel, thread_id=" << thread_id_;
VLOG(3) << "FillMemoryDataToChannel, thread_id=" << thread_id_;
int64_t start = 0;
int64_t end = 0;
int64_t size = memory_data_->size();
......@@ -252,7 +250,7 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
template <typename T>
void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG(3) << "InMemoryDataFeed<T>::FillChannelToMemoryData, thread_id=" << thread_id_;
VLOG(3) << "FillChannelToMemoryData, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
if (cur_channel_ == 0) {
......@@ -274,11 +272,12 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() begin, thread_id=" << thread_id_;
VLOG(3) << "LoadIntoMemory() begin, thread_id=" << thread_id_;
std::vector<T> local_vec;
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_;
VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
......@@ -287,36 +286,38 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id=" << thread_id_;
VLOG(3) << "LoadIntoMemory() read all lines, file="
<< filename <<", thread_id=" << thread_id_;
{
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
memory_data_->insert(memory_data_->end(),
local_vec.begin(), local_vec.end());
}
std::vector<T>().swap(local_vec);
}
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() end, thread_id=" << thread_id_;
VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_;
}
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() begin, thread_id=" << thread_id_;
VLOG(3) << "LocalShuffle() begin, thread_id=" << thread_id_;
FillMemoryDataToChannel();
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() end, thread_id=" << thread_id_;
VLOG(3) << "LocalShuffle() end, thread_id=" << thread_id_;
}
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle() {
VLOG(3) << "GlobalShuffle(), thread_id=" << thread_id_;
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::string> send_str_vec(trainer_num_);
for (int64_t i = 0; i < memory_data_->size(); ++i) {
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
// std::string ins_id = memory_data_[i].ins_id;
// todo hash
//int64_t hash_id = paddle::ps::local_random_engine()();
int64_t hash_id = 0;
int64_t node_id = hash_id % trainer_num_;
int64_t random_num = fleet_ptr->local_random_engine()();
int64_t node_id = random_num % trainer_num_;
std::string str;
SerializeIns((*memory_data_)[i], str);
SerializeIns((*memory_data_)[i], &str);
send_str_vec[node_id] += str;
if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_str_vec.size(); ++j) {
......@@ -821,12 +822,12 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<MultiSlotType>& ins, std::string& str) {
const std::vector<MultiSlotType>& ins, std::string* str) {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Serialize(ins, str);
}
// todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins,
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>* ins,
const std::string& str) {
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->Deserialize(ins, str);
......
......@@ -212,13 +212,16 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void LoadIntoMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0;
virtual void AddInstanceToInsVec(T* vec_ins,
const T& instance,
int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const T& ins, std::string& str) = 0;
virtual void DeserializeIns(T& ins, const std::string& str) = 0;
virtual void SerializeIns(const T& ins, std::string* str) = 0;
virtual void DeserializeIns(T* ins, const std::string& str) = 0;
int thread_id_;
int thread_num_;
......@@ -284,6 +287,28 @@ class MultiSlotType {
const std::string& GetType() const { return type_; }
std::string& MutableType() { return type_; }
std::string DebugString() {
std::stringstream ss;
ss << "type: " << type_ << "\n";
ss << "offset:\n";
ss << "[";
for (const size_t& i : offset_) {
ss << offset_[i] << ",";
}
ss << "]\ndata:\n[";
if (type_[0] == 'f') {
for (const float& i : float_feasign_) {
ss << i << ",";
}
} else {
for (const uint64_t& i : uint64_feasign_) {
ss << i << ",";
}
}
ss << "]\n";
return ss.str();
}
private:
void CheckType(const std::string& type) const {
PADDLE_ENFORCE((type == "uint64") || (type == "float"),
......@@ -336,8 +361,10 @@ class MultiSlotInMemoryDataFeed
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins, std::string& str);
virtual void DeserializeIns(std::vector<MultiSlotType>& ins, const std::string& str);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins,
std::string* str);
virtual void DeserializeIns(std::vector<MultiSlotType>* ins,
const std::string& str);
};
} // namespace framework
......
......@@ -54,7 +54,9 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) {
}
template <typename T>
void DatasetImpl<T>::SetTrainerNum(int trainer_num) { trainer_num_ = trainer_num; }
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
template <typename T>
void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
......@@ -115,10 +117,12 @@ void DatasetImpl<T>::GlobalShuffle() {
// if it is not InMemory, memory_data_ is empty
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "registe_client2client_msg_handler";
fleet_ptr->registe_client2client_msg_handler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "start global shuffle threads";
std::vector<std::thread> global_shuffle_threads;
for (int i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/dataset_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
datasetMap g_dataset_map;
#define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \
} \
class __Registerer_##dataset_class { \
public: \
__Registerer_##dataset_class() { \
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
} \
}; \
__Registerer_##dataset_class g_registerer_##dataset_class; \
} // namespace
std::string DatasetFactory::DatasetTypeList() {
std::string dataset_types;
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end();
++iter) {
if (iter != g_dataset_map.begin()) {
dataset_types += ", ";
}
dataset_types += iter->first;
}
return dataset_types;
}
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
std::string dataset_class) {
if (g_dataset_map.count(dataset_class) < 1) {
LOG(WARNING) << "Your Dataset " << dataset_class
<< "is not supported currently";
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
exit(-1);
}
return g_dataset_map[dataset_class]();
}
REGISTER_DATASET_CLASS(MultiSlotDataset);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
class DatasetFactory {
public:
static std::string DatasetTypeList();
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
};
} // namespace framework
} // namespace paddle
......@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
......@@ -45,7 +46,7 @@ paddle::ps::Archive<AR>& operator << (
ar << ins.GetOffset();
ar << ins.GetFloatData();
ar << ins.GetUint64Data();
return ar;
return ar;
}
template<class AR>
......@@ -56,7 +57,7 @@ paddle::ps::Archive<AR>& operator >> (
ar >> ins.MutableOffset();
ar >> ins.MutableFloatData();
ar >> ins.MutableUint64Data();
return ar;
return ar;
}
#endif
......@@ -291,42 +292,63 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
// todo registe_client2client_msg_handler
int FleetWrapper::registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler) {
return 0;
int FleetWrapper::registe_client2client_msg_handler(
int msg_type, MsgHandlerFunc handler) {
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(
msg_type, handler);
return 0;
}
// todo send_client2client_msg
int FleetWrapper::send_client2client_msg(int msg_type, int to_client_id, const std::string& msg) {
return 0;
int FleetWrapper::send_client2client_msg(
int msg_type, int to_client_id, const std::string& msg) {
pslib_ptr_->_worker_ptr->send_client2client_msg(
msg_type, to_client_id, msg);
return 0;
}
std::default_random_engine& FleetWrapper::local_random_engine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
struct timespec tp;
clock_gettime(CLOCK_REALTIME, &tp);
double cur_time = tp.tv_sec + tp.tv_nsec * 1e-9;
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++,
(uint64_t)(cur_time * 1000)};
engine.seed(sseq);
}
};
thread_local engine_wrapper_t r;
return r.engine;
}
template<typename T>
void FleetWrapper::Serialize(const T& t, std::string& str) {
void FleetWrapper::Serialize(const T& t, std::string* str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar << t;
str = std::string(ar.buffer(), ar.length());
*str = std::string(ar.buffer(), ar.length());
#else
VLOG(0) << "FleetWrapper::Serialize do nothing when no pslib";
#endif
}
template<typename T>
void FleetWrapper::Deserialize(T& t, const std::string& str) {
void FleetWrapper::Deserialize(T* t, const std::string& str) {
#ifdef PADDLE_WITH_PSLIB
paddle::ps::BinaryArchive ar;
ar.set_read_buffer(const_cast<char*>(str.c_str()), str.length(), nullptr);
t = ar.get<T>();
*t = ar.get<T>();
#else
VLOG(0) << "FleetWrapper::Deserialize do nothing when no pslib";
#endif
}
template void FleetWrapper::Serialize<std::vector<MultiSlotType>>(
const std::vector<MultiSlotType>&, std::string&);
const std::vector<MultiSlotType>&, std::string*);
template void FleetWrapper::Deserialize(
std::vector<MultiSlotType>&, const std::string&);
std::vector<MultiSlotType>*, const std::string&);
} // end namespace framework
} // end namespace paddle
......@@ -21,7 +21,7 @@ limitations under the License. */
#endif
#include <random>
#include <atomic>
#include <time.h>
#include <ctime>
#include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h"
......@@ -116,13 +116,15 @@ class FleetWrapper {
typedef std::function<int32_t (int, int, const std::string&)> MsgHandlerFunc;
int registe_client2client_msg_handler(int msg_type, MsgHandlerFunc handler);
int send_client2client_msg(int msg_type, int to_client_id, const std::string& msg);
int send_client2client_msg(int msg_type,
int to_client_id,
const std::string& msg);
std::default_random_engine& local_random_engine();
template<typename T>
void Serialize(const T& t, std::string& str);
void Serialize(const T& t, std::string* str);
template<typename T>
void Deserialize(T& t, const std::string& str);
void Deserialize(T* t, const std::string& str);
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
......
......@@ -65,6 +65,7 @@ void MultiTrainer::Finalize() {
for (auto& th : threads_) {
th.join();
}
// todo dataset->DestroyReaders();
}
} // end namespace framework
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#endif
#include <string>
#include <vector>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
......@@ -33,6 +33,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/framework/dataset_factory.h"
namespace py = pybind11;
namespace pd = paddle::framework;
......@@ -41,17 +42,18 @@ namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::MultiSlotDataset>(*m, "MultiSlotDataset")
.def(py::init([]() {
return std::unique_ptr<framework::MultiSlotDataset>(new framework::MultiSlotDataset());
py::class_<framework::Dataset,
std::shared_ptr<framework::Dataset>>(*m, "Dataset")
.def(py::init([](const std::string& name = "MultiSlotDataset") {
return framework::DatasetFactory::CreateDataset(name);
}))
.def("set_filelist", &framework::MultiSlotDataset::SetFileList)
.def("set_thread_num", &framework::MultiSlotDataset::SetThreadNum)
.def("set_trainer_num", &framework::MultiSlotDataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::MultiSlotDataset::SetDataFeedDesc)
.def("load_into_memory", &framework::MultiSlotDataset::LoadIntoMemory)
.def("local_shuffle", &framework::MultiSlotDataset::LocalShuffle)
.def("global_shuffle", &framework::MultiSlotDataset::GlobalShuffle);
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle);
}
} // end namespace pybind
......
......@@ -37,7 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance
self.proto_desc = data_feed_pb2.DataFeedDesc()
self.proto_desc.pipe_command = "cat"
self.dataset = core.MultiSlotDataset()
self.dataset = core.Dataset("MultiSlotDataset")
self.thread_num = 0
def set_pipe_command(self, pipe_command):
......@@ -119,10 +119,16 @@ class InMemoryDataset(DatasetBase):
from .distributed import ps_instance
instance = ps_instance.PaddlePSInstance(1, 2)
self.dataset.set_trainer_num(instance.get_worker_num())
self.global_shuffle()
self.dataset.global_shuffle()
class QueueDataset(DatasetBase):
def __init__(self):
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
def local_shuffle(self):
pass
def global_shuffle(self):
pass
......@@ -20,7 +20,7 @@ from google.protobuf import text_format
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer']
# can be initialized from train_desc,
# can be initialized from train_desc,
class TrainerDesc(object):
def __init__(self):
'''
......@@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc):
def gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.gen_worker_desc(self.proto_desc, fleet_desc_)
self.device_worker_.gen_worker_desc(self.proto_desc, self.fleet_desc_)
class DistMultiTrainer(TrainerDesc):
......
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from trainer_desc import *
from device_worker import *
__all__ = ["TrainerFactory"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册