提交 24863897 编写于 作者: D dongdaxiang

add RunFromDataset in executor

上级 e36bbcc8
......@@ -28,7 +28,7 @@ add_subdirectory(common)
add_subdirectory(io)
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
proto_library(data_feed_proto SRCS data_feed.proto)
proto_library(trainer_desc_proto SRCS trainer_desc.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
......@@ -175,15 +175,11 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} trainer_library)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
if(WITH_NGRAPH)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper)
else(WITH_NGRAPH)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(WITH_NGRAPH)
cc_library(executor SRCS executor.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS} timer)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
......@@ -194,28 +190,15 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper)
if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer)
endif(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto)
......
......@@ -154,14 +154,5 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
return;
}
// todo RunFromDataset
void AsyncExecutor::RunFromDataset(const ProgramDesc& main_program,
Dataset* data_set,
const std::string& trainer_desc_str,
const bool debug) {
}
} // einit_modelnd namespace framework
} // end namespace framework
} // end namespace paddle
......@@ -135,9 +135,7 @@ int PrivateQueueDataFeed<T>::Next() {
return batch_size_;
}
#ifdef _WIN32
template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
#endif
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
......@@ -150,7 +148,7 @@ template <typename T>
bool InMemoryDataFeed<T>::Start() {
DataFeed::CheckSetFileList();
if (memory_data_.size() != 0) {
CHECK(cur_channel_ == 0);
CHECK_EQ(cur_channel_, 0);
shuffled_ins_->Extend(std::move(memory_data_));
std::vector<T>().swap(memory_data_);
}
......@@ -173,30 +171,30 @@ int InMemoryDataFeed<T>::Next() {
CHECK(in_channel != nullptr);
CHECK(out_channel != nullptr);
int index = 0;
T instance;
T ins_vec;
while (index < DataFeed::default_batch_size_) {
if (in_channel->Size() == 0) {
break;
}
in_channel->Pop(instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
cur_channel_ = 1 - cur_channel_;
T instance;
T ins_vec;
while (index < DataFeed::default_batch_size_) {
if (in_channel->Size() == 0) {
break;
}
return DataFeed::batch_size_;
in_channel->Pop(instance);
AddInstanceToInsVec(&ins_vec, instance, index++);
out_channel->Push(std::move(instance));
}
DataFeed::batch_size_ = index;
if (DataFeed::batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
cur_channel_ = 1 - cur_channel_;
}
return DataFeed::batch_size_;
}
template <typename T>
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
T ins;
DeserializeIns(ins, ins_str);
shuffled_ins_->Push(std::move(ins));
T ins;
DeserializeIns(ins, ins_str);
shuffled_ins_->Push(std::move(ins));
}
template <typename T>
......@@ -205,11 +203,11 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
std::string filename;
while (DataFeed::PickOneFile(&filename)) {
int err_no = 0;
PrivateQueueDataFeed<T>::fp_ = fs_open_read(filename, &err_no,
PrivateQueueDataFeed<T>::pipe_command_);
PrivateQueueDataFeed<T>::fp_ =
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
__fsetlocking(&*PrivateQueueDataFeed<T>::fp_, FSETLOCKING_BYCALLER);
T instance;
while(ParseOneInstanceFromPipe(&instance)) {
while (ParseOneInstanceFromPipe(&instance)) {
local_vec.push_back(instance);
}
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end());
......@@ -242,6 +240,8 @@ void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
}
*/
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
......@@ -633,7 +633,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
}
}
bool MultiSlotInMemoryDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
bool MultiSlotInMemoryDataFeed::ParseOneInstance(
std::vector<MultiSlotType>* instance) {
std::string line;
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
......@@ -725,12 +726,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
}
// todo serialize ins in global shuffle
void MultiSlotInMemoryDataFeed::SerializeIns(const std::vector<MultiSlotType>& ins, std::string& str) {
void MultiSlotInMemoryDataFeed::SerializeIns(
const std::vector<MultiSlotType>& ins, std::string& str) {
return;
}
// todo deserialize ins in global shuffle
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins, const std::string& str) {
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins,
const std::string& str) {
return;
}
} // namespace framework
......
......@@ -21,7 +21,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) {
void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set) {
thread_num_ = trainer_desc.thread_num();
workers_.resize(thread_num_);
readers_.resize(thread_num_);
......
......@@ -19,13 +19,16 @@ limitations under the License. */
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
......@@ -115,9 +118,39 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
void Executor::RunFromDataset(const ProgramDesc& pdesc, const Dataset& dataset,
void Executor::RunFromDataset(const ProgramDesc& main_program,
const Dataset& dataset,
const std::string& trainer_desc_str,
const bool debug) {}
const bool debug) {
VLOG(3) << "Start to RunFromDataset in executor";
TrainerDesc trainer_desc;
google::protobuf::TextFormat::ParseFromString(trainer_desc_str,
&trainer_desc);
VLOG(3) << "Going to create trainer, trainer class is "
<< trainer_desc.class_name();
std::shared_ptr<TrainerBase> trainer;
trainer = TrainerFactory::CreateTrainer(trainer_desc.class_name());
// initialize trainer
VLOG(3) << "Going to initialize trainer";
trainer->Initialize(trainer_desc, dataset);
VLOG(3) << "Set root scope here";
trainer->SetScope(root_scope_);
VLOG(3) << "Going to set debug";
trainer->SetDebug(debug);
// prepare training environment and helper environment
VLOG(3) << "Try to init train environment";
trainer->InitTrainerEnv(main_program, place_);
VLOG(3) << "Try to init other environment";
trainer->InitOtherEnv(main_program);
// training and finalize training
VLOG(3) << "Trainer starts to run";
trainer->Run();
VLOG(3) << "Trainer going to finalize";
trainer->Finalize();
VLOG(3) << "Drop current scope kids";
root_scope_->DropKids();
return;
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars,
......
......@@ -22,11 +22,12 @@ namespace paddle {
namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
const Dataset& dataset) {
thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here
workers_.resize(thread_num_);
/*
if (NULL == dataset) {
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
......@@ -42,6 +43,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
} else {
// readers_ = dataset.get_readers(); ?
}
*/
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -29,7 +30,6 @@ limitations under the License. */
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/reader/blocking_queue.h"
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
......@@ -41,7 +41,8 @@ class TrainerBase {
// model memory are hosted in root_scope
void SetScope(Scope* root_scope);
void SetDebug(const bool debug) { debug_ = debug; }
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) = 0;
virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set) = 0;
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place) = 0;
virtual void InitOtherEnv(const ProgramDesc& main_program) = 0;
......@@ -60,7 +61,8 @@ class MultiTrainer : public TrainerBase {
public:
MultiTrainer() {}
virtual ~MultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set);
virtual void InitTrainerEnv(const ProgramDesc& main_program,
const platform::Place& place);
virtual void InitOtherEnv(const ProgramDesc& main_program) {}
......@@ -78,7 +80,8 @@ class DistMultiTrainer : public MultiTrainer {
public:
DistMultiTrainer() {}
virtual ~DistMultiTrainer() {}
virtual void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set);
virtual void Initialize(const TrainerDesc& trainer_desc,
const Dataset& data_set);
virtual void InitOtherEnv(const ProgramDesc& main_program);
virtual void Finalize();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册