提交 929a9e80 编写于 作者: W wangguibao

Google naming conventions

上级 c555948c
...@@ -37,13 +37,13 @@ limitations under the License. */ ...@@ -37,13 +37,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::mutex ExecutorThreadWorker::_s_locker_for_pick_file; std::mutex ExecutorThreadWorker::s_locker_for_pick_file_;
unsigned int ExecutorThreadWorker::_s_current_file_idx = 0; unsigned int ExecutorThreadWorker::s_current_file_idx_ = 0;
size_t ExecutorThreadWorker::_s_current_finished_file_cnt = 0; size_t ExecutorThreadWorker::s_current_finished_file_cnt_ = 0;
unsigned int ExecutorThreadWorker::_s_current_epoch = 0; unsigned int ExecutorThreadWorker::s_current_epoch_ = 0;
int ExecutorThreadWorker::_s_current_save_epoch = 0; int ExecutorThreadWorker::s_current_save_epoch_ = 0;
bool ExecutorThreadWorker::_s_is_first_worker = false; bool ExecutorThreadWorker::s_is_first_worker_ = false;
std::vector<std::string> ExecutorThreadWorker::_s_thread_filelist; std::vector<std::string> ExecutorThreadWorker::s_thread_filelist_;
void CreateTensor(Variable* var, proto::VarType::Type var_type) { void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
...@@ -142,33 +142,33 @@ static void save_model( ...@@ -142,33 +142,33 @@ static void save_model(
} // end save_model } // end save_model
void ExecutorThreadWorker::add_train_file(const std::string& file) { void ExecutorThreadWorker::AddTrainFile(const std::string& file) {
_s_thread_filelist.push_back(file); s_thread_filelist_.push_back(file);
} }
void ExecutorThreadWorker::create_thread_operators(const ProgramDesc& program) { void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0); auto& block = program.Block(0);
_op_names.clear(); op_names_.clear();
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc); std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
_op_names.push_back(op_desc->Type()); op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release(); OperatorBase* local_op_ptr = local_op.release();
_ops.push_back(local_op_ptr); ops_.push_back(local_op_ptr);
continue; continue;
} }
} }
void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) { void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0); auto& block = program.Block(0);
_thread_scope = &_root_scope->NewScope(); thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
if (var->Persistable()) { if (var->Persistable()) {
auto* ptr = _root_scope->Var(var->Name()); auto* ptr = root_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType()); CreateTensor(ptr, var->GetType());
// LOGERR("create Persistable var[%s] finished", // LOGERR("create Persistable var[%s] finished",
// var->Name().c_str()); // var->Name().c_str());
} else { } else {
auto* ptr = _thread_scope->Var(var->Name()); auto* ptr = thread_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType()); CreateTensor(ptr, var->GetType());
// LOGERR("create unpersistable var[%s] finished", // LOGERR("create unpersistable var[%s] finished",
// var->Name().c_str()); // var->Name().c_str());
...@@ -176,33 +176,33 @@ void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) { ...@@ -176,33 +176,33 @@ void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) {
} }
} }
void ExecutorThreadWorker::set_datafeed(const std::shared_ptr<DataFeed>& datafeed) { void ExecutorThreadWorker::SetDataFeed(const std::shared_ptr<DataFeed>& datafeed) {
_local_reader = datafeed; local_reader_ = datafeed;
} }
void ExecutorThreadWorker::binding_datafeed_memory() { void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = _local_reader->get_use_slot_alias(); const std::vector<std::string>& input_feed = local_reader_->GetUseSlotAlias();
for (auto name : input_feed) { for (auto name : input_feed) {
_local_reader->add_feed_var(_thread_scope->Var(name), name); local_reader_->AddFeedVar(thread_scope_->Var(name), name);
} }
} }
void ExecutorThreadWorker::set_inspect_var_name( void ExecutorThreadWorker::SetInspectVarName(
const std::string& inspect_var_name) { const std::string& inspect_var_name) {
_inspect_var_name = inspect_var_name; inspect_var_name_ = inspect_var_name;
} }
void ExecutorThreadWorker::set_model_param_names( void ExecutorThreadWorker::SetModelParamNames(
const std::vector<std::string>& param_names) { const std::vector<std::string>& param_names) {
_model_param_names = param_names; model_param_names_ = param_names;
} }
void ExecutorThreadWorker::set_sparse_comm_data( void ExecutorThreadWorker::SetSparseCommData(
const std::map<std::string, int>& param_names) { const std::map<std::string, int>& param_names) {
_sparse_comm_data = param_names; sparse_comm_data_ = param_names;
} }
void ExecutorThreadWorker::set_device() { void ExecutorThreadWorker::SetDevice() {
static unsigned priority[] = { static unsigned priority[] = {
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11,
...@@ -214,7 +214,7 @@ void ExecutorThreadWorker::set_device() { ...@@ -214,7 +214,7 @@ void ExecutorThreadWorker::set_device() {
42, 43, 44, 45, 46, 47 42, 43, 44, 45, 46, 47
}; };
unsigned int i = this->_thread_id; unsigned int i = this->thread_id_;
if (i < sizeof(priority) / sizeof(unsigned)) { if (i < sizeof(priority) / sizeof(unsigned)) {
unsigned proc = priority[i]; unsigned proc = priority[i];
...@@ -235,55 +235,55 @@ void ExecutorThreadWorker::set_device() { ...@@ -235,55 +235,55 @@ void ExecutorThreadWorker::set_device() {
} }
} }
void ExecutorThreadWorker::update_epoch_num() { void ExecutorThreadWorker::UpdateEpochNum() {
_s_current_finished_file_cnt++; s_current_finished_file_cnt_++;
if (_s_current_finished_file_cnt >= _s_thread_filelist.size()) { if (s_current_finished_file_cnt_ >= s_thread_filelist_.size()) {
_s_current_finished_file_cnt = 0; s_current_finished_file_cnt_ = 0;
_s_current_epoch++; s_current_epoch_++;
} }
} }
const char* ExecutorThreadWorker::pick_one_file() { const char* ExecutorThreadWorker::PickOneFile() {
std::string file_to_be_preocessed; std::string file_to_be_preocessed;
std::lock_guard<std::mutex> lock(_s_locker_for_pick_file); std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
if (_s_current_file_idx >= _s_thread_filelist.size()) { if (s_current_file_idx_ >= s_thread_filelist_.size()) {
std::random_shuffle(_s_thread_filelist.begin(), std::random_shuffle(s_thread_filelist_.begin(),
_s_thread_filelist.end()); s_thread_filelist_.end());
_s_current_file_idx = 0; s_current_file_idx_ = 0;
// _s_current_epoch++; //example: when one file, one thread, it's bug // s_current_epoch_++; //example: when one file, one thread, it's bug
LOG(ERROR) << "thread " << _thread_id LOG(ERROR) << "thread " << thread_id_
<< ": finish traing for epoch " << _s_current_epoch + 1; << ": finish traing for epoch " << s_current_epoch_ + 1;
} }
file_to_be_preocessed = _s_thread_filelist[_s_current_file_idx]; file_to_be_preocessed = s_thread_filelist_[s_current_file_idx_];
_s_current_file_idx++; s_current_file_idx_++;
return file_to_be_preocessed.c_str(); return file_to_be_preocessed.c_str();
} }
void ExecutorThreadWorker::train() { void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train"; LOG(ERROR) << "begin to train";
set_device(); SetDevice();
#ifdef LOCAL_PROF #ifdef LOCAL_PROF
std::vector<double> op_total_time; std::vector<double> op_total_time;
std::vector<std::string> op_name; std::vector<std::string> op_name;
// int total_batch = 0; // int total_batch = 0;
for (auto& op : _ops) { for (auto& op : ops_) {
op_name.push_back(op->Type()); op_name.push_back(op->Type());
} }
op_total_time.resize(_ops.size()); op_total_time.resize(ops_.size());
for (int i = 0; i < op_total_time.size(); ++i) { for (int i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0; op_total_time[i] = 0.0;
} }
#endif #endif
std::string inspect_key = "inspect"; std::string inspect_key = "inspect";
if (!_inspect_var_name.empty()) { if (!inspect_var_name_.empty()) {
inspect_key = _inspect_var_name.substr(0, inspect_key = inspect_var_name_.substr(0,
_inspect_var_name.find_first_of('_')); inspect_var_name_.find_first_of('_'));
} }
for (unsigned i = 0; i < _max_epoch; ++i) { for (unsigned i = 0; i < max_epoch_; ++i) {
LOG(ERROR) << "epoch: " << i; LOG(ERROR) << "epoch: " << i;
#ifdef LOCAL_PROF #ifdef LOCAL_PROF
Timer timeline; Timer timeline;
...@@ -292,14 +292,14 @@ void ExecutorThreadWorker::train() { ...@@ -292,14 +292,14 @@ void ExecutorThreadWorker::train() {
#endif #endif
float total_inspect = 0; float total_inspect = 0;
int batch_num = 1; int batch_num = 1;
while (i == _s_current_epoch) { while (i == s_current_epoch_) {
const char* filename = pick_one_file(); const char* filename = PickOneFile();
_local_reader->set_file(filename); local_reader_->SetFile(filename);
while (true) { while (true) {
#ifdef LOCAL_PROF #ifdef LOCAL_PROF
timeline.start(); timeline.start();
#endif #endif
bool flag = _local_reader->read_batch(); bool flag = local_reader_->ReadBatch();
if (!flag) { if (!flag) {
break; break;
} }
...@@ -312,11 +312,11 @@ void ExecutorThreadWorker::train() { ...@@ -312,11 +312,11 @@ void ExecutorThreadWorker::train() {
break; break;
} }
for (unsigned int i = 0; i < _ops.size(); ++i) { for (unsigned int i = 0; i < ops_.size(); ++i) {
#ifdef LOCAL_PROF #ifdef LOCAL_PROF
timeline.start(); timeline.start();
#endif #endif
_ops[i]->Run(*_thread_scope, _place); ops_[i]->Run(*thread_scope_, place_);
#ifdef LOCAL_PROF #ifdef LOCAL_PROF
timeline.pause(); timeline.pause();
op_total_time[i] += timeline.elapsed_sec(); op_total_time[i] += timeline.elapsed_sec();
...@@ -325,17 +325,17 @@ void ExecutorThreadWorker::train() { ...@@ -325,17 +325,17 @@ void ExecutorThreadWorker::train() {
} }
batch_num++; batch_num++;
float avg_inspect = 0.0; float avg_inspect = 0.0;
if (!_inspect_var_name.empty()) { if (!inspect_var_name_.empty()) {
avg_inspect = _thread_scope->FindVar(_inspect_var_name) avg_inspect = thread_scope_->FindVar(inspect_var_name_)
->GetMutable<LoDTensor>() ->GetMutable<LoDTensor>()
->data<float>()[0]; ->data<float>()[0];
} }
total_inspect += avg_inspect; total_inspect += avg_inspect;
_thread_scope->DropKids(); thread_scope_->DropKids();
} }
update_epoch_num(); UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << i + 1 LOG(ERROR) << "memory used after epoch " << i + 1
<< " called: " << memory::memory_usage(_place); << " called: " << memory::memory_usage(place_);
#ifdef LOCAL_PROF #ifdef LOCAL_PROF
for (int i = 0; i < op_total_time.size(); ++i) { for (int i = 0; i < op_total_time.size(); ++i) {
...@@ -354,12 +354,12 @@ void ExecutorThreadWorker::train() { ...@@ -354,12 +354,12 @@ void ExecutorThreadWorker::train() {
LOG(ERROR) << "mean " << inspect_key.c_str() LOG(ERROR) << "mean " << inspect_key.c_str()
<< " of epoch " << i + 1 << ": " << total_inspect / batch_num; << " of epoch " << i + 1 << ": " << total_inspect / batch_num;
#endif #endif
if (_thread_id == 0) { if (thread_id_ == 0) {
char modelfile[1024]; char modelfile[1024];
snprintf(&modelfile[0], snprintf(&modelfile[0],
sizeof(modelfile), sizeof(modelfile),
"%s_epoch%d.model", "%s_epoch%d.model",
_model_prefix.c_str(), model_prefix_.c_str(),
i); i);
std::string model_filename = std::string(modelfile); std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this // this save_inference_model can only save imdbtask, should make this
...@@ -367,55 +367,55 @@ void ExecutorThreadWorker::train() { ...@@ -367,55 +367,55 @@ void ExecutorThreadWorker::train() {
// //
// currently comment it // currently comment it
LOG(ERROR) << "Going to save model " << modelfile; LOG(ERROR) << "Going to save model " << modelfile;
save_model(_main_program, save_model(main_program_,
_thread_scope, thread_scope_,
_model_param_names, model_param_names_,
model_filename, model_filename,
true); true);
} }
} }
} }
void ExecutorThreadWorker::set_thread_id(int tid) { void ExecutorThreadWorker::SetThreadId(int tid) {
_thread_id = tid; thread_id_ = tid;
} }
void ExecutorThreadWorker::set_place(const platform::Place& place) { void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
_place = place; place_ = place;
} }
void ExecutorThreadWorker::set_main_program( void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) { const ProgramDesc& main_program_desc) {
_main_program.reset(new ProgramDesc(main_program_desc)); main_program_.reset(new ProgramDesc(main_program_desc));
} }
void ExecutorThreadWorker::set_root_scope(Scope* g_scope) { void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
_root_scope = g_scope; root_scope_ = g_scope;
} }
void ExecutorThreadWorker::set_max_training_epoch(int max_epoch) { void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
_max_epoch = max_epoch; max_epoch_ = max_epoch;
} }
MultiExecutor::MultiExecutor(const platform::Place& place) : _place(place) {} MultiExecutor::MultiExecutor(const platform::Place& place) : place_(place) {}
void MultiExecutor::init_root_scope(Scope* scope) { void MultiExecutor::InitRootScope(Scope* scope) {
_root_scope = scope; root_scope_ = scope;
} }
void MultiExecutor::set_max_training_epoch(int max_epoch) { void MultiExecutor::SetMaxTrainingEpoch(int max_epoch) {
_max_epoch = max_epoch; max_epoch_ = max_epoch;
} }
void MultiExecutor::set_datafeed_name(const char* feedname) { void MultiExecutor::SetDataFeedName(const char* feedname) {
_feed_name = std::string(feedname); feed_name_ = std::string(feedname);
} }
void MultiExecutor::set_model_prefix(const std::string& model_prefix) { void MultiExecutor::SetModelPrefix(const std::string& model_prefix) {
_model_prefix = model_prefix; model_prefix_ = model_prefix;
} }
void MultiExecutor::run_startup_program(const ProgramDesc& program, void MultiExecutor::RunStartupProgram(const ProgramDesc& program,
Scope* scope) { Scope* scope) {
auto& block = program.Block(0); auto& block = program.Block(0);
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
...@@ -447,7 +447,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program, ...@@ -447,7 +447,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// param_dict.size(), ops.size()); // param_dict.size(), ops.size());
for (auto& op : ops) { for (auto& op : ops) {
op->Run(*scope, _place); op->Run(*scope, place_);
} }
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec()); // LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for (auto& op : ops) { for (auto& op : ops) {
...@@ -456,7 +456,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program, ...@@ -456,7 +456,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// LOGERR("run startup program done."); // LOGERR("run startup program done.");
} }
std::unique_ptr<ProgramDesc> MultiExecutor::load_desc_from_file( std::unique_ptr<ProgramDesc> MultiExecutor::LoadDescFromFile(
const std::string& f) { const std::string& f) {
std::string program_desc_str; std::string program_desc_str;
read_binary_file(f, &program_desc_str); read_binary_file(f, &program_desc_str);
...@@ -464,102 +464,102 @@ std::unique_ptr<ProgramDesc> MultiExecutor::load_desc_from_file( ...@@ -464,102 +464,102 @@ std::unique_ptr<ProgramDesc> MultiExecutor::load_desc_from_file(
return program; return program;
} }
void MultiExecutor::set_dense_comm_tensor( void MultiExecutor::SetDenseCommTensor(
const std::vector<std::string>& dense_comm_tensor) { const std::vector<std::string>& dense_comm_tensor) {
_dense_comm_tensor.resize(dense_comm_tensor.size()); dense_comm_tensor_.resize(dense_comm_tensor.size());
for (unsigned int i = 0; i < dense_comm_tensor.size(); ++i) { for (unsigned int i = 0; i < dense_comm_tensor.size(); ++i) {
_dense_comm_tensor[i] = dense_comm_tensor[i]; dense_comm_tensor_[i] = dense_comm_tensor[i];
} }
} }
void MultiExecutor::set_sparse_comm_tensor( void MultiExecutor::SetSparseCommTensor(
const std::vector<std::string>& sparse_comm_tensor) { const std::vector<std::string>& sparse_comm_tensor) {
_sparse_comm_tensor.resize(sparse_comm_tensor.size()); sparse_comm_tensor_.resize(sparse_comm_tensor.size());
for (unsigned int i = 0; i < sparse_comm_tensor.size(); ++i) { for (unsigned int i = 0; i < sparse_comm_tensor.size(); ++i) {
_sparse_comm_tensor[i] = sparse_comm_tensor[i]; sparse_comm_tensor_[i] = sparse_comm_tensor[i];
} }
} }
void MultiExecutor::set_sparse_comm_data( void MultiExecutor::SetSparseCommData(
const std::map<std::string, int>& sparse_comm_data) { const std::map<std::string, int>& sparse_comm_data) {
_sparse_comm_data = sparse_comm_data; sparse_comm_data_ = sparse_comm_data;
LOG(INFO) << "Sparse comm data: " << _sparse_comm_data.size(); LOG(INFO) << "Sparse comm data: " << sparse_comm_data_.size();
} }
void MultiExecutor::set_filelist(const char* filelist) { void MultiExecutor::SetFileList(const char* filelist) {
_filelist.clear(); filelist_.clear();
std::ifstream fin(filelist); std::ifstream fin(filelist);
std::string filename; std::string filename;
while (fin >> filename) { while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist"; LOG(ERROR) << "add " << filename.c_str() << " to filelist";
_filelist.push_back(filename); filelist_.push_back(filename);
} }
fin.close(); fin.close();
} }
void MultiExecutor::set_filelist(std::vector<std::string> tfiles) { void MultiExecutor::SetFileList(std::vector<std::string> tfiles) {
_filelist.clear(); filelist_.clear();
_filelist.insert(_filelist.end(), tfiles.begin(), tfiles.end()); filelist_.insert(filelist_.end(), tfiles.begin(), tfiles.end());
return; return;
} }
void MultiExecutor::set_inspect_var_name(const std::string& inspect_var_name) { void MultiExecutor::SetInspectVarName(const std::string& inspect_var_name) {
_inspect_var_name = inspect_var_name; inspect_var_name_ = inspect_var_name;
} }
void MultiExecutor::set_param_names(const std::vector<std::string>& param_names) { void MultiExecutor::SetParamNames(const std::vector<std::string>& param_names) {
_model_param_names = param_names; model_param_names_ = param_names;
} }
void MultiExecutor::set_thread_num(const int thread_num) { void MultiExecutor::SetThreadNum(const int thread_num) {
_thread_num = thread_num; thread_num_ = thread_num;
} }
void MultiExecutor::prepare_threads(const ProgramDesc& host_program) { void MultiExecutor::PrepareThreads(const ProgramDesc& host_program) {
_workers.resize(_thread_num); workers_.resize(thread_num_);
for (unsigned i = 0; i < _thread_num; ++i) { for (unsigned i = 0; i < thread_num_; ++i) {
_workers[i].reset(new ExecutorThreadWorker); workers_[i].reset(new ExecutorThreadWorker);
_workers[i]->set_thread_id(i); workers_[i]->SetThreadId(i);
_workers[i]->create_thread_operators(host_program); workers_[i]->CreateThreadOperators(host_program);
_workers[i]->set_root_scope(_root_scope); workers_[i]->SetRootScope(root_scope_);
_workers[i]->set_place(_place); workers_[i]->SetPlace(place_);
_workers[i]->set_max_training_epoch(_max_epoch); workers_[i]->SetMaxTrainingEpoch(max_epoch_);
_workers[i]->create_thread_scope(host_program); workers_[i]->CreateThreadScope(host_program);
_workers[i]->set_inspect_var_name(_inspect_var_name); workers_[i]->SetInspectVarName(inspect_var_name_);
_workers[i]->set_model_param_names(_model_param_names); workers_[i]->SetModelParamNames(model_param_names_);
_workers[i]->set_sparse_comm_data(_sparse_comm_data); workers_[i]->SetSparseCommData(sparse_comm_data_);
_workers[i]->set_main_program(host_program); workers_[i]->SetMainProgram(host_program);
_workers[i]->set_model_prefix(_model_prefix); workers_[i]->SetModelPrefix(model_prefix_);
} }
for (unsigned i = 0; i < _filelist.size(); ++i) { for (unsigned i = 0; i < filelist_.size(); ++i) {
// suppose at least one trainer thread here, and // suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once // filelist is static so that we only add filelist once
_workers[0]->add_train_file(_filelist[i]); workers_[0]->AddTrainFile(filelist_[i]);
} }
// mpi_wrapper::ModelParam model_param(true); // mpi_wrapper::ModelParam model_param(true);
// _workers[0]->register_parallel_training_param(model_param); // workers_[0]->register_parallel_training_param(model_param);
for (unsigned i = 0; i < _thread_num; ++i) { for (unsigned i = 0; i < thread_num_; ++i) {
// new a datafeed here // new a datafeed here
std::shared_ptr<DataFeed> local_feed = create_datafeed(_feed_name.c_str()); std::shared_ptr<DataFeed> local_feed = CreateDataFeed(feed_name_.c_str());
local_feed->init(_data_feed_param); local_feed->Init(data_feed_param_);
local_feed->set_batch_size(_batch_size); local_feed->SetBatchSize(batch_size_);
_workers[i]->set_datafeed(local_feed); workers_[i]->SetDataFeed(local_feed);
_workers[i]->binding_datafeed_memory(); workers_[i]->BindingDataFeedMemory();
_workers[i]->set_thread_id(i); workers_[i]->SetThreadId(i);
} }
} }
void MultiExecutor::run_multi_executor(const ProgramDesc& host_program) { void MultiExecutor::RunMultiExecutor(const ProgramDesc& host_program) {
// thread binding here? // thread binding here?
prepare_threads(host_program); PrepareThreads(host_program);
for (unsigned i = 0; i < _thread_num; ++i) { for (unsigned i = 0; i < thread_num_; ++i) {
_threads.push_back(std::thread(&ExecutorThreadWorker::train, threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
_workers[i].get())); workers_[i].get()));
} }
for (auto& th : _threads) { for (auto& th : threads_) {
th.join(); th.join();
} }
} }
......
...@@ -36,137 +36,129 @@ class ExecutorThreadWorker { ...@@ -36,137 +36,129 @@ class ExecutorThreadWorker {
public: public:
ExecutorThreadWorker() {} ExecutorThreadWorker() {}
virtual ~ExecutorThreadWorker() {} virtual ~ExecutorThreadWorker() {}
void create_thread_scope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
void set_datafeed(const DataFeed& datafeed); void SetDataFeed(const DataFeed& datafeed);
void set_thread_id(int tid); void SetThreadId(int tid);
void create_thread_operators(const framework::ProgramDesc& program); void CreateThreadOperators(const framework::ProgramDesc& program);
void set_root_scope(Scope* g_scope); void SetRootScope(Scope* g_scope);
void set_device(); void SetDevice();
virtual void add_fid_set(); virtual void AddFidSet();
void set_comm_batch(int comm_batch) { _comm_batch = comm_batch; } void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
void add_train_file(const std::string& filename); void AddTrainFile(const std::string& filename);
void set_main_program(const ProgramDesc& main_program_desc); void SetMainProgram(const ProgramDesc& main_program_desc);
void set_place(const paddle::platform::Place& place); void SetPlace(const paddle::platform::Place& place);
void set_max_training_epoch(const int max_epoch); void SetMaxTrainingEpoch(const int max_epoch);
void binding_datafeed_memory(); void BindingDataFeedMemory();
void set_model_prefix(const std::string& prefix) { _model_prefix = prefix; } void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void set_inspect_var_name(const std::string& inspect_var_name); void SetInspectVarName(const std::string& inspect_var_name);
void set_model_param_names(const std::vector<std::string>& param_names); void SetModelParamNames(const std::vector<std::string>& param_names);
void set_sparse_comm_data(const std::map<std::string, int>& param_names); void SetSparseCommData(const std::map<std::string, int>& param_names);
void set_datafeed(const std::shared_ptr<DataFeed>& datafeed); void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
virtual void mpi_train(); void Train();
void gpu_train(); virtual const char* PickOneFile();
void train(); void UpdateEpochNum();
virtual const char* pick_one_file();
void update_epoch_num(); virtual void SetDenseCommTensor(
virtual void set_dense_comm_tensor(
const std::vector<std::string>& param_names) {} const std::vector<std::string>& param_names) {}
virtual void initialize() {} virtual void Initialize() {}
public: public:
static std::mutex _s_locker_for_pick_file; static std::mutex s_locker_for_pick_file_;
static unsigned int _s_current_file_idx; static unsigned int s_current_file_idx_;
static size_t _s_current_finished_file_cnt; static size_t s_current_finished_file_cnt_;
static unsigned int _s_current_epoch; static unsigned int s_current_epoch_;
static int _s_current_save_epoch; static int s_current_save_epoch_;
static std::vector<std::string> _s_thread_filelist; // filelist static std::vector<std::string> s_thread_filelist_; // filelist
static bool _s_is_first_worker; static bool s_is_first_worker_;
protected: protected:
// thread index // thread index
int _thread_id; int thread_id_;
// current training file
int _cur_fileidx;
// max epoch for each thread // max epoch for each thread
unsigned int _max_epoch; unsigned int max_epoch_;
// instances learned currently // instances learned currently
int _comm_batch; int comm_batch_;
std::string _model_prefix; std::string model_prefix_;
std::vector<std::string> _op_names; std::vector<std::string> op_names_;
// local ops for forward and backward // local ops for forward and backward
std::vector<OperatorBase *> _ops; std::vector<OperatorBase *> ops_;
// main program for training // main program for training
std::unique_ptr<framework::ProgramDesc> _main_program; std::unique_ptr<framework::ProgramDesc> main_program_;
// binary data reader // binary data reader
std::shared_ptr<DataFeed> _local_reader; std::shared_ptr<DataFeed> local_reader_;
std::string _inspect_var_name; std::string inspect_var_name_;
std::vector<std::string> _model_param_names; std::vector<std::string> model_param_names_;
std::map<std::string, int> _sparse_comm_data; std::map<std::string, int> sparse_comm_data_;
std::vector<int> _ids_buffer;
// execution place // execution place
platform::Place _place; platform::Place place_;
// root scope for model parameters // root scope for model parameters
Scope* _root_scope; Scope* root_scope_;
// a thread scope, father scope is global score which is shared // a thread scope, father scope is global score which is shared
Scope* _thread_scope; Scope* thread_scope_;
}; };
class MultiExecutor { class MultiExecutor {
public: public:
explicit MultiExecutor(const platform::Place& place); explicit MultiExecutor(const platform::Place& place);
virtual ~MultiExecutor() {} virtual ~MultiExecutor() {}
static std::unique_ptr<ProgramDesc> load_desc_from_file( static std::unique_ptr<ProgramDesc> LoadDescFromFile(
const std::string& filename); const std::string& filename);
void init_root_scope(Scope* scope); void InitRootScope(Scope* scope);
void set_inspect_var_name(const std::string& inspect_var_name); void SetInspectVarName(const std::string& inspect_var_name);
void set_param_names(const std::vector<std::string>& param_names); void SetParamNames(const std::vector<std::string>& param_names);
void set_max_training_epoch(const int max_epoch); void SetMaxTrainingEpoch(const int max_epoch);
Scope* get_root_scope() { return _root_scope; } Scope* GetRootScope() { return root_scope_; }
void set_thread_num(const int thread_num); void SetThreadNum(const int thread_num);
void set_batch_size(const int batch_size) { _batch_size = batch_size; } void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
void set_filelist(const char* filelist); void SetFileList(const char* filelist);
void set_filelist(const std::vector<std::string> filelist); void SetFileList(const std::vector<std::string> filelist);
void set_datafeed_name(const char* feedname); void SetDataFeedName(const char* feedname);
void set_data_feed_param(const datafeed::DataFeedParameter& feed_param) { void SetDataFeedParam(const datafeed::DataFeedParameter& feed_param) {
_data_feed_param = feed_param; data_feed_param_ = feed_param;
} }
void set_comm_batch(int comm_batch) { void SetCommBatch(int comm_batch) {
_comm_batch = comm_batch; comm_batch_ = comm_batch;
} }
void set_model_prefix(const std::string& model_prefix); void SetModelPrefix(const std::string& model_prefix);
void set_dense_comm_tensor(const std::vector<std::string>& dense_comm_tensor); void SetDenseCommTensor(const std::vector<std::string>& dense_comm_tensor);
void set_sparse_comm_tensor( void SetSparseCommTensor(
const std::vector<std::string>& sparse_comm_tensor); const std::vector<std::string>& sparse_comm_tensor);
void set_sparse_comm_data(const std::map<std::string, int>& sparse_comm_data); void SetSparseCommData(const std::map<std::string, int>& sparse_comm_data);
virtual void prepare_threads(const framework::ProgramDesc& host_program); virtual void PrepareThreads(const framework::ProgramDesc& host_program);
void run_startup_program(const framework::ProgramDesc& program, void RunStartupProgram(const framework::ProgramDesc& program,
framework::Scope* scope); framework::Scope* scope);
void run_multi_executor(const ProgramDesc& host_program); void RunMultiExecutor(const ProgramDesc& host_program);
public: public:
unsigned int _thread_num; unsigned int thread_num_;
datafeed::DataFeedParameter _data_feed_param; datafeed::DataFeedParameter data_feed_param_;
int _max_epoch; int max_epoch_;
int _batch_size; int batch_size_;
int _comm_batch; int comm_batch_;
std::vector<std::shared_ptr<ExecutorThreadWorker> > _workers; std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
std::vector<std::thread> _threads; std::vector<std::thread> threads_;
std::vector<std::string> _filelist; std::vector<std::string> filelist_;
std::string _inspect_var_name; std::string inspect_var_name_;
std::vector<std::string> _model_param_names; std::vector<std::string> model_param_names_;
std::vector<std::string> _dense_comm_tensor; std::vector<std::string> dense_comm_tensor_;
std::vector<std::string> _sparse_comm_tensor; std::vector<std::string> sparse_comm_tensor_;
std::map<std::string, int> _sparse_comm_data; std::map<std::string, int> sparse_comm_data_;
int node_num; std::string model_prefix_;
std::string _model_prefix; std::string feed_name_;
ProgramDesc _host_program; Scope* root_scope_;
std::string _feed_name; platform::Place place_;
Scope* _root_scope;
platform::Place _place;
}; };
} // namespace framework } // namespace framework
......
...@@ -38,111 +38,114 @@ DEFINE_bool(is_text_feed, false, "is_text_feed"); ...@@ -38,111 +38,114 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void TextClassDataFeed::init(const datafeed::DataFeedParameter& feed_param) { void TextClassDataFeed::Init(const datafeed::DataFeedParameter& feed_param) {
// hard coding for a specific datafeed // hard coding for a specific datafeed
_feed_vec.resize(2); feed_vec_.resize(2);
// _feed_vec[0].reset(new LoDTensor); // feed_vec_[0].reset(new LoDTensor);
// _feed_vec[1].reset(new LoDTensor); // feed_vec_[1].reset(new LoDTensor);
_all_slot_ids = {0, 1}; all_slot_ids_ = {0, 1};
_use_slot_ids = {0, 1}; use_slot_ids_ = {0, 1};
_use_slot_alias = {"words", "label"}; use_slot_alias_ = {"words", "label"};
_file_content_buffer_host.reset(new char[200*1024*1024],
file_content_buffer_host_.reset(new char[200*1024*1024],
[](char *p) {delete[] p;}); [](char *p) {delete[] p;});
_file_content_buffer = _file_content_buffer_host.get(); file_content_buffer_ = file_content_buffer_host_.get();
_file_content_buffer_ptr = _file_content_buffer; file_content_buffer_ptr_ = file_content_buffer_;
_batch_id_host.reset(new int[10240*1024],
batch_id_host_.reset(new int[10240*1024],
[](int *p) {delete[] p;}); // max word num in a batch [](int *p) {delete[] p;}); // max word num in a batch
_label_host.reset(new int[10240], batch_id_buffer_ = batch_id_host_.get();
label_host_.reset(new int[10240],
[](int *p) {delete[] p;}); // max label in a batch [](int *p) {delete[] p;}); // max label in a batch
_batch_id_buffer = _batch_id_host.get(); label_ptr_ = label_host_.get();
_label_ptr = _label_host.get();
} }
// todo: use elegant implemention for this function // todo: use elegant implemention for this function
bool TextClassDataFeed::read_batch() { bool TextClassDataFeed::ReadBatch() {
paddle::framework::Vector<size_t> offset; paddle::framework::Vector<size_t> offset;
int tlen = 0; int tlen = 0;
int llen = 0; int llen = 0;
int inst_idx = 0; int inst_idx = 0;
offset.resize(_batch_size + 1); offset.resize(batch_size_ + 1);
offset[0] = 0; offset[0] = 0;
while (inst_idx < _batch_size) { while (inst_idx < batch_size_) {
int ptr_offset = 0; int ptr_offset = 0;
if (_file_content_buffer_ptr - _file_content_buffer >= _file_size) { if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) {
break; break;
} }
memcpy(reinterpret_cast<char *>(&llen), memcpy(reinterpret_cast<char *>(&llen),
_file_content_buffer_ptr + ptr_offset, file_content_buffer_ptr_ + ptr_offset,
sizeof(int)); sizeof(int));
ptr_offset += sizeof(int); ptr_offset += sizeof(int);
memcpy(reinterpret_cast<char *>(_batch_id_buffer + tlen), memcpy(reinterpret_cast<char *>(batch_id_buffer_ + tlen),
_file_content_buffer_ptr + ptr_offset, file_content_buffer_ptr_ + ptr_offset,
llen * sizeof(int)); llen * sizeof(int));
tlen += llen; tlen += llen;
offset[inst_idx + 1] = offset[inst_idx] + llen; offset[inst_idx + 1] = offset[inst_idx] + llen;
ptr_offset += sizeof(int) * llen; ptr_offset += sizeof(int) * llen;
memcpy(reinterpret_cast<char *>(_label_ptr + inst_idx), memcpy(reinterpret_cast<char *>(label_ptr_ + inst_idx),
_file_content_buffer_ptr + ptr_offset, file_content_buffer_ptr_ + ptr_offset,
sizeof(int)); sizeof(int));
ptr_offset += sizeof(int); ptr_offset += sizeof(int);
_file_content_buffer_ptr += ptr_offset; file_content_buffer_ptr_ += ptr_offset;
inst_idx++; inst_idx++;
} }
if (inst_idx != _batch_size) { if (inst_idx != batch_size_) {
return false; return false;
} }
LoD input_lod{offset}; LoD input_lod{offset};
paddle::framework::Vector<size_t> label_offset; paddle::framework::Vector<size_t> label_offset;
label_offset.resize(_batch_size + 1); label_offset.resize(batch_size_ + 1);
for (int i = 0; i <= _batch_size; ++i) { for (int i = 0; i <= batch_size_; ++i) {
label_offset[i] = i; label_offset[i] = i;
} }
LoD label_lod{label_offset}; LoD label_lod{label_offset};
int64_t* input_ptr = _feed_vec[0]->mutable_data<int64_t>( int64_t* input_ptr = feed_vec_[0]->mutable_data<int64_t>(
{static_cast<int64_t>(offset.back()), 1}, {static_cast<int64_t>(offset.back()), 1},
platform::CPUPlace()); platform::CPUPlace());
int64_t* label_ptr = _feed_vec[1]->mutable_data<int64_t>({_batch_size, 1}, int64_t* label_ptr = feed_vec_[1]->mutable_data<int64_t>({batch_size_, 1},
platform::CPUPlace()); platform::CPUPlace());
for (unsigned int i = 0; i < offset.back(); ++i) { for (unsigned int i = 0; i < offset.back(); ++i) {
input_ptr[i] = static_cast<int64_t>(_batch_id_buffer[i]); input_ptr[i] = static_cast<int64_t>(batch_id_buffer_[i]);
} }
for (int i = 0; i < _batch_size; ++i) { for (int i = 0; i < batch_size_; ++i) {
label_ptr[i] = static_cast<int64_t>(_label_ptr[i]); label_ptr[i] = static_cast<int64_t>(label_ptr_[i]);
} }
_feed_vec[0]->set_lod(input_lod); feed_vec_[0]->set_lod(input_lod);
_feed_vec[1]->set_lod(label_lod); feed_vec_[1]->set_lod(label_lod);
return true; return true;
} }
void TextClassDataFeed::add_feed_var(Variable* feed, const std::string& name) { void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < _use_slot_alias.size(); ++i) { for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == _use_slot_alias[i]) { if (name == use_slot_alias_[i]) {
_feed_vec[i] = feed->GetMutable<LoDTensor>(); feed_vec_[i] = feed->GetMutable<LoDTensor>();
} }
} }
} }
bool TextClassDataFeed::set_file(const char* filename) { bool TextClassDataFeed::SetFile(const char* filename) {
// termnum termid termid ... termid label // termnum termid termid ... termid label
int filesize = read_whole_file(filename, _file_content_buffer); int filesize = ReadWholeFile(filename, file_content_buffer_);
// todo , remove magic number // todo , remove magic number
if (filesize < 0 || filesize >= 1024 * 1024 * 1024) { if (filesize < 0 || filesize >= 1024 * 1024 * 1024) {
return false; return false;
} }
_file_content_buffer_ptr = _file_content_buffer; file_content_buffer_ptr_ = file_content_buffer_;
_file_size = filesize; file_size_ = filesize;
return true; return true;
} }
int TextClassDataFeed::read_whole_file(const std::string& filename, int TextClassDataFeed::ReadWholeFile(const std::string& filename,
char* buffer) { char* buffer) {
std::ifstream ifs(filename.c_str(), std::ios::binary); std::ifstream ifs(filename.c_str(), std::ios::binary);
if (ifs.fail()) { if (ifs.fail()) {
......
...@@ -35,47 +35,6 @@ limitations under the License. */ ...@@ -35,47 +35,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
typedef uint64_t FeatureKey;
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign_, uint16_t slot_) {
sign() = sign_;
slot() = slot_;
}
FeatureKey& sign() {
return *(reinterpret_cast<FeatureKey*>(sign_buffer()));
}
const FeatureKey& sign() const {
return *(const FeatureKey*)sign_buffer();
}
uint16_t& slot() {
return _slot;
}
const uint16_t& slot() const {
return _slot;
}
private:
char _sign[sizeof(FeatureKey)];
uint16_t _slot;
char* sign_buffer() const {
return (char *)_sign;
}
};
// Record(average:14031B) is smaller than Sample(average:16530B)
struct Record {
int show, click;
std::vector<FeatureItem> feas;
std::string lineid;
std::string tags;
};
struct Gauc { struct Gauc {
int show, click; int show, click;
uint64_t fea; uint64_t fea;
...@@ -89,241 +48,83 @@ struct Instance { ...@@ -89,241 +48,83 @@ struct Instance {
std::vector<Gauc> gauc_vec; std::vector<Gauc> gauc_vec;
}; };
struct Sample {
uint64_t label;
std::map<uint16_t, std::vector<uint64_t>> feas;
bool from_string(const std::string& input, const std::set<uint32_t>& slots) {
size_t end = input.find_first_of(' ');
if (end == std::string::npos) {
LOG(ERROR) << "[ERROR] Fail in parsing:" << input;
return false;
}
label = input[end + 3] - '0';
CHECK(label == 0 || label == 1) << "invalid label:" << label;
std::stringstream ss(input);
std::string token;
uint16_t slot_id = 0;
uint64_t feature_id = 0;
int num_nonfeas_token = 0;
std::ostringstream os;
while (ss >> token) {
size_t end = token.find_first_of(':');
if (end == std::string::npos) {
++num_nonfeas_token;
continue;
}
try {
slot_id = stoi(token.substr(end + 1));
} catch (...) {
LOG(ERROR) << "Error in parsing slot id:" << token;
return false;
}
try {
feature_id = stoull(token.substr(0, end));
} catch (...) {
LOG(ERROR) << "Error in parsing feature id:" << token;
return false;
}
if (slot_id <= 0) {
LOG(ERROR) << "invalid slot:" << slot_id << " feasign:" << feature_id
<< " line:" << input;
return false;
}
if (slots.find(slot_id) == slots.end()) {
continue;
}
feas[slot_id].push_back(feature_id);
}
if (num_nonfeas_token != 4) {
LOG(ERROR) << "Format error. Invalid number of non-feasign token:"
<< num_nonfeas_token;
return false;
}
return true;
}
};
struct TeacherStudentSample {
uint64_t label;
std::map<uint16_t, std::vector<uint64_t>> feas;
float q_score;
void print() {
LOG(ERROR) << "label: " << label << " score: " << q_score;
for (auto &slot : feas) {
for (auto &fea : slot.second) {
LOG(ERROR) << "slot: " << slot.first << " fea: " << fea;
}
}
}
bool from_string(const std::string& input,
const std::set<uint32_t>& slots,
Gauc& gauc) { // NOLINT
size_t end = input.find_first_of(' ');
if (end == std::string::npos) {
LOG(ERROR) << "[ERROR] Fail in parsing:" << input;
return false;
}
label = input[end + 3] - '0';
CHECK(label == 0 || label == 1) << "invalid label:" << label;
gauc.show = 1;
gauc.click = label;
gauc.lineid = input.substr(0, end);
gauc.fea = 0;
size_t dnn_start = input.find("*");
if (dnn_start == std::string::npos) {
q_score = -1.0;
} else {
dnn_start += 1;
size_t dnn_end = input.find(' ', dnn_start);
q_score = static_cast<float>(
atof(input.substr(dnn_start, dnn_end - dnn_start).c_str()));
}
size_t head_pos = input.find("\t");
std::string head = input.substr(0, head_pos);
std::stringstream ss(head);
std::string token;
uint16_t slot_id = 0;
uint64_t feature_id = 0;
int num_nonfeas_token = 0;
std::ostringstream os;
while (ss >> token) {
size_t end = token.find_first_of(':');
if (end == std::string::npos) {
++num_nonfeas_token;
continue;
}
try {
slot_id = stoi(token.substr(end + 1));
} catch (...) {
LOG(ERROR) << "Error in parsing slot id:" << token;
return false;
}
try {
feature_id = stoull(token.substr(0, end));
} catch (...) {
LOG(ERROR) << "Error in parsing feature id:" << token;
return false;
}
if (slot_id <= 0) {
LOG(ERROR) << "invalid slot:" << slot_id << " feasign:" << feature_id
<< " line:" << input;
return false;
}
if (slots.find(slot_id) == slots.end()) {
continue;
}
if (slot_id == 6048) {
gauc.fea = feature_id;
}
feas[slot_id].push_back(feature_id);
}
if (num_nonfeas_token != 4) {
LOG(ERROR) << "Format error. Invalid number of non-feasign token:"
<< num_nonfeas_token;
return false;
}
return true;
}
};
class DataFeed { class DataFeed {
public: public:
DataFeed() {} DataFeed() {}
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void init(const datafeed::DataFeedParameter& feed_param) = 0; virtual void Init(const datafeed::DataFeedParameter& feed_param) = 0;
/* /*
* This function will be used to check file format. * This function will be used to check file format.
* Considering that this function may be used alone, * Considering that this function may be used alone,
* it does not check anything. * it does not check anything.
* */ * */
virtual bool check_file(const char* filename) = 0; virtual bool CheckFile(const char* filename) = 0;
virtual bool set_file(const char* filename) = 0; virtual bool SetFile(const char* filename) = 0;
virtual bool read_batch() = 0; virtual bool ReadBatch() = 0;
virtual const std::vector<uint16_t>& get_all_slot_ids() { virtual const std::vector<uint16_t>& GetAllSlotIds() {
return _all_slot_ids; return all_slot_ids_;
} }
virtual const std::vector<uint16_t>& get_use_slot_ids() { virtual const std::vector<uint16_t>& GetUseSlotIds() {
return _use_slot_ids; return use_slot_ids_;
} }
virtual const std::vector<std::string>& get_use_slot_alias() { virtual const std::vector<std::string>& GetUseSlotAlias() {
return _use_slot_alias; return use_slot_alias_;
} }
virtual void add_feed_var(Variable* var, virtual void AddFeedVar(Variable* var,
const std::string& name) = 0; const std::string& name) = 0;
virtual void bind_scope(Scope* scope) = 0; virtual void BindScope(Scope* scope) = 0;
virtual void set_batch_size(int batch) { _default_batch_size = batch; } virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int get_batch_size() { return _batch_size; } virtual int GetBatchSize() { return batch_size_; }
virtual void set_buffer_size(int buffer_size) {} virtual void SetBufferSize(int buffer_size) {}
std::vector<LoDTensor*>& get_feed_vec() { std::vector<LoDTensor*>& GetFeedVec() {
return _feed_vec; return feed_vec_;
} }
virtual std::vector<LoDTensor*>& get_feed_vec(const Instance& ins) { virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
LOG(ERROR) << "use defalut get_feed_vec"; LOG(ERROR) << "use defalut get_feed_vec";
return _feed_vec; return feed_vec_;
} }
protected: protected:
std::vector<uint16_t> _all_slot_ids; std::vector<uint16_t> all_slot_ids_;
std::vector<uint16_t> _use_slot_ids; std::vector<uint16_t> use_slot_ids_;
std::vector<std::string> _use_slot_alias; std::vector<std::string> use_slot_alias_;
std::vector<LoDTensor*> _feed_vec; std::vector<LoDTensor*> feed_vec_;
int _default_batch_size; int default_batch_size_;
int _batch_size; int batch_size_;
}; };
class TextClassDataFeed : public DataFeed { class TextClassDataFeed : public DataFeed {
public: public:
virtual ~TextClassDataFeed() {} virtual ~TextClassDataFeed() {}
virtual void init(const datafeed::DataFeedParameter& feed_param); virtual void Init(const datafeed::DataFeedParameter& feed_param);
virtual bool read_batch(); virtual bool ReadBatch();
virtual void add_feed_var(Variable* feed, const std::string& name); virtual void AddFeedVar(Variable* feed, const std::string& name);
virtual void bind_scope(Scope* scope) {} virtual void BindScope(Scope* scope) {}
virtual bool set_file(const char* filename); virtual bool SetFile(const char* filename);
virtual bool check_file(const char* filename) { virtual bool CheckFile(const char* filename) {
// TODO(xxx) // TODO(xxx)
return false; return false;
} }
void set_batch_size(int batch) {_batch_size = batch;} void SetBatchSize(int batch) {batch_size_ = batch;}
private: private:
int read_whole_file(const std::string& filename, char* buffer); int ReadWholeFile(const std::string& filename, char* buffer);
char* _file_content_buffer; char* file_content_buffer_;
char* _file_content_buffer_ptr; char* file_content_buffer_ptr_;
int* _batch_id_buffer; int* batch_id_buffer_;
int* _label_ptr; int* label_ptr_;
int _file_size; int file_size_;
std::vector<std::string> _names; std::vector<std::string> names_;
std::shared_ptr<char> _file_content_buffer_host; std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> _batch_id_host; std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> _label_host; std::shared_ptr<int> label_host_;
}; };
} // namespace framework } // namespace framework
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/datafeed_creator.h" #include "paddle/fluid/framework/datafeed_creator.h"
std::shared_ptr<paddle::framework::DataFeed> create_datafeed( std::shared_ptr<paddle::framework::DataFeed> CreateDataFeed(
const char* datafeed_class) { const char* datafeed_class) {
if (strcmp(datafeed_class, "TextClass") == 0) { if (strcmp(datafeed_class, "TextClass") == 0) {
return std::shared_ptr<paddle::framework::DataFeed>( return std::shared_ptr<paddle::framework::DataFeed>(
......
...@@ -17,6 +17,6 @@ limitations under the License. */ ...@@ -17,6 +17,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
std::shared_ptr<paddle::framework::DataFeed> create_datafeed( std::shared_ptr<paddle::framework::DataFeed> CreateDataFeed(
const char* datafeed_class); const char* datafeed_class);
#endif // PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_ #endif // PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册