提交 271b7147 编写于 作者: X xjqbest

fix dataset bug

test=develop
上级 1c526e1d
......@@ -237,11 +237,21 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetTrainerId(int trainer_id) {
trainer_id_ = trainer_id;
}
template <typename T>
void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetFleetSendBatchSize(int64_t size) {
fleet_send_batch_size_ = size;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
#ifdef _LINUX
......@@ -361,8 +371,15 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_;
auto fleet_ptr = FleetWrapper::GetInstance();
std::vector<std::vector<T*>> send_vec(trainer_num_);
std::vector<int> send_index(trainer_num_);
std::vector<T> local_send_vec;
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_;
for (auto& vec : send_vec) {
vec.reserve(fleet_send_batch_size_);
vec.reserve(reserve_len);
}
local_send_vec.reserve(reserve_len);
for (int i = 0; i < trainer_num_; ++i) {
send_index[i] = i;
}
std::vector<std::future<int32_t>> total_status;
auto interval = GetMemoryDataInterval();
......@@ -373,9 +390,23 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
// std::string ins_id = memory_data_[i].ins_id;
int64_t random_num = rand_r(&rand_seed);
int64_t node_id = random_num % trainer_num_;
send_vec[node_id].push_back(&((*memory_data_)[i]));
if (node_id == trainer_id_) {
local_send_vec.push_back((*memory_data_)[i]);
} else {
send_vec[node_id].push_back(&((*memory_data_)[i]));
}
if (i % fleet_send_batch_size_ == 0 && i != 0) {
for (int j = 0; j < send_vec.size(); ++j) {
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
if (j == trainer_id_) {
VLOG(3) << "send to local, ins num=" << local_send_vec.size()
<< ", node_id=" << j << ", thread_id=" << thread_id_;
shuffled_ins_->Extend(std::move(local_send_vec));
local_send_vec.clear();
continue;
}
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length()
......@@ -388,8 +419,14 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
}
}
}
for (int j = 0; j < send_vec.size(); ++j) {
if (send_vec[j].size() != 0) {
// shuffle the sequence of sending to avoid network timeout error
std::random_shuffle(send_index.begin(), send_index.end());
for (int index = 0; index < send_index.size(); ++index) {
int j = send_index[index];
if (j == trainer_id_ && local_send_vec.size() != 0) {
shuffled_ins_->Extend(std::move(local_send_vec));
std::vector<T>().swap(local_send_vec);
} else if (send_vec[j].size() != 0) {
std::string send_str;
SerializeIns(send_vec[j], &send_str);
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
......
......@@ -91,9 +91,13 @@ class DataFeed {
// This function will do nothing at default
virtual void SetThreadId(int thread_id) {}
// This function will do nothing at default
virtual void SetTrainerId(int trainer_id) {}
// This function will do nothing at default
virtual void SetThreadNum(int thread_num) {}
// This function will do nothing at default
virtual void SetTrainerNum(int trainer_num) {}
// This function will do nothing at default
virtual void SetFleetSendBatchSize(int64_t size) {}
virtual void SetFileListMutex(std::mutex* mutex) {
mutex_for_pick_file_ = mutex;
}
......@@ -211,7 +215,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual void SetMemoryDataMutex(std::mutex* mutex);
virtual void SetThreadId(int thread_id);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerId(int trainer_id);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
virtual void PutInsToChannel(const std::string& ins_str);
virtual void FillMemoryDataToChannel();
virtual void FillChannelToMemoryData();
......@@ -231,6 +237,7 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
int thread_id_;
int thread_num_;
int trainer_id_;
int trainer_num_;
uint32_t rand_seed;
std::vector<T>* memory_data_;
......
......@@ -52,6 +52,17 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerId
template <typename T>
void DatasetImpl<T>::SetTrainerId(int trainer_id) {
trainer_id_ = trainer_id;
for (auto reader : readers_) {
reader->SetTrainerId(trainer_id);
}
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum
......@@ -64,6 +75,17 @@ void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
}
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetFleetSendBatchSize
template <typename T>
void DatasetImpl<T>::SetFleetSendBatchSize(int64_t size) {
fleet_send_batch_size_ = size;
for (auto reader : readers_) {
reader->SetFleetSendBatchSize(size);
}
}
template <typename T>
void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
......
......@@ -45,8 +45,12 @@ class Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num
virtual void SetThreadNum(int thread_num) = 0;
// set worker rank
virtual void SetTrainerId(int trainer_id) = 0;
// set workers' num
virtual void SetTrainerNum(int trainer_num) = 0;
// set fleet send batch size
virtual void SetFleetSendBatchSize(int64_t size) = 0;
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
......@@ -57,8 +61,12 @@ class Dataset {
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
virtual int GetThreadNum() = 0;
// get worker rank
virtual int GetTrainerId() = 0;
// get worker num
virtual int GetTrainerNum() = 0;
// get fleet send batch size
virtual int64_t GetFleetSendBatchSize() = 0;
// get hdfs config
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
// get data fedd desc
......@@ -97,14 +105,18 @@ class DatasetImpl : public Dataset {
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerId(int trainer_id);
virtual void SetTrainerNum(int trainer_num);
virtual void SetFleetSendBatchSize(int64_t size);
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi);
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
virtual int GetTrainerId() { return trainer_id_; }
virtual int GetTrainerNum() { return trainer_num_; }
virtual int64_t GetFleetSendBatchSize() { return fleet_send_batch_size_; }
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
......@@ -130,6 +142,7 @@ class DatasetImpl : public Dataset {
std::mutex mutex_for_update_memory_data_;
int thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_id_;
int trainer_num_;
std::vector<std::string> filelist_;
size_t file_idx_;
......@@ -137,6 +150,7 @@ class DatasetImpl : public Dataset {
std::string fs_name_;
std::string fs_ugi_;
unsigned int rand_seed;
int64_t fleet_send_batch_size_;
};
// use std::vector<MultiSlotType> as data type
......
......@@ -49,12 +49,18 @@ void BindDataset(py::module* m) {
}))
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_id", &framework::Dataset::SetTrainerId)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_fleet_send_batch_size",
&framework::Dataset::SetFleetSendBatchSize)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("get_filelist", &framework::Dataset::GetFileList)
.def("get_thread_num", &framework::Dataset::GetThreadNum)
.def("get_trainer_id", &framework::Dataset::GetTrainerId)
.def("get_trainer_num", &framework::Dataset::GetTrainerNum)
.def("get_fleet_send_batch_size",
&framework::Dataset::GetFleetSendBatchSize)
.def("get_hdfs_config", &framework::Dataset::GetHdfsConfig)
.def("get_data_feed_desc", &framework::Dataset::GetDataFeedDesc)
.def("register_client2client_msg_handler",
......
......@@ -796,7 +796,7 @@ class Executor(object):
if dataset == None:
raise RuntimeError("dataset is need and should be initialized")
if self.place == paddle.fluid.CUDAPlace():
if not isinstance(self.place, core.CPUPlace):
raise RuntimeError("train_from_dataset is verified on CPUPlace"
"We will open CUDAPlace in the future")
......
......@@ -29,7 +29,6 @@ class TestDataset(unittest.TestCase):
def test_dataset_create(self):
""" Testcase for dataset create. """
return
try:
dataset = fluid.DatasetFactory().create_dataset("InMemoryDataset")
except:
......@@ -48,7 +47,6 @@ class TestDataset(unittest.TestCase):
def test_dataset_config(self):
""" Testcase for dataset configuration. """
return
dataset = fluid.core.Dataset("MultiSlotDataset")
dataset.set_thread_num(12)
dataset.set_filelist(["a.txt", "b.txt", "c.txt"])
......@@ -75,7 +73,6 @@ class TestDataset(unittest.TestCase):
"""
Testcase for InMemoryDataset from create to run.
"""
return
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
......@@ -113,8 +110,7 @@ class TestDataset(unittest.TestCase):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
#self.assertTrue(False)
pass
self.assertTrue(False)
os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")
......@@ -123,7 +119,6 @@ class TestDataset(unittest.TestCase):
"""
Testcase for QueueDataset from create to run.
"""
return
with open("test_queue_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
......@@ -157,14 +152,11 @@ class TestDataset(unittest.TestCase):
try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
#self.assertTrue(False)
pass
self.assertTrue(False)
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
if __name__ == '__main__':
#unittest.main()
import sys
sys.exit(0)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册