提交 91fc8f35 编写于 作者: W wangguibao

Interface rework

上级 274ec6a1
...@@ -36,7 +36,7 @@ add_subdirectory(details) ...@@ -36,7 +36,7 @@ add_subdirectory(details)
endif (NOT WIN32) endif (NOT WIN32)
# ddim lib # ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_param SRCS async_executor_param.proto) proto_library(async_executor_proto SRCS data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version) ...@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto) cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
if(NOT WIN32)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler) shape_inference data_transform lod_tensor profiler)
endif(NOT WIN32)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
if (NOT WIN32)
py_proto_compile(framework_py_proto SRCS framework.proto) py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
if (NOT WIN32) add_custom_command(TARGET framework_py_proto POST_BUILD
add_custom_command(TARGET framework_py_proto POST_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/
COMMAND cp *.py ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/ COMMENT "Copy generated python proto into directory paddle/fluid/proto."
COMMENT "Copy generated python proto into directory paddle/fluid/proto." WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else(NOT WIN32)
string(REPLACE "/" "\\" proto_dstpath "${PADDLE_BINARY_DIR}/python/paddle/fluid/proto/")
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_BINARY_DIR}/python/paddle/fluid/proto
COMMAND copy /Y *.py ${proto_dstpath}
COMMENT "Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif(NOT WIN32) endif(NOT WIN32)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
...@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE) ...@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() else()
if(NOT WIN32) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator)
else(NOT WIN32)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass)
endif(NOT WIN32)
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
endif() endif()
...@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS ...@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
endif() # NOT WIN32 endif() # NOT WIN32
cc_library(async_executor cc_library(async_executor
SRCS async_executor.cc data_feed.cc datafeed_creator.cc SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc
DEPS op_registry device_context scope framework_proto glog DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method graph_to_program_pass lod_rank_table feed_fetch_method graph_to_program_pass
async_executor_param) async_executor_proto)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -36,43 +36,13 @@ limitations under the License. */ ...@@ -36,43 +36,13 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
bool AsyncExecutor::workers_initialized_ = false;
void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<Scope>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type);
}
}
static void ReadBinaryFile(const std::string& filename, static void ReadBinaryFile(const std::string& filename,
std::string* content) { std::string* content) {
std::string &contents = *content; std::string &contents = *content;
...@@ -139,343 +109,100 @@ static void SaveModel( ...@@ -139,343 +109,100 @@ static void SaveModel(
} }
} // end SaveModel } // end SaveModel
void ExecutorThreadWorker::Reset() { AsyncExecutor::AsyncExecutor(Scope& scope, const platform::Place& place)
inspect_values_.clear(); : root_scope_(scope), place_(place) {}
}
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
}
}
}
void ExecutorThreadWorker::SetDataFeed(DataFeed& datafeed) {
if (typeid(datafeed) == typeid(TextClassDataFeed)) {
local_reader_.reset(
new TextClassDataFeed(dynamic_cast<TextClassDataFeed &>(datafeed)));
local_reader_->SetThreadId(thread_id_);
}
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = local_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void ExecutorThreadWorker::SetInspectVarNames( void AsyncExecutor::CreateThreads(
const std::vector<std::string>& inspect_var_names) { ExecutorThreadWorker* worker,
inspect_var_names_.clear(); const ProgramDesc& main_program,
inspect_var_names_.insert(inspect_var_names_.end(), const std::shared_ptr<DataFeed>& reader,
inspect_var_names.begin(), inspect_var_names.end()); const std::vector<std::string>& fetch_var_names,
Scope& root_scope,
const int thread_index) {
worker->SetThreadId(thread_index);
worker->SetRootScope(&root_scope);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
} }
void ExecutorThreadWorker::SetModelParamNames( void AsyncExecutor::CheckFiles(
const std::vector<std::string>& param_names) { const std::vector<std::string>& files) {
model_param_names_ = param_names; // function for user to check file formats
} // should be exposed to users
void ExecutorThreadWorker::SetDevice() {
static unsigned priority[] = {
0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47
};
unsigned int i = this->thread_id_;
if (i < sizeof(priority) / sizeof(unsigned)) {
unsigned proc = priority[i];
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(proc, &mask);
if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i;
} else {
CPU_ZERO(&mask);
if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
&& CPU_ISSET(proc, &mask)) {
LOG(ERROR) << "TRACE: Thread " << i
<< " is running on processor " << proc
<< "...";
}
}
}
}
void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train";
SetDevice();
int inspect_var_num = inspect_var_names_.size();
inspect_values_.clear();
inspect_values_.resize(inspect_var_num, 0);
local_reader_->WaitNextEpoch();
int epoch = local_reader_->GetCurrentEpoch();
LOG(ERROR) << "epoch: " << epoch;
int batch_num = 1;
while (true) {
const char *file = local_reader_->PickOneFile();
if (file == NULL) {
break;
}
if (!local_reader_->SetFile(file)) {
break;
}
while (true) {
bool flag = local_reader_->ReadBatch();
if (!flag) {
break;
}
for (unsigned int i = 0; i < ops_.size(); ++i) {
ops_[i]->Run(*thread_scope_, place_);
}
batch_num++;
float avg_inspect = 0.0;
for (int i = 0; i < inspect_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(inspect_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
inspect_values_[i] += avg_inspect;
}
thread_scope_->DropKids();
}
local_reader_->UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << epoch + 1
<< " called: " << memory::memory_usage(place_);
}
for (int i = 0; i < inspect_var_num; ++i) {
inspect_values_[i] /= batch_num;
std::string var = inspect_var_names_[i].substr(
0,
inspect_var_names_[i].find_first_of("_"));
LOG(ERROR) << "mean " << var.c_str()
<< " of epoch " << i + 1 << ": " << inspect_values_[i];
}
if (thread_id_ == 0) {
char modelfile[1024];
snprintf(&modelfile[0], sizeof(modelfile), "%s_epoch%d.model",
model_prefix_.c_str(), epoch);
std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG(ERROR) << "Going to save model " << modelfile;
SaveModel(main_program_,
thread_scope_,
model_param_names_,
model_filename,
true);
}
}
void ExecutorThreadWorker::SetThreadId(int tid) {
thread_id_ = tid;
}
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
place_ = place;
}
void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) {
main_program_.reset(new ProgramDesc(main_program_desc));
}
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
AsyncExecutor::AsyncExecutor(ProgramDesc& main_program,
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed,
unsigned int thread_num,
const platform::Place& place)
: thread_num_(thread_num),
place_(place),
main_program_(main_program),
data_feed_(data_feed) {
model_param_names_.clear();
model_param_names_.insert(model_param_names_.end(),
param_names.begin(),
param_names.end());
}
void AsyncExecutor::InitRootScope(Scope* scope) {
root_scope_ = scope;
}
void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
} }
void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix; model_prefix_ = model_prefix;
} }
void AsyncExecutor::RunStartupProgram(const ProgramDesc& program, std::vector<float> AsyncExecutor::RunFromFile(
Scope* scope) { const ProgramDesc& main_program,
auto& block = program.Block(0); const DataFeedDesc& data_feed_desc,
for (auto& var : block.AllVars()) { const std::vector<std::string>& filelist,
if (var->Persistable()) { const int thread_num,
auto* ptr = scope->Var(var->Name()); const std::vector<std::string>& fetch_var_names) {
CreateTensor(ptr, var->GetType()); std::vector<std::thread> threads;
// LOGERR("Persistable Var Name:%s", var->Name().c_str());
}
}
std::map<std::string, int> param_dict; /*
std::vector<OperatorBase *> ops; readerDesc: protobuf description for reader initlization
for (auto& op_desc : block.AllOps()) { argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
std::vector<std::string> param_name_vec = op_desc->OutputArgumentNames();
bool need_to_run = false; reader:
for (auto& name : param_name_vec) { 1) each thread has a reader, reader will read input data and
if (param_dict.find(name) == param_dict.end()) { put it into input queue
param_dict[name] = 1; 2) each reader has a Next() iterface, that can fetch an instance
need_to_run = true; from the input queue
} */
} // todo: should be factory method for creating datafeed
if (need_to_run) { std::vector<std::shared_ptr<DataFeed> > readers;
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc); readers.resize(thread_num);
OperatorBase* local_op_ptr = local_op.release(); for (unsigned int i = 0; i < readers.size(); ++i) {
ops.push_back(local_op_ptr); readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
} }
// LOGERR("There are %d parameters in startup program, %d op needs to run",
// param_dict.size(), ops.size());
for (auto& op : ops) { std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
op->Run(*scope, place_); workers.resize(thread_num);
for (auto& worker : workers) {
worker.reset(new ExecutorThreadWorker);
} }
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for (auto& op : ops) {
delete op;
}
// LOGERR("run startup program done.");
}
std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile( // prepare thread resource here
const std::string& f) { for (int thidx = 0; thidx < thread_num; ++thidx) {
std::string program_desc_str; CreateThreads(workers[thidx].get(), main_program,
ReadBinaryFile(f, &program_desc_str); readers[thidx], fetch_var_names, root_scope_, thidx);
std::unique_ptr<ProgramDesc> program(new ProgramDesc(program_desc_str));
return program;
}
void AsyncExecutor::SetInspectVarNames(
const std::vector<std::string>& inspect_var_names) {
inspect_var_names_.clear();
inspect_var_names_.insert(inspect_var_names_.end(),
inspect_var_names.begin(), inspect_var_names.end());
}
void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i].reset(new ExecutorThreadWorker);
workers_[i]->SetThreadId(i);
workers_[i]->CreateThreadOperators(host_program);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->SetPlace(place_);
workers_[i]->SetMaxTrainingEpoch(max_epoch_);
workers_[i]->CreateThreadScope(host_program);
workers_[i]->SetInspectVarNames(inspect_var_names_);
workers_[i]->SetModelParamNames(model_param_names_);
workers_[i]->SetMainProgram(host_program);
workers_[i]->SetModelPrefix(model_prefix_);
//
// new a datafeed here
workers_[i]->SetDataFeed(data_feed_);
workers_[i]->BindingDataFeedMemory();
} }
}
std::vector<float>& AsyncExecutor::Run( // start executing ops in multiple threads
const std::vector<std::string>& inspect_var_names) { for (int thidx = 0; thidx < thread_num; ++thidx) {
SetInspectVarNames(inspect_var_names); threads.push_back(std::thread(&ExecutorThreadWorker::TrainFiles,
threads_.clear(); workers[thidx].get()));
// thread binding here?
if (workers_initialized_ == false) {
PrepareThreads(main_program_);
workers_initialized_ = true;
}
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->Reset();
workers_[i]->SetInspectVarNames(inspect_var_names);
threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
workers_[i].get()));
} }
for (auto& th : threads_) { for (auto& th : threads) {
th.join(); th.join();
} }
inspect_values_.clear(); std::vector<float> fetch_values;
inspect_values_.resize(inspect_var_names_.size(), 0); fetch_values.resize(fetch_var_names.size(), 0);
std::vector<std::vector<float>*> inspect_value_vectors; std::vector<std::vector<float>*> fetch_value_vectors;
inspect_value_vectors.resize(thread_num_); fetch_value_vectors.resize(thread_num);
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num; ++i) {
inspect_value_vectors[i] = &workers_[i]->GetInspectValues(); fetch_value_vectors[i] = &workers[i]->GetFetchValues();
} }
for (unsigned int i = 0; i < inspect_var_names_.size(); ++i) { for (unsigned int i = 0; i < fetch_var_names.size(); ++i) {
float value = 0.0; float value = 0.0;
for (int j = 0; j < thread_num_; ++j) { for (int j = 0; j < thread_num; ++j) {
value += inspect_value_vectors[j]->at(i); value += fetch_value_vectors[j]->at(i);
} }
value /= thread_num_; value /= thread_num;
inspect_values_[i] = value; fetch_values[i] = value;
} }
return inspect_values_; return fetch_values;
} }
void AsyncExecutor::LoadInitModel() { void AsyncExecutor::LoadInitModel() {
......
...@@ -23,7 +23,8 @@ limitations under the License. */ ...@@ -23,7 +23,8 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include <typeinfo> #include <typeinfo>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/datafeed_creator.h" #include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -31,93 +32,13 @@ limitations under the License. */ ...@@ -31,93 +32,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);
class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
~ExecutorThreadWorker() {}
void CreateThreadScope(const ProgramDesc& program);
void SetThreadId(int tid);
void CreateThreadOperators(const ProgramDesc& program);
void SetRootScope(Scope* g_scope);
void SetDevice();
void AddFidSet();
void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
void AddTrainFile(const std::string& filename);
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
void SetMaxTrainingEpoch(const int max_epoch);
void BindingDataFeedMemory();
void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
void SetModelParamNames(const std::vector<std::string>& param_names);
void SetDataFeed(DataFeed& datafeed); // NOLINT
void Train();
const char* PickOneFile();
void UpdateEpochNum();
void Reset();
void Initialize() {}
std::vector<float>& GetInspectValues() {return inspect_values_;}
protected:
// thread index
int thread_id_;
// max epoch for each thread
unsigned int max_epoch_;
// instances learned currently
int comm_batch_;
std::string model_prefix_;
std::vector<std::string> op_names_;
// local ops for forward and backward
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<ProgramDesc> main_program_;
// binary data reader
std::unique_ptr<DataFeed> local_reader_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
// execution place
platform::Place place_;
// root scope for model parameters
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
private:
std::vector<float> inspect_values_;
};
class AsyncExecutor { class AsyncExecutor {
public: public:
explicit AsyncExecutor(ProgramDesc& main_program, // NOLINT explicit AsyncExecutor(Scope& scope, const platform::Place& place); // NOLINT
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed, // NOLINT
unsigned int thread_num,
const platform::Place& place);
virtual ~AsyncExecutor() {} virtual ~AsyncExecutor() {}
static std::unique_ptr<ProgramDesc> LoadDescFromFile( static std::unique_ptr<ProgramDesc> LoadDescFromFile(
const std::string& filename); const std::string& filename);
void InitRootScope(Scope* scope); Scope* GetRootScope() { return &root_scope_; }
void SetMaxTrainingEpoch(const int max_epoch);
Scope* GetRootScope() { return root_scope_; }
void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
void SetCommBatch(int comm_batch) {
comm_batch_ = comm_batch;
}
void SetModelPath(const std::string& model_path) { void SetModelPath(const std::string& model_path) {
model_path_ = model_path; model_path_ = model_path;
...@@ -132,38 +53,32 @@ class AsyncExecutor { ...@@ -132,38 +53,32 @@ class AsyncExecutor {
} }
void SetModelPrefix(const std::string& model_prefix); void SetModelPrefix(const std::string& model_prefix);
virtual void PrepareThreads(const ProgramDesc& host_program);
void RunStartupProgram(const ProgramDesc& program, Scope* scope); void RunStartupProgram(const ProgramDesc& program, Scope* scope);
std::vector<float>& Run(const std::vector<std::string>& inspect_var_names); std::vector<float> RunFromFile(const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_names);
void CheckFiles(const std::vector<std::string>& files);
void LoadInitModel(); void LoadInitModel();
private: private:
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names); void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const std::shared_ptr<DataFeed>& reader,
const std::vector<std::string>& fetch_var_names,
Scope& root_scope, // NOLINT
const int thread_index);
public: public:
int thread_num_;
int max_epoch_;
int batch_size_;
int comm_batch_;
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
std::vector<std::thread> threads_;
std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_;
std::string model_prefix_; std::string model_prefix_;
std::string model_path_; std::string model_path_;
std::string init_prog_file_; std::string init_prog_file_;
std::string init_model_file_; std::string init_model_file_;
Scope* root_scope_; Scope& root_scope_;
platform::Place place_; platform::Place place_;
private:
ProgramDesc& main_program_;
TextClassDataFeed& data_feed_;
std::vector<float> inspect_values_;
private:
static bool workers_initialized_;
}; };
} // namespace framework } // namespace framework
......
...@@ -38,17 +38,17 @@ DEFINE_bool(is_text_feed, false, "is_text_feed"); ...@@ -38,17 +38,17 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<std::string> TextClassDataFeed::s_filelist_; std::vector<std::string> MultiSlotDataFeed::s_filelist_;
std::mutex TextClassDataFeed::s_locker_for_pick_file_; std::mutex MultiSlotDataFeed::s_locker_for_pick_file_;
unsigned int TextClassDataFeed::s_current_file_idx_ = 0; unsigned int MultiSlotDataFeed::s_current_file_idx_ = 0;
size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0; size_t MultiSlotDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int TextClassDataFeed::s_current_epoch_ = 0; unsigned int MultiSlotDataFeed::s_current_epoch_ = 0;
int TextClassDataFeed::s_current_save_epoch_ = 0; int MultiSlotDataFeed::s_current_save_epoch_ = 0;
std::mutex TextClassDataFeed::s_locker_epoch_start_; std::mutex MultiSlotDataFeed::s_locker_epoch_start_;
std::condition_variable TextClassDataFeed::s_condition_epoch_start_; std::condition_variable MultiSlotDataFeed::s_condition_epoch_start_;
bool TextClassDataFeed::s_epoch_start_flag_ = false; bool MultiSlotDataFeed::s_epoch_start_flag_ = false;
void TextClassDataFeed::Init() { void MultiSlotDataFeed::Init() {
// hard coding for a specific datafeed // hard coding for a specific datafeed
feed_vec_.resize(2); feed_vec_.resize(2);
// feed_vec_[0].reset(new LoDTensor); // feed_vec_[0].reset(new LoDTensor);
...@@ -73,12 +73,12 @@ void TextClassDataFeed::Init() { ...@@ -73,12 +73,12 @@ void TextClassDataFeed::Init() {
field_names_.clear(); field_names_.clear();
} }
TextClassDataFeed::TextClassDataFeed() { MultiSlotDataFeed::MultiSlotDataFeed() {
Init(); Init();
} }
// todo: use elegant implemention for this function // todo: use elegant implemention for this function
bool TextClassDataFeed::ReadBatch() { bool MultiSlotDataFeed::ReadBatch() {
paddle::framework::Vector<size_t> offset; paddle::framework::Vector<size_t> offset;
int tlen = 0; int tlen = 0;
int llen = 0; int llen = 0;
...@@ -142,13 +142,13 @@ bool TextClassDataFeed::ReadBatch() { ...@@ -142,13 +142,13 @@ bool TextClassDataFeed::ReadBatch() {
return true; return true;
} }
TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) { MultiSlotDataFeed::MultiSlotDataFeed(const MultiSlotDataFeed& data_feed) {
Init(); Init();
SetBatchSize(data_feed.batch_size_); SetBatchSize(data_feed.batch_size_);
SetFieldNames(data_feed.field_names_); SetFieldNames(data_feed.field_names_);
} }
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { void MultiSlotDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) { for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == use_slot_alias_[i]) { if (name == use_slot_alias_[i]) {
feed_vec_[i] = feed->GetMutable<LoDTensor>(); feed_vec_[i] = feed->GetMutable<LoDTensor>();
...@@ -156,7 +156,7 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { ...@@ -156,7 +156,7 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
} }
} }
void TextClassDataFeed::SetFileList(const char* filelist) { void MultiSlotDataFeed::SetFileList(const char* filelist) {
s_filelist_.clear(); s_filelist_.clear();
std::ifstream fin(filelist); std::ifstream fin(filelist);
PADDLE_ENFORCE(fin.good(), PADDLE_ENFORCE(fin.good(),
...@@ -170,14 +170,14 @@ void TextClassDataFeed::SetFileList(const char* filelist) { ...@@ -170,14 +170,14 @@ void TextClassDataFeed::SetFileList(const char* filelist) {
fin.close(); fin.close();
} }
void TextClassDataFeed::SetFieldNames( void MultiSlotDataFeed::SetFieldNames(
const std::vector<std::string>& field_names) { const std::vector<std::string>& field_names) {
field_names_.clear(); field_names_.clear();
field_names_.insert(field_names_.end(), field_names.begin(), field_names_.insert(field_names_.end(), field_names.begin(),
field_names.end()); field_names.end());
} }
bool TextClassDataFeed::SetFile(const char* filename) { bool MultiSlotDataFeed::SetFile(const char* filename) {
// termnum termid termid ... termid label // termnum termid termid ... termid label
std::ifstream ifs(filename, std::ios::binary); std::ifstream ifs(filename, std::ios::binary);
if (ifs.fail()) { if (ifs.fail()) {
...@@ -198,7 +198,7 @@ bool TextClassDataFeed::SetFile(const char* filename) { ...@@ -198,7 +198,7 @@ bool TextClassDataFeed::SetFile(const char* filename) {
return true; return true;
} }
void TextClassDataFeed::UpdateEpochNum() { void MultiSlotDataFeed::UpdateEpochNum() {
s_current_finished_file_cnt_++; s_current_finished_file_cnt_++;
if (s_current_finished_file_cnt_ >= s_filelist_.size()) { if (s_current_finished_file_cnt_ >= s_filelist_.size()) {
...@@ -214,25 +214,14 @@ void TextClassDataFeed::UpdateEpochNum() { ...@@ -214,25 +214,14 @@ void TextClassDataFeed::UpdateEpochNum() {
} }
} }
void TextClassDataFeed::StartOneEpoch() { void MultiSlotDataFeed::Start() {
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
std::random_shuffle(s_filelist_.begin(), s_filelist_.end());
s_current_file_idx_ = 0;
LOG(INFO) << "Beginning epoch " << s_current_epoch_;
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = true;
}
s_condition_epoch_start_.notify_all();
} }
void TextClassDataFeed::WaitNextEpoch() { int MultiSlotDataFeed::Next() {
std::unique_lock<std::mutex> lock(s_locker_epoch_start_); return 0;
s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;});
} }
const char* TextClassDataFeed::PickOneFile() { const char* MultiSlotDataFeed::PickOneFile() {
std::string file_to_be_processed; std::string file_to_be_processed;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_); std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
......
...@@ -81,8 +81,8 @@ class DataFeed { ...@@ -81,8 +81,8 @@ class DataFeed {
virtual unsigned int GetCurrentEpoch() = 0; virtual unsigned int GetCurrentEpoch() = 0;
virtual const char *PickOneFile() = 0; virtual const char *PickOneFile() = 0;
virtual void UpdateEpochNum() = 0; virtual void UpdateEpochNum() = 0;
virtual void StartOneEpoch() = 0; virtual void Start() = 0;
virtual void WaitNextEpoch() = 0; virtual int Next() = 0;
std::vector<LoDTensor*>& GetFeedVec() { std::vector<LoDTensor*>& GetFeedVec() {
return feed_vec_; return feed_vec_;
...@@ -106,13 +106,13 @@ class DataFeed { ...@@ -106,13 +106,13 @@ class DataFeed {
int thread_id_; int thread_id_;
}; };
class TextClassDataFeed : public DataFeed { class MultiSlotDataFeed : public DataFeed {
public: public:
TextClassDataFeed(); MultiSlotDataFeed();
TextClassDataFeed(const TextClassDataFeed& data_feed); MultiSlotDataFeed(const MultiSlotDataFeed& data_feed);
public: public:
virtual ~TextClassDataFeed() {} virtual ~MultiSlotDataFeed() {}
virtual void Init(); virtual void Init();
virtual bool ReadBatch(); virtual bool ReadBatch();
virtual void AddFeedVar(Variable* feed, const std::string& name); virtual void AddFeedVar(Variable* feed, const std::string& name);
...@@ -125,8 +125,8 @@ class TextClassDataFeed : public DataFeed { ...@@ -125,8 +125,8 @@ class TextClassDataFeed : public DataFeed {
void SetBatchSize(int batch) {batch_size_ = batch;} void SetBatchSize(int batch) {batch_size_ = batch;}
unsigned int GetCurrentEpoch() {return s_current_epoch_;} unsigned int GetCurrentEpoch() {return s_current_epoch_;}
void UpdateEpochNum(); void UpdateEpochNum();
void StartOneEpoch(); void Start();
void WaitNextEpoch(); int Next();
public: public:
void SetFieldNames(const std::vector<std::string>& field_names); void SetFieldNames(const std::vector<std::string>& field_names);
......
...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
package paddle; package paddle.framework;
message DataFeedDesc { message DataFeedDesc {
optional string name = 1; optional string name = 1;
optional int32 batch = 2 [default = 32]; optional int32 batch = 2 [default = 32];
repeated string field_names = 3;
} }
...@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef shared_ptr<DataFeed> (*Createdata_feedFunction)(); typedef std::shared_ptr<DataFeed> (*Createdata_feedFunction)();
typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap; typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap;
data_feedMap g_data_feed_map; data_feedMap g_data_feed_map;
#define REGISTER_DATAFEED_CLASS(data_feed_class) \ #define REGISTER_DATAFEED_CLASS(data_feed_class) \
namespace { \ namespace { \
shared_ptr<DataFeed> Creator_##data_feed_class() { \ std::shared_ptr<DataFeed> Creator_##data_feed_class() { \
return shared_ptr<DataFeed>(new data_feed_class); \ return std::shared_ptr<DataFeed>(new data_feed_class); \
} \ } \
class __Registerer_##data_feed_class { \ class __Registerer_##data_feed_class { \
public: \ public: \
...@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map; ...@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map;
} // namespace } // namespace
string DataFeedFactory::DataFeedTypeList() { std::string DataFeedFactory::DataFeedTypeList() {
string data_feed_types; std::string data_feed_types;
for (auto iter = g_data_feed_map.begin(); for (auto iter = g_data_feed_map.begin();
iter != g_data_feed_map.end(); ++iter) { iter != g_data_feed_map.end(); ++iter) {
if (iter != g_data_feed_map.begin()) { if (iter != g_data_feed_map.begin()) {
...@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() { ...@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() {
return data_feed_types; return data_feed_types;
} }
shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed( std::shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
const char* data_feed_class) { std::string data_feed_class) {
if (g_data_feed_map.count(string(data_feed_class)) < 1) { if (g_data_feed_map.count(data_feed_class) < 1) {
exit(-1); exit(-1);
} }
return g_data_feed_map[data_feed_class](); return g_data_feed_map[data_feed_class]();
......
...@@ -16,14 +16,15 @@ limitations under the License. */ ...@@ -16,14 +16,15 @@ limitations under the License. */
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_ #define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#include <string> #include <string>
#include "paddle/framework/data_feed.h" #include <memory>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class DataFeedFactory { class DataFeedFactory {
public: public:
static std::string DataFeedTypeList(); static std::string DataFeedTypeList();
static shared_ptr<DataFeed> CreateDataFeed(const char* data_feed_class); static std::shared_ptr<DataFeed> CreateDataFeed(std::string data_feed_class);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource( ...@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource(
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0); auto& block = program.Block(0);
PADDLE_ENFORCE_NOT_NULL(
root_scope_,
"root_scope should be set before creating thread scope");
thread_scope_ = &root_scope_->NewScope(); thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
if (var->Persistable()) { if (var->Persistable()) {
...@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { ...@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
void ExecutorThreadWorker::SetDataFeed( void ExecutorThreadWorker::SetDataFeed(
const std::shared_ptr<DataFeed>& datafeed) { const std::shared_ptr<DataFeed>& datafeed) {
local_reader_ = datafeed; thread_reader_ = datafeed;
} }
void ExecutorThreadWorker::BindingDataFeedMemory() { void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias(); thread_reader_->GetUseSlotAlias();
for (auto name : input_feed) { for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name); thread_reader_->AddFeedVar(thread_scope_->Var(name), name);
} }
} }
void ExecutorThreadWorker::SetFetchVarNames(
const std::vector<std::string>& fetch_var_names) {
fetch_var_names_.clear();
fetch_var_names_.insert(fetch_var_names_.end(),
fetch_var_names.begin(), fetch_var_names.end());
}
void ExecutorThreadWorker::SetDevice() { void ExecutorThreadWorker::SetDevice() {
// at most 48 threads binding currently // at most 48 threads binding currently
static unsigned priority[] = { static unsigned priority[] = {
...@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() { ...@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() {
void ExecutorThreadWorker::TrainFiles() { void ExecutorThreadWorker::TrainFiles() {
// todo: configurable // todo: configurable
SetDevice(); SetDevice();
int fetch_var_num = fetch_var_names_.size();
fetch_values_.clear();
fetch_values_.resize(fetch_var_num, 0);
thread_reader_->Start(); thread_reader_->Start();
while (int cur_batch = thread_reader_->Next()) {
int cur_batch;
while ((cur_batch = thread_reader_->Next()) > 0) {
// executor run here // executor run here
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
} }
float avg_inspect = 0.0;
for (int i = 0; i < fetch_var_num; ++i) {
avg_inspect = thread_scope_->FindVar(fetch_var_names_[i])
->GetMutable<LoDTensor>()
->data<float>()[0];
fetch_values_[i] += avg_inspect;
}
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
} }
......
...@@ -43,6 +43,9 @@ class ExecutorThreadWorker { ...@@ -43,6 +43,9 @@ class ExecutorThreadWorker {
void SetDevice(); void SetDevice();
void BindingDataFeedMemory(); void BindingDataFeedMemory();
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed); void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
void TrainFiles();
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
std::vector<float>& GetFetchValues() {return fetch_values_;}
private: private:
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
...@@ -66,9 +69,13 @@ class ExecutorThreadWorker { ...@@ -66,9 +69,13 @@ class ExecutorThreadWorker {
Scope* root_scope_; Scope* root_scope_;
// a thread scope, father scope is global score which is shared // a thread scope, father scope is global score which is shared
Scope* thread_scope_; Scope* thread_scope_;
private:
std::vector<std::string> fetch_var_names_;
std::vector<float> fetch_values_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_ #endif // PADDLE_FLUID_FRAMEWORK_EXECUTOR_THREAD_WORKER_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ /* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
...@@ -30,36 +30,32 @@ limitations under the License. */ ...@@ -30,36 +30,32 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
namespace py = pybind11; namespace py = pybind11;
namespace pd = paddle::framework;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
void BindAsyncExecutor(py::module* m) { void BindAsyncExecutor(py::module* m) {
py::class_<framework::DataFeed>(*m, "DataFeed"); py::class_<pd::DataFeedDesc>(*m, "DataFeedDesc")
py::class_<framework::TextClassDataFeed, .def(pybind11::init<>())
framework::DataFeed>(*m, "TextDataFeed") .def("set_name", (set_name_func)&pd::DataFeedDesc::set_name)
.def(py::init()) .def("set_batch", &pd::DataFeedDesc::set_batch)
.def("set_filelist", .def("set_field_names",
[] (framework::TextClassDataFeed &self, const char *data_list_file) { [] (pd::DataFeedDesc& self, const std::vector<std::string> &fields) {
self.SetFileList(data_list_file); for (auto field : fields) {
}) self.add_field_names(field);
.def("set_batch_size", &framework::TextClassDataFeed::SetBatchSize) }
.def("set_field_names", &framework::TextClassDataFeed::SetFieldNames) });
.def("start_one_epoch", &framework::TextClassDataFeed::StartOneEpoch);
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor") py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<framework::ProgramDesc&, .def(py::init<pd::Scope&, const platform::Place&>())
std::vector<std::string>&, .def("run_from_files", &framework::AsyncExecutor::RunFromFile)
framework::TextClassDataFeed&, .def("check_file", &framework::AsyncExecutor::CheckFiles);
unsigned int,
const platform::Place&>())
.def("init_root_scope", &framework::AsyncExecutor::InitRootScope)
.def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram)
.def("run", &framework::AsyncExecutor::Run);
} // end BindAsyncExecutor } // end BindAsyncExecutor
} // end namespace pybind } // end namespace pybind
} // end namespace paddle } // end namespace paddle
......
...@@ -19,30 +19,26 @@ import contextlib ...@@ -19,30 +19,26 @@ import contextlib
import six import six
from .framework import Program, default_main_program, Variable from .framework import Program, default_main_program, Variable
from . import core from . import core
from . import Executor from .executor import global_scope
__all__ = ['TextDataFeed', 'AsyncExecutor'] __all__ = ['MultiSlotDataFeed', 'AsyncExecutor']
g_scope = core.Scope() g_scope = core.Scope()
class TextDataFeed(): class DataFeedDesc(object):
def __init__(self): def __init__(self):
self.feed = core.TextDataFeed() self.desc = core.DataFeedDesc()
def set_filelist(self, filelist):
self.feed.set_filelist(filelist)
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.feed.set_batch_size(batch_size) self.desc.set_batch(batch_size)
def set_field_name(self, field_names):
def set_field_names(self, field_names): if isinstance(field_names, str):
if isinstance(field_names, Variable):
field_names = [field_names] field_names = [field_names]
self.desc.set_field_names(field_names)
self.feed.set_field_names(field_names) class MultiSlotDataFeed(DataFeedDesc):
def __init__(self):
def start_an_epoch(self): super(MultiSlotDataFeed, self).__init__()
self.feed.start_one_epoch() self.desc.set_name("MultiSlotDataFeed")
class AsyncExecutor(object): class AsyncExecutor(object):
""" """
...@@ -55,45 +51,19 @@ class AsyncExecutor(object): ...@@ -55,45 +51,19 @@ class AsyncExecutor(object):
They has the exactly same arguments, and expected the same results. They has the exactly same arguments, and expected the same results.
""" """
def __init__(self, def __init__(self, place=None):
program,
param_names,
data_feed,
thread_num,
place=None,
scope=None):
if program is None:
program = default_main_program()
program_desc = program.desc
if not isinstance(data_feed, TextDataFeed):
raise ValueError("data_feed for AsyncExecutor.run() type error")
if place is None: if place is None:
place = core.CPUPlace() place = core.CPUPlace()
if not isinstance(place, core.CPUPlace): if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device") raise ValueError("AsyncExecutor only supports CPU device")
if isinstance(param_names, Variable):
param_names = [param_names]
p = core.Place() p = core.Place()
p.set_place(place) p.set_place(place)
self.executor = core.AsyncExecutor(program_desc, param_names, data_feed.feed, thread_num, p)
def run_startup_program(self,
program=None,
scope=None):
if program is None:
program = default_startup_program()
program_desc = program._get_desc()
if scope is None: scope = global_scope()
scope = g_scope self.executor = core.AsyncExecutor(scope, p)
self.executor.run_startup_program(program_desc, scope) def run(self, program, data_feed, filelist, thread_num, fetch):
def run(self, inspect_vars, scope=None):
""" """
Run program by this Executor. Feed data by feed map, fetch result by fetch_list. Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according Python executor takes a program, add feed operators and fetch operators to this program according
...@@ -136,16 +106,27 @@ class AsyncExecutor(object): ...@@ -136,16 +106,27 @@ class AsyncExecutor(object):
>>> feed={'X': x}, >>> feed={'X': x},
>>> fetch_list=[loss.name]) >>> fetch_list=[loss.name])
""" """
if inspect_vars is not None: if program is None:
if isinstance(inspect_vars, Variable): program = default_main_program()
inspect_vars = [inspect_vars] program_desc = program.desc
inspect_var_names = [var.name for var in inspect_vars]
if data_feed is None:
raise ValueError('ValueError: data_feed should be provided')
if filelist is None:
raise ValueError('ValueError: filelist should be provided')
if isinstance(filelist, str):
filelist = [filelist]
if scope is None: if not isinstance(thread_num, int):
scope = g_scope raise TypeError('TypeError: thread_num should be a positive number')
self.executor.init_root_scope(scope) if fetch is not None:
if isinstance(fetch, Variable):
fetch = [fetch]
fetch_var_names = [var.name for var in fetch]
evaluation = self.executor.run(inspect_var_names) evaluation = self.executor.run_from_files(program_desc, data_feed.desc, filelist, thread_num, fetch_var_names)
return evaluation return evaluation
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册