提交 be74de2c 编写于 作者: X xjqbest 提交者: dongdaxiang

fix code style & fix register bug & add release_memory

test=develop
上级 a0b59773
......@@ -83,10 +83,10 @@ class BlockingQueue {
return rc;
}
void Pop(T &t) {
void Pop(T *t) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
t = std::move(q_.front());
*t = std::move(q_.front());
q_.pop_front();
}
......
......@@ -48,7 +48,7 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) {
return false;
}
*/
//PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
// PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_.assign(files.begin(), files.end());
finish_set_filelist_ = true;
......@@ -190,7 +190,8 @@ int InMemoryDataFeed<T>::Next() {
if (in_channel->Size() == 0) {
break;
}
in_channel->Pop(instance);
in_channel->Pop(&instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
......@@ -268,17 +269,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
}
CHECK(channel != nullptr);
CHECK(pre_channel != nullptr);
CHECK(pre_channel->Size() == 0);
CHECK_EQ(pre_channel->Size(), 0);
local_vec.resize(channel->Size());
for (int64_t i = 0; i < local_vec.size(); ++i) {
channel->Pop(local_vec[i]);
channel->Pop(&local_vec[i]);
}
VLOG(3) << "local_vec size=" << local_vec.size() <<", thread_id=" << thread_id_;
VLOG(3) << "local_vec size=" << local_vec.size()
<<", thread_id=" << thread_id_;
{
std::lock_guard<std::mutex> g(*mutex_for_update_memory_data_);
VLOG(3) << "before insert, memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_;
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
memory_data_->insert(memory_data_->end(), local_vec.begin(),
local_vec.end());
VLOG(3) << "after insert memory_data_ size=" << memory_data_->size()
<< ", thread_id=" << thread_id_;
}
......@@ -574,7 +577,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get();
std::string line = std::string(str);
//VLOG(3) << line;
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
......@@ -750,7 +753,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
const char* str = reader.get();
std::string line = std::string(str);
//VLOG(3) << line;
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
......
......@@ -21,7 +21,8 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include <sstream>
#include <future>
#include <future> // NOLINT
#include <utility>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
......
......@@ -82,6 +82,18 @@ DatasetImpl<T>::GetReaders() {
return readers_;
}
// if sent message between workers, should first call this function
template <typename T>
void DatasetImpl<T>::RegisterClientToClientMsgHandler() {
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
VLOG(3) << "RegisterClientToClientMsgHandler done";
}
// load data into memory, Dataset hold this memory,
// which will later be fed into readers' channel
template <typename T>
......@@ -106,6 +118,14 @@ void DatasetImpl<T>::LoadIntoMemory() {
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
}
// release memory data
template <typename T>
void DatasetImpl<T>::ReleaseMemory() {
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() begin";
std::vector<T>().swap(memory_data_);
VLOG(3) << "DatasetImpl<T>::ReleaseMemory() end";
}
// do local shuffle
template <typename T>
void DatasetImpl<T>::LocalShuffle() {
......@@ -137,12 +157,6 @@ void DatasetImpl<T>::GlobalShuffle() {
VLOG(3) << "DatasetImpl<T>::GlobalShuffle() begin";
platform::Timer timeline;
timeline.Start();
auto fleet_ptr = FleetWrapper::GetInstance();
VLOG(3) << "RegisterClientToClientMsgHandler";
fleet_ptr->RegisterClientToClientMsgHandler(
0, [this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
if (readers_.size() == 0) {
CreateReaders();
}
......
......@@ -40,22 +40,43 @@ class Dataset {
public:
Dataset() {}
virtual ~Dataset() {}
// set file list
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
// set readers' num
virtual void SetThreadNum(int thread_num) = 0;
// set workers' num
virtual void SetTrainerNum(int trainer_num) = 0;
// set fs name and ugi
virtual void SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) = 0;
// set data fedd desc, which contains:
// data feed name, batch size, slots
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
// get file list
virtual const std::vector<std::string>& GetFileList() = 0;
// get thread num
virtual int GetThreadNum() = 0;
// get worker num
virtual int GetTrainerNum() = 0;
// get data fedd desc
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
// get readers, the reader num depend both on thread num
// and filelist size
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders() = 0;
// register message handler between workers
virtual void RegisterClientToClientMsgHandler() = 0;
// load all data into memory
virtual void LoadIntoMemory() = 0;
// release all memory data
virtual void ReleaseMemory() = 0;
// local shuffle data
virtual void LocalShuffle() = 0;
// global shuffle data
virtual void GlobalShuffle() = 0;
// create readers
virtual void CreateReaders() = 0;
// destroy readers
virtual void DestroyReaders() = 0;
protected:
......@@ -84,10 +105,12 @@ class DatasetImpl : public Dataset {
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
GetReaders();
virtual void RegisterClientToClientMsgHandler();
virtual void LoadIntoMemory();
virtual void ReleaseMemory();
virtual void LocalShuffle();
virtual void GlobalShuffle();
virtual void CreateReaders();
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#endif
#include <string>
#include <vector>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
......@@ -49,7 +50,6 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
//.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer)
......
......@@ -52,7 +52,10 @@ void BindDataset(py::module* m) {
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_hdfs_config", &framework::Dataset::SetHdfsConfig)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("register_client2client_msg_handler",
&framework::Dataset::RegisterClientToClientMsgHandler)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("release_memory", &framework::Dataset::ReleaseMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GlobalShuffle);
}
......
......@@ -237,7 +237,10 @@ class InMemoryDataset(DatasetBase):
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
trainer_num = fleet.worker_num()
self.dataset.register_client2client_msg_handler()
self.dataset.set_trainer_num(trainer_num)
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
self.dataset.global_shuffle()
if fleet is not None:
fleet.fleet_instance.role_maker_.barrier_worker()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册