提交 91fc8f35 编写于 作者: W wangguibao

Interface rework

上级 274ec6a1
......@@ -36,7 +36,7 @@ add_subdirectory(details)
endif (NOT WIN32)
# ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_param SRCS async_executor_param.proto)
proto_library(async_executor_proto SRCS data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
......@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
if(NOT WIN32)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler)
endif(NOT WIN32)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
if (NOT WIN32)
py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
if (NOT WIN32)
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else(NOT WIN32)
string(REPLACE "/" "\\" proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/")
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND copy /Y *.py ${proto_dstpath}
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif(NOT WIN32)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
......@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
if(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
else(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
endif(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif()
......@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
endif() # NOT WIN32
cc_library(async_executor
SRCS async_executor.cc data_feed.cc datafeed_creator.cc
SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc
DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method graph_to_program_pass
async_executor_param)
async_executor_proto)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
......@@ -36,43 +36,13 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
namespace framework {
bool AsyncExecutor::workers_initialized_ = false;
void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<Scope>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type);
}
}
static void ReadBinaryFile(const std::string& filename,
std::string* content) {
std::string &contents = *content;
......@@ -139,343 +109,100 @@ static void SaveModel(
}
} // end SaveModel
void ExecutorThreadWorker::Reset() {
inspect_values_.clear();
}
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
}
}
}
void ExecutorThreadWorker::SetDataFeed(DataFeed& datafeed) {
if (typeid(datafeed) == typeid(TextClassDataFeed)) {
local_reader_.reset(
new TextClassDataFeed(dynamic_cast<TextClassDataFeed &>(datafeed)));
local_reader_->SetThreadId(thread_id_);
}
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = local_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
AsyncExecutor::AsyncExecutor(Scope& scope, const platform::Place& place)
: root_scope_(scope), place_(place) {}
void ExecutorThreadWorker::SetInspectVarNames(
const std::vector<std::string>& inspect_var_names) {
inspect_var_names_.clear();
inspect_var_names_.insert(inspect_var_names_.end(),
inspect_var_names.begin(), inspect_var_names.end());
void AsyncExecutor::CreateThreads(
ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope& root_scope,
const int thread_index) {
worker->SetThreadId(thread_index);
worker->SetRootScope(&root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
}
void ExecutorThreadWorker::SetModelParamNames(
const std::vector<std::string>& param_names) {
model_param_names_ = param_names;
}
void ExecutorThreadWorker::SetDevice() {
static unsigned priority[] = {
0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47
};
unsigned int i = this->thread_id_;
if (i < sizeof(priority) / sizeof(unsigned)) {
unsigned proc = priority[i];
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(proc, &mask);
if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i;
} else {
CPU_ZERO(&mask);
if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
&& CPU_ISSET(proc, &mask)) {
LOG(ERROR) << "TRACE: Thread " << i
<< " is running on processor " << proc
<< "...";
}
}
}
}
void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train";
SetDevice();
int inspect_var_num = inspect_var_names_.size();
inspect_values_.clear();
inspect_values_.resize(inspect_var_num, 0);
local_reader_->WaitNextEpoch();
int epoch = local_reader_->GetCurrentEpoch();
LOG(ERROR) << "epoch: " << epoch;
int batch_num = 1;
while (true) {
const char *file = local_reader_->PickOneFile();
if (file == NULL) {
break;
}
if (!local_reader_->SetFile(file)) {
break;
}
while (true) {
bool flag = local_reader_->ReadBatch();
if (!flag) {
break;
}
for (unsigned int i = 0; i < ops_.size(); ++i) {
ops_[i]->Run(*thread_scope_, place_);
}
batch_num++;
float avg_inspect = 0.0;
for (int i = 0; i < inspect_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(inspect_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
inspect_values_[i] += avg_inspect;
}
thread_scope_->DropKids();
}
local_reader_->UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << epoch + 1
<< " called: " << memory::memory_usage(place_);
}
for (int i = 0; i < inspect_var_num; ++i) {
inspect_values_[i] /= batch_num;
std::string var = inspect_var_names_[i].substr(
0,
inspect_var_names_[i].find_first_of("_"));
LOG(ERROR) << "mean " << var.c_str()
<< " of epoch " << i + 1 << ": " << inspect_values_[i];
}
if (thread_id_ == 0) {
char modelfile[1024];
snprintf(&modelfile[0], sizeof(modelfile), "%s_epoch%d.model",
model_prefix_.c_str(), epoch);
std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG(ERROR) << "Going to save model " << modelfile;
SaveModel(main_program_,
thread_scope_,
model_param_names_,
model_filename,
true);
}
}
void ExecutorThreadWorker::SetThreadId(int tid) {
thread_id_ = tid;
}
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
place_ = place;
}
void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) {
main_program_.reset(new ProgramDesc(main_program_desc));
}
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
AsyncExecutor::AsyncExecutor(ProgramDesc& main_program,
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed,
unsigned int thread_num,
const platform::Place& place)
: thread_num_(thread_num),
place_(place),
main_program_(main_program),
data_feed_(data_feed) {
model_param_names_.clear();
model_param_names_.insert(model_param_names_.end(),
param_names.begin(),
param_names.end());
}
void AsyncExecutor::InitRootScope(Scope* scope) {
root_scope_ = scope;
}
void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
void AsyncExecutor::CheckFiles(
const std::vector<std::string>& files) {
// function for user to check file formats
// should be exposed to users
}
void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix;
}
void AsyncExecutor::RunStartupProgram(const ProgramDesc& program,
Scope* scope) {
auto& block = program.Block(0);
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
// LOGERR("Persistable Var Name:%s", var->Name().c_str());
}
}
std::vector<float> AsyncExecutor::RunFromFile(
const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_var_names) {
std::vector<std::thread> threads;
std::map<std::string, int> param_dict;
std::vector<OperatorBase *> ops;
for (auto& op_desc : block.AllOps()) {
std::vector<std::string> param_name_vec = op_desc->OutputArgumentNames();
bool need_to_run = false;
for (auto& name : param_name_vec) {
if (param_dict.find(name) == param_dict.end()) {
param_dict[name] = 1;
need_to_run = true;
}
}
if (need_to_run) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
OperatorBase* local_op_ptr = local_op.release();
ops.push_back(local_op_ptr);
}
/*
readerDesc: protobuf description for reader initlization
argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
reader:
1) each thread has a reader, reader will read input data and
put it into input queue
2) each reader has a Next() iterface, that can fetch an instance
from the input queue
*/
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed> > readers;
readers.resize(thread_num);
for (unsigned int i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
// LOGERR("There are %d parameters in startup program, %d op needs to run",
// param_dict.size(), ops.size());
for (auto& op : ops) {
op->Run(*scope, place_);
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
workers.resize(thread_num);
for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker);
}
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for (auto& op : ops) {
delete op;
}
// LOGERR("run startup program done.");
}
std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile(
const std::string& f) {
std::string program_desc_str;
ReadBinaryFile(f, &program_desc_str);
std::unique_ptr<ProgramDesc> program(new ProgramDesc(program_desc_str));
return program;
}
void AsyncExecutor::SetInspectVarNames(
const std::vector<std::string>& inspect_var_names) {
inspect_var_names_.clear();
inspect_var_names_.insert(inspect_var_names_.end(),
inspect_var_names.begin(), inspect_var_names.end());
}
void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i].reset(new ExecutorThreadWorker);
workers_[i]->SetThreadId(i);
workers_[i]->CreateThreadOperators(host_program);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->SetPlace(place_);
workers_[i]->SetMaxTrainingEpoch(max_epoch_);
workers_[i]->CreateThreadScope(host_program);
workers_[i]->SetInspectVarNames(inspect_var_names_);
workers_[i]->SetModelParamNames(model_param_names_);
workers_[i]->SetMainProgram(host_program);
workers_[i]->SetModelPrefix(model_prefix_);
//
// new a datafeed here
workers_[i]->SetDataFeed(data_feed_);
workers_[i]->BindingDataFeedMemory();
// prepare thread resource here
for (int thidx = 0; thidx < thread_num; ++thidx) {
CreateThreads(workers[thidx].get(), main_program,
readers[thidx], fetch_var_names, root_scope_, thidx);
}
}
std::vector<float>& AsyncExecutor::Run(
const std::vector<std::string>& inspect_var_names) {
SetInspectVarNames(inspect_var_names);
threads_.clear();
// thread binding here?
if (workers_initialized_ == false) {
PrepareThreads(main_program_);
workers_initialized_ = true;
}
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->Reset();
workers_[i]->SetInspectVarNames(inspect_var_names);
threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
workers_[i].get()));
// start executing ops in multiple threads
for (int thidx = 0; thidx < thread_num; ++thidx) {
threads.push_back(std::thread(&ExecutorThreadWorker::TrainFiles,
workers[thidx].get()));
}
for (auto& th : threads_) {
for (auto& th : threads) {
th.join();
}
inspect_values_.clear();
inspect_values_.resize(inspect_var_names_.size(), 0);
std::vector<float> fetch_values;
fetch_values.resize(fetch_var_names.size(), 0);
std::vector<std::vector<float>*> inspect_value_vectors;
inspect_value_vectors.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
inspect_value_vectors[i] = &workers_[i]->GetInspectValues();
std::vector<std::vector<float>*> fetch_value_vectors;
fetch_value_vectors.resize(thread_num);
for (int i = 0; i < thread_num; ++i) {
fetch_value_vectors[i] = &workers[i]->GetFetchValues();
}
for (unsigned int i = 0; i < inspect_var_names_.size(); ++i) {
for (unsigned int i = 0; i < fetch_var_names.size(); ++i) {
float value = 0.0;
for (int j = 0; j < thread_num_; ++j) {
value += inspect_value_vectors[j]->at(i);
for (int j = 0; j < thread_num; ++j) {
value += fetch_value_vectors[j]->at(i);
}
value /= thread_num_;
inspect_values_[i] = value;
value /= thread_num;
fetch_values[i] = value;
}
return inspect_values_;
return fetch_values;
}
void AsyncExecutor::LoadInitModel() {
......
......@@ -23,7 +23,8 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include <typeinfo>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -31,93 +32,13 @@ limitations under the License. */
namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);
class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
~ExecutorThreadWorker() {}
void CreateThreadScope(const ProgramDesc& program);
void SetThreadId(int tid);
void CreateThreadOperators(const ProgramDesc& program);
void SetRootScope(Scope* g_scope);
void SetDevice();
void AddFidSet();
void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
void AddTrainFile(const std::string& filename);
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
void SetMaxTrainingEpoch(const int max_epoch);
void BindingDataFeedMemory();
void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
void SetModelParamNames(const std::vector<std::string>& param_names);
void SetDataFeed(DataFeed& datafeed); // NOLINT
void Train();
const char* PickOneFile();
void UpdateEpochNum();
void Reset();
void Initialize() {}
std::vector<float>& GetInspectValues() {return inspect_values_;}
protected:
// thread index
int thread_id_;
// max epoch for each thread
unsigned int max_epoch_;
// instances learned currently
int comm_batch_;
std::string model_prefix_;
std::vector<std::string> op_names_;
// local ops for forward and backward
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<ProgramDesc> main_program_;
// binary data reader
std::unique_ptr<DataFeed> local_reader_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
// execution place
platform::Place place_;
// root scope for model parameters
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
private:
std::vector<float> inspect_values_;
};
class AsyncExecutor {
public:
explicit AsyncExecutor(ProgramDesc& main_program, // NOLINT
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed, // NOLINT
unsigned int thread_num,
const platform::Place& place);
explicit AsyncExecutor(Scope& scope, const platform::Place& place); // NOLINT
virtual ~AsyncExecutor() {}
static std::unique_ptr<ProgramDesc> LoadDescFromFile(
const std::string& filename);
void InitRootScope(Scope* scope);
void SetMaxTrainingEpoch(const int max_epoch);
Scope* GetRootScope() { return root_scope_; }
void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
void SetCommBatch(int comm_batch) {
comm_batch_ = comm_batch;
}
Scope* GetRootScope() { return &root_scope_; }
void SetModelPath(const std::string& model_path) {
model_path_ = model_path;
......@@ -132,38 +53,32 @@ class AsyncExecutor {
}
void SetModelPrefix(const std::string& model_prefix);
virtual void PrepareThreads(const ProgramDesc& host_program);
void RunStartupProgram(const ProgramDesc& program, Scope* scope);
std::vector<float>& Run(const std::vector<std::string>& inspect_var_names);
std::vector<float> RunFromFile(const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_names);
void CheckFiles(const std::vector<std::string>& files);
void LoadInitModel();
private:
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope& root_scope, // NOLINT
const int thread_index);
public:
int thread_num_;
int max_epoch_;
int batch_size_;
int comm_batch_;
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
std::vector<std::thread> threads_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
std::string model_prefix_;
std::string model_path_;
std::string init_prog_file_;
std::string init_model_file_;
Scope* root_scope_;
Scope& root_scope_;
platform::Place place_;
private:
ProgramDesc& main_program_;
TextClassDataFeed& data_feed_;
std::vector<float> inspect_values_;
private:
static bool workers_initialized_;
};
} // namespace framework
......
......@@ -38,17 +38,17 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle {
namespace framework {
std::vector<std::string> TextClassDataFeed::s_filelist_;
std::mutex TextClassDataFeed::s_locker_for_pick_file_;
unsigned int TextClassDataFeed::s_current_file_idx_ = 0;
size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int TextClassDataFeed::s_current_epoch_ = 0;
int TextClassDataFeed::s_current_save_epoch_ = 0;
std::mutex TextClassDataFeed::s_locker_epoch_start_;
std::condition_variable TextClassDataFeed::s_condition_epoch_start_;
bool TextClassDataFeed::s_epoch_start_flag_ = false;
void TextClassDataFeed::Init() {
std::vector<std::string> MultiSlotDataFeed::s_filelist_;
std::mutex MultiSlotDataFeed::s_locker_for_pick_file_;
unsigned int MultiSlotDataFeed::s_current_file_idx_ = 0;
size_t MultiSlotDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int MultiSlotDataFeed::s_current_epoch_ = 0;
int MultiSlotDataFeed::s_current_save_epoch_ = 0;
std::mutex MultiSlotDataFeed::s_locker_epoch_start_;
std::condition_variable MultiSlotDataFeed::s_condition_epoch_start_;
bool MultiSlotDataFeed::s_epoch_start_flag_ = false;
void MultiSlotDataFeed::Init() {
// hard coding for a specific datafeed
feed_vec_.resize(2);
// feed_vec_[0].reset(new LoDTensor);
......@@ -73,12 +73,12 @@ void TextClassDataFeed::Init() {
field_names_.clear();
}
TextClassDataFeed::TextClassDataFeed() {
MultiSlotDataFeed::MultiSlotDataFeed() {
Init();
}
// todo: use elegant implemention for this function
bool TextClassDataFeed::ReadBatch() {
bool MultiSlotDataFeed::ReadBatch() {
paddle::framework::Vector<size_t> offset;
int tlen = 0;
int llen = 0;
......@@ -142,13 +142,13 @@ bool TextClassDataFeed::ReadBatch() {
return true;
}
TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) {
MultiSlotDataFeed::MultiSlotDataFeed(const MultiSlotDataFeed& data_feed) {
Init();
SetBatchSize(data_feed.batch_size_);
SetFieldNames(data_feed.field_names_);
}
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
void MultiSlotDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == use_slot_alias_[i]) {
feed_vec_[i] = feed->GetMutable<LoDTensor>();
......@@ -156,7 +156,7 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
}
}
void TextClassDataFeed::SetFileList(const char* filelist) {
void MultiSlotDataFeed::SetFileList(const char* filelist) {
s_filelist_.clear();
std::ifstream fin(filelist);
PADDLE_ENFORCE(fin.good(),
......@@ -170,14 +170,14 @@ void TextClassDataFeed::SetFileList(const char* filelist) {
fin.close();
}
void TextClassDataFeed::SetFieldNames(
void MultiSlotDataFeed::SetFieldNames(
const std::vector<std::string>& field_names) {
field_names_.clear();
field_names_.insert(field_names_.end(), field_names.begin(),
field_names.end());
}
bool TextClassDataFeed::SetFile(const char* filename) {
bool MultiSlotDataFeed::SetFile(const char* filename) {
// termnum termid termid ... termid label
std::ifstream ifs(filename, std::ios::binary);
if (ifs.fail()) {
......@@ -198,7 +198,7 @@ bool TextClassDataFeed::SetFile(const char* filename) {
return true;
}
void TextClassDataFeed::UpdateEpochNum() {
void MultiSlotDataFeed::UpdateEpochNum() {
s_current_finished_file_cnt_++;
if (s_current_finished_file_cnt_ >= s_filelist_.size()) {
......@@ -214,25 +214,14 @@ void TextClassDataFeed::UpdateEpochNum() {
}
}
void TextClassDataFeed::StartOneEpoch() {
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
std::random_shuffle(s_filelist_.begin(), s_filelist_.end());
s_current_file_idx_ = 0;
LOG(INFO) << "Beginning epoch " << s_current_epoch_;
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = true;
}
s_condition_epoch_start_.notify_all();
void MultiSlotDataFeed::Start() {
}
void TextClassDataFeed::WaitNextEpoch() {
std::unique_lock<std::mutex> lock(s_locker_epoch_start_);
s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;});
int MultiSlotDataFeed::Next() {
return 0;
}
const char* TextClassDataFeed::PickOneFile() {
const char* MultiSlotDataFeed::PickOneFile() {
std::string file_to_be_processed;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
......
......@@ -81,8 +81,8 @@ class DataFeed {
virtual unsigned int GetCurrentEpoch() = 0;
virtual const char *PickOneFile() = 0;
virtual void UpdateEpochNum() = 0;
virtual void StartOneEpoch() = 0;
virtual void WaitNextEpoch() = 0;
virtual void Start() = 0;
virtual int Next() = 0;
std::vector<LoDTensor*>& GetFeedVec() {
return feed_vec_;
......@@ -106,13 +106,13 @@ class DataFeed {
int thread_id_;
};
class TextClassDataFeed : public DataFeed {
class MultiSlotDataFeed : public DataFeed {
public:
TextClassDataFeed();
TextClassDataFeed(const TextClassDataFeed& data_feed);
MultiSlotDataFeed();
MultiSlotDataFeed(const MultiSlotDataFeed& data_feed);
public:
virtual ~TextClassDataFeed() {}
virtual ~MultiSlotDataFeed() {}
virtual void Init();
virtual bool ReadBatch();
virtual void AddFeedVar(Variable* feed, const std::string& name);
......@@ -125,8 +125,8 @@ class TextClassDataFeed : public DataFeed {
void SetBatchSize(int batch) {batch_size_ = batch;}
unsigned int GetCurrentEpoch() {return s_current_epoch_;}
void UpdateEpochNum();
void StartOneEpoch();
void WaitNextEpoch();
void Start();
int Next();
public:
void SetFieldNames(const std::vector<std::string>& field_names);
......
......@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
package paddle.framework;
message DataFeedDesc {
optional string name = 1;
optional int32 batch = 2 [default = 32];
repeated string field_names = 3;
}
......@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
typedef shared_ptr<DataFeed> (*Createdata_feedFunction)();
typedef std::shared_ptr<DataFeed> (*Createdata_feedFunction)();
typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap;
data_feedMap g_data_feed_map;
#define REGISTER_DATAFEED_CLASS(data_feed_class) \
namespace { \
shared_ptr<DataFeed> Creator_##data_feed_class() { \
return shared_ptr<DataFeed>(new data_feed_class); \
std::shared_ptr<DataFeed> Creator_##data_feed_class() { \
return std::shared_ptr<DataFeed>(new data_feed_class); \
} \
class __Registerer_##data_feed_class { \
public: \
......@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map;
} // namespace
string DataFeedFactory::DataFeedTypeList() {
string data_feed_types;
std::string DataFeedFactory::DataFeedTypeList() {
std::string data_feed_types;
for (auto iter = g_data_feed_map.begin();
iter != g_data_feed_map.end(); ++iter) {
if (iter != g_data_feed_map.begin()) {
......@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() {
return data_feed_types;
}
shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
const char* data_feed_class) {
if (g_data_feed_map.count(string(data_feed_class)) < 1) {
std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
std::string data_feed_class) {
if (g_data_feed_map.count(data_feed_class) < 1) {
exit(-1);
}
return g_data_feed_map[data_feed_class]();
......
......@@ -16,14 +16,15 @@ limitations under the License. */
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#include <string>
#include "paddle/framework/data_feed.h"
#include <memory>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
class DataFeedFactory {
public:
static std::string DataFeedTypeList();
static shared_ptr<DataFeed> CreateDataFeed(const char* data_feed_class);
static std::shared_ptr<DataFeed> CreateDataFeed(std::string data_feed_class);
};
} // namespace framework
} // namespace paddle
......
......@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource(
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL(
root_scope_,
"root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
......@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
void ExecutorThreadWorker::SetDataFeed(
const std::shared_ptr<DataFeed>& datafeed) {
local_reader_ = datafeed;
thread_reader_ = datafeed;
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name);
thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void ExecutorThreadWorker::SetFetchVarNames(
const std::vector<std::string>& fetch_var_names) {
fetch_var_names_.clear();
fetch_var_names_.insert(fetch_var_names_.end(),
fetch_var_names.begin(), fetch_var_names.end());
}
void ExecutorThreadWorker::SetDevice() {
// at most 48 threads binding currently
static unsigned priority[] = {
......@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() {
void ExecutorThreadWorker::TrainFiles() {
// todo: configurable
SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0);
thread_reader_->Start();
while (int cur_batch = thread_reader_->Next()) {
int cur_batch;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
thread_scope_->DropKids();
}
}
......
......@@ -43,6 +43,9 @@ class ExecutorThreadWorker {
void SetDevice();
void BindingDataFeedMemory();
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
void TrainFiles();
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
std::vector<float>& GetFetchValues() {return fetch_values_;}
private:
void CreateThreadScope(const framework::ProgramDesc& program);
......@@ -66,9 +69,13 @@ class ExecutorThreadWorker {
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
private:
std::vector<std::string> fetch_var_names_;
std::vector<float> fetch_values_;
};
} // namespace framework
} // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
#endif // PADDLE_FLUID_FRAMEWORK_EXECUTOR_THREAD_WORKER_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
......@@ -30,36 +30,32 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle {
namespace pybind {
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
void BindAsyncExecutor(py::module* m) {
py::class_<framework::DataFeed>(*m, "DataFeed");
py::class_<framework::TextClassDataFeed,
framework::DataFeed>(*m, "TextDataFeed")
.def(py::init())
.def("set_filelist",
[] (framework::TextClassDataFeed &self, const char *data_list_file) {
self.SetFileList(data_list_file);
})
.def("set_batch_size", &framework::TextClassDataFeed::SetBatchSize)
.def("set_field_names", &framework::TextClassDataFeed::SetFieldNames)
.def("start_one_epoch", &framework::TextClassDataFeed::StartOneEpoch);
py::class_<pd::DataFeedDesc>(*m, "DataFeedDesc")
.def(pybind11::init<>())
.def("set_name", (set_name_func)&pd::DataFeedDesc::set_name)
.def("set_batch", &pd::DataFeedDesc::set_batch)
.def("set_field_names",
[] (pd::DataFeedDesc& self, const std::vector<std::string> &fields) {
for (auto field : fields) {
self.add_field_names(field);
}
});
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<framework::ProgramDesc&,
std::vector<std::string>&,
framework::TextClassDataFeed&,
unsigned int,
const platform::Place&>())
.def("init_root_scope", &framework::AsyncExecutor::InitRootScope)
.def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram)
.def("run", &framework::AsyncExecutor::Run);
.def(py::init<pd::Scope&, const platform::Place&>())
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("check_file", &framework::AsyncExecutor::CheckFiles);
} // end BindAsyncExecutor
} // end namespace pybind
} // end namespace paddle
......
......@@ -19,30 +19,26 @@ import contextlib
import six
from .framework import Program, default_main_program, Variable
from . import core
from . import Executor
from .executor import global_scope
__all__ = ['TextDataFeed', 'AsyncExecutor']
__all__ = ['MultiSlotDataFeed', 'AsyncExecutor']
g_scope = core.Scope()
class TextDataFeed():
class DataFeedDesc(object):
def __init__(self):
self.feed = core.TextDataFeed()
def set_filelist(self, filelist):
self.feed.set_filelist(filelist)
self.desc = core.DataFeedDesc()
def set_batch_size(self, batch_size):
self.feed.set_batch_size(batch_size)
def set_field_names(self, field_names):
if isinstance(field_names, Variable):
self.desc.set_batch(batch_size)
def set_field_name(self, field_names):
if isinstance(field_names, str):
field_names = [field_names]
self.desc.set_field_names(field_names)
self.feed.set_field_names(field_names)
def start_an_epoch(self):
self.feed.start_one_epoch()
class MultiSlotDataFeed(DataFeedDesc):
def __init__(self):
super(MultiSlotDataFeed, self).__init__()
self.desc.set_name("MultiSlotDataFeed")
class AsyncExecutor(object):
"""
......@@ -55,45 +51,19 @@ class AsyncExecutor(object):
They has the exactly same arguments, and expected the same results.
"""
def __init__(self,
program,
param_names,
data_feed,
thread_num,
place=None,
scope=None):
if program is None:
program = default_main_program()
program_desc = program.desc
if not isinstance(data_feed, TextDataFeed):
raise ValueError("data_feed for AsyncExecutor.run() type error")
def __init__(self, place=None):
if place is None:
place = core.CPUPlace()
if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device")
if isinstance(param_names, Variable):
param_names = [param_names]
p = core.Place()
p.set_place(place)
self.executor = core.AsyncExecutor(program_desc, param_names, data_feed.feed, thread_num, p)
def run_startup_program(self,
program=None,
scope=None):
if program is None:
program = default_startup_program()
program_desc = program._get_desc()
if scope is None:
scope = g_scope
scope = global_scope()
self.executor = core.AsyncExecutor(scope, p)
self.executor.run_startup_program(program_desc, scope)
def run(self, inspect_vars, scope=None):
def run(self, program, data_feed, filelist, thread_num, fetch):
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
......@@ -136,16 +106,27 @@ class AsyncExecutor(object):
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if inspect_vars is not None:
if isinstance(inspect_vars, Variable):
inspect_vars = [inspect_vars]
inspect_var_names = [var.name for var in inspect_vars]
if program is None:
program = default_main_program()
program_desc = program.desc
if data_feed is None:
raise ValueError('ValueError: data_feed should be provided')
if filelist is None:
raise ValueError('ValueError: filelist should be provided')
if isinstance(filelist, str):
filelist = [filelist]
if scope is None:
scope = g_scope
if not isinstance(thread_num, int):
raise TypeError('TypeError: thread_num should be a positive number')
self.executor.init_root_scope(scope)
if fetch is not None:
if isinstance(fetch, Variable):
fetch = [fetch]
fetch_var_names = [var.name for var in fetch]
evaluation = self.executor.run(inspect_var_names)
evaluation = self.executor.run_from_files(program_desc, data_feed.desc, filelist, thread_num, fetch_var_names)
return evaluation
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册