提交 824b84d1 编写于 作者: X xjqbest 提交者: dongdaxiang

add DataSet and InMemoryDataFeed, support load data into memory and shuffle data

上级 08c25995
......@@ -199,6 +199,7 @@ if(WITH_PSLIB)
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
......@@ -208,6 +209,7 @@ else()
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
......
......@@ -154,5 +154,14 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return;
}
// todo RunFromDataset
void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program,
Dataset* data_set,
const std::string& trainer_desc_str,
const bool debug) {
}
} // einit_modelnd namespace framework
} // end namespace paddle
......@@ -30,6 +30,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
......
......@@ -33,6 +33,14 @@ class BlockingQueue {
cv_.notify_one();
}
void Push(T &&item) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(std::move(item));
}
cv_.notify_one();
}
template <typename U>
void Extend(const U &items) {
{
......@@ -44,6 +52,17 @@ class BlockingQueue {
cv_.notify_all();
}
template <typename U>
void Extend(U &&items) {
{
std::lock_guard<std::mutex> g(mutex_);
for (auto &item : items) {
q_.emplace_back(std::move(item));
}
}
cv_.notify_all();
}
std::deque<T> PopAll(size_t ms, bool *timeout) {
auto time =
std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
......@@ -64,6 +83,18 @@ class BlockingQueue {
return rc;
}
void Pop(T &t) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [=] { return !q_.empty(); });
t = std::move(q_.front());
q_.pop_front();
}
size_t Size() {
std::lock_guard<std::mutex> lock(mutex_);
return q_.size();
}
private:
std::mutex mutex_;
std::condition_variable cv_;
......
......@@ -139,6 +139,109 @@ int PrivateQueueDataFeed<T>::Next() {
template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
#endif
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_ = 0;
shuffled_ins_ = nullptr;
shuffled_ins_out_ = nullptr;
}
template <typename T>
bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList();
if (memory_data_.size() != 0) {
CHECK(cur_channel_ == 0);
shuffled_ins_->Extend(std::move(memory_data_));
std::vector<T>().swap(memory_data_);
}
DataFeed::finish_start_ = true;
return true;
}
template <typename T>
int InMemoryDataFeed<T>::Next() {
DataFeed::CheckStart();
std::shared_ptr<paddle::framework::BlockingQueue<T>> in_channel = nullptr;
std::shared_ptr<paddle::framework::BlockingQueue<T>> out_channel = nullptr;
if (cur_channel_ == 0) {
in_channel = shuffled_ins_;
out_channel = shuffled_ins_out_;
} else {
in_channel = shuffled_ins_out_;
out_channel = shuffled_ins_;
}
CHECK(in_channel != nullptr);
CHECK(out_channel != nullptr);
int index = 0;
T instance;
T ins_vec;
while (index < DataFeed::default_batch_size_) {
if (in_channel->Size() == 0) {
break;
}
in_channel->Pop(instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
cur_channel_ = 1 - cur_channel_;
}
return DataFeed::batch_size_;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
DeserializeIns(ins, ins_str);
shuffled_ins_->Push(std::move(ins));
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
std::vector<T> local_vec;
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ = fs_open_read(filename, &err_no,
PrivateQueueDataFeed<T>::pipe_command_);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance;
while(ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end());
std::vector<T>().swap(local_vec);
}
}
template <typename T>
void InMemoryDataFeed<T>::LocalShuffle() {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
}
// todo global shuffle
/*
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
for (int64_t i = 0; i < memory_data_.size(); ++i) {
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id);
int64_t node_id = hash_id % trainer_num_;
std::string str;
SerializeIns(memory_data_[i], str);
auto fleet_ptr = FleetWrapper::GetInstance();
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
}
}
*/
void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
......@@ -445,5 +548,190 @@ void MultiSlotDataFeed::PutToFeedVec(
}
}
void MultiSlotInMemoryDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
SetQueueSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
}
}
feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
}
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
std::vector<MultiSlotType>* instance) {
thread_local string::LineFileReader reader;
if (!reader.getline(&*(fp_.get()))) {
return false;
} else {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
const char* str = reader.get();
std::string line = std::string(str);
VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
(*instance)[idx].AddValue(feasign);
}
} else if ((*instance)[idx].GetType()[0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
(*instance)[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
// pos = line.find_first_of(' ', pos + 1);
while (line[pos + 1] != ' ') {
pos++;
}
}
}
}
return true;
}
}
bool MultiSlotInMemoryDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
std::string line;
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
// parse line
const char* str = line.c_str();
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
(*instance)[idx].AddValue(feasign);
}
} else if ((*instance)[idx].GetType()[0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
(*instance)[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
pos = line.find_first_of(' ', pos + 1);
}
}
}
} else {
return false;
}
return true;
}
void MultiSlotInMemoryDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) {
if (index == 0) {
ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].Init(instance[i].GetType());
(*ins_vec)[i].InitOffset();
}
}
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]);
}
}
void MultiSlotInMemoryDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) {
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData();
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
}
}
}
// todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns(const std::vector<MultiSlotType>& ins, std::string& str) {
}
// todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins, const std::string& str) {
}
} // namespace framework
} // namespace paddle
......@@ -27,6 +27,8 @@ limitations under the License. */
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
namespace paddle {
namespace framework {
......@@ -76,6 +78,19 @@ class DataFeed {
// This function is used for binding feed_vec memory
virtual void AddFeedVar(Variable* var, const std::string& name);
virtual void LoadIntoMemory() {
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
}
virtual void LocalShuffle() {
PADDLE_THROW("This function(LocalShuffle) is not implemented.");
}
virtual void GlobalShuffle(int trainer_num) {
PADDLE_THROW("This function(GlobalShuffle) is not implemented.");
}
virtual void PutInsToChannel(const std::string& ins_str) {
PADDLE_THROW("This function(PutToChannel) is not implemented.");
}
protected:
// The following three functions are used to check if it is executed in this
// order:
......@@ -161,6 +176,35 @@ class PrivateQueueDataFeed : public DataFeed {
std::unique_ptr<paddle::operators::reader::BlockingQueue<T>> queue_;
};
template <typename T>
class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
InMemoryDataFeed();
virtual ~InMemoryDataFeed() {}
virtual bool Start();
virtual int Next();
virtual void PutInsToChannel(const std::string& ins_str);
virtual void LoadIntoMemory();
virtual void LocalShuffle();
// todo global shuffle
//virtual void GlobalShuffle(int trainer_num);
protected:
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, int index) = 0;
virtual bool ParseOneInstance(T* instance) = 0;
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
virtual void PutToFeedVec(const T& ins_vec) = 0;
virtual void SerializeIns(const T& ins, std::string& str) = 0;
virtual void DeserializeIns(T& ins, const std::string& str) = 0;
std::vector<T> memory_data_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int cur_channel_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_;
std::shared_ptr<paddle::framework::BlockingQueue<T>> shuffled_ins_out_;
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
class MultiSlotType {
public:
......@@ -245,5 +289,23 @@ class MultiSlotDataFeed
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
};
class MultiSlotInMemoryDataFeed
: public InMemoryDataFeed<std::vector<MultiSlotType>> {
public:
MultiSlotInMemoryDataFeed() {}
virtual ~MultiSlotInMemoryDataFeed() {}
virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc);
protected:
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
const std::vector<MultiSlotType>& instance,
int index);
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
virtual void SerializeIns(const std::vector<MultiSlotType>& ins, std::string& str);
virtual void DeserializeIns(std::vector<MultiSlotType>& ins, const std::string& str);
};
} // namespace framework
} // namespace paddle
......@@ -60,5 +60,6 @@ std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
}
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
REGISTER_DATAFEED_CLASS(MultiSlotInMemoryDataFeed);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
namespace paddle {
namespace framework {
Dataset::Dataset() {
thread_num_ = 1;
}
void Dataset::SetFileList(const std::vector<std::string>& filelist) {
filelist_ = filelist;
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_ << ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt;
}
}
void Dataset::SetThreadNum(int thread_num) {
int file_cnt = filelist_.size();
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num << ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
}
thread_num_ = thread_num;
}
void Dataset::SetTrainerNum(int trainer_num) {
trainer_num_ = trainer_num;
}
void Dataset::SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc) {
data_feed_desc_ = data_feed_desc;
}
std::vector<std::shared_ptr<paddle::framework::DataFeed>> Dataset::GetReaders() {
return readers_;
}
void Dataset::LoadIntoMemory() {
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> load_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
load_threads.push_back(std::thread(&paddle::framework::DataFeed::LoadIntoMemory,
readers_[i].get()));
}
for (std::thread& t : load_threads) {
t.join();
}
}
void Dataset::LocalShuffle() {
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> local_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
local_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::LocalShuffle,
readers_[i].get()));
}
for (std::thread& t : local_shuffle_threads) {
t.join();
}
}
// todo global shuffle
void Dataset::GlobalShuffle() {
/*
auto fleet_ptr = FleetWrapper::GetInstance();
fleet_ptr->registe_client2client_msg_handler(0,
[this](int msg_type, int client_id, const std::string& msg) -> int {
return this->ReceiveFromClient(msg_type, client_id, msg);
});
if (readers_.size() == 0) {
CreateReaders();
}
std::vector<std::thread> global_shuffle_threads;
for (int64_t i = 0; i < thread_num_; ++i) {
global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle,
readers_[i].get(), trainer_num_));
}
for (std::thread& t : global_shuffle_threads) {
t.join();
}*/
}
void Dataset::CreateReaders() {
CHECK(thread_num_ > 0) << "thread_num should > 0";
if (readers_.size() != 0) {
return;
}
for (int64_t i = 0; i < thread_num_; ++i) {
readers_.push_back(DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
readers_.back()->Init(data_feed_desc_);
}
readers_[0]->SetFileList(filelist_);
}
int Dataset::ReceiveFromClient(int msg_type, int client_id, const std::string& msg) {
// can also use hash
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
// todo
int64_t index = 0;
readers_[index]->PutInsToChannel(msg);
return 0;
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
class Dataset {
public:
Dataset();
virtual ~Dataset() {}
virtual void SetFileList(const std::vector<std::string>& filelist);
virtual void SetThreadNum(int thread_num);
virtual void SetTrainerNum(int trainer_num);
virtual void SetDataFeedDesc(const paddle::framework::DataFeedDesc& data_feed_desc);
virtual const std::vector<std::string>& GetFileList() {
return filelist_;
}
virtual int GetThreadNum() {
return thread_num_;
}
virtual int GetTrainerNum() {
return trainer_num_;
}
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
}
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>> GetReaders();
virtual void LoadIntoMemory();
virtual void LocalShuffle();
// todo global shuffle
virtual void GlobalShuffle();
virtual void CreateReaders();
protected:
virtual int ReceiveFromClient(int msg_type, int client_id, const std::string& msg);
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
int thread_num_;
std::string fs_name_;
std::string fs_ugi_;
paddle::framework::DataFeedDesc data_feed_desc_;
std::vector<std::string> filelist_;
int trainer_num_;
};
}
}
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) {
thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_);
readers_.resize(thread_num_);
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
......@@ -115,7 +116,7 @@ class Executor {
const std::string& trainer_desc_str,
const bool debug);
void RunFromDataset(const ProgramDesc& main_program, const Dataset* dataset,
void RunFromDataset(const ProgramDesc& main_program, Dataset* dataset,
const std::string& trainer_desc_str, const bool debug);
public:
......
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
......@@ -40,7 +41,7 @@ class TrainerBase {
// model memory are hosted in root_scope
void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; }
virtual void Initialize(const TrainerDesc& trainer_desc) = 0;
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
......@@ -59,7 +60,7 @@ class MultiTrainer : public TrainerBase {
public:
MultiTrainer() {}
virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc);
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
......@@ -77,7 +78,7 @@ class DistMultiTrainer : public MultiTrainer {
public:
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc);
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize();
......
......@@ -49,6 +49,7 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <string>
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/framework/data_set.h"
namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle {
namespace pybind {
void BindDataset(py::module* m) {
py::class_<framework::DataSet>(*m, "Dataset")
.def(py::init([]() {
return std::unique_ptr<framework::Dataset>(
new framework::Dataset());
}))
.def("set_filelist", &framework::Dataset::SetFileList)
.def("set_thread_num", &framework::Dataset::SetThreadNum)
.def("set_trainer_num", &framework::Dataset::SetTrainerNum)
.def("set_data_feed_desc", &framework::Dataset::SetDataFeedDesc)
.def("load_into_memory", &framework::Dataset::LoadIntoMemory)
.def("local_shuffle", &framework::Dataset::LocalShuffle)
.def("global_shuffle", &framework::Dataset::GLobalShuffle)
}
} // end namespace pybind
} // end namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindDataset(py::module* m);
} // namespace pybind
} // namespace paddle
......@@ -61,6 +61,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h"
#include "paddle/fluid/pybind/data_set_py.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
......@@ -1359,6 +1360,7 @@ All parameter, weight, gradient are variables in Paddle.
BindGraph(&m);
BindNode(&m);
BindInferenceApi(&m);
BindDataset(&m);
}
} // namespace pybind
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册