提交 929a9e80 编写于 作者: W wangguibao

Google naming conventions

上级 c555948c
......@@ -37,13 +37,13 @@ limitations under the License. */
namespace paddle {
namespace framework {
std::mutex ExecutorThreadWorker::_s_locker_for_pick_file;
unsigned int ExecutorThreadWorker::_s_current_file_idx = 0;
size_t ExecutorThreadWorker::_s_current_finished_file_cnt = 0;
unsigned int ExecutorThreadWorker::_s_current_epoch = 0;
int ExecutorThreadWorker::_s_current_save_epoch = 0;
bool ExecutorThreadWorker::_s_is_first_worker = false;
std::vector<std::string> ExecutorThreadWorker::_s_thread_filelist;
std::mutex ExecutorThreadWorker::s_locker_for_pick_file_;
unsigned int ExecutorThreadWorker::s_current_file_idx_ = 0;
size_t ExecutorThreadWorker::s_current_finished_file_cnt_ = 0;
unsigned int ExecutorThreadWorker::s_current_epoch_ = 0;
int ExecutorThreadWorker::s_current_save_epoch_ = 0;
bool ExecutorThreadWorker::s_is_first_worker_ = false;
std::vector<std::string> ExecutorThreadWorker::s_thread_filelist_;
void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
......@@ -142,33 +142,33 @@ static void save_model(
} // end save_model
void ExecutorThreadWorker::add_train_file(const std::string& file) {
_s_thread_filelist.push_back(file);
void ExecutorThreadWorker::AddTrainFile(const std::string& file) {
s_thread_filelist_.push_back(file);
}
void ExecutorThreadWorker::create_thread_operators(const ProgramDesc& program) {
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
_op_names.clear();
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
_op_names.push_back(op_desc->Type());
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
_ops.push_back(local_op_ptr);
ops_.push_back(local_op_ptr);
continue;
}
}
void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) {
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
_thread_scope = &_root_scope->NewScope();
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = _root_scope->Var(var->Name());
auto* ptr = root_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
// LOGERR("create Persistable var[%s] finished",
// var->Name().c_str());
} else {
auto* ptr = _thread_scope->Var(var->Name());
auto* ptr = thread_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
// LOGERR("create unpersistable var[%s] finished",
// var->Name().c_str());
......@@ -176,33 +176,33 @@ void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) {
}
}
void ExecutorThreadWorker::set_datafeed(const std::shared_ptr<DataFeed>& datafeed) {
_local_reader = datafeed;
void ExecutorThreadWorker::SetDataFeed(const std::shared_ptr<DataFeed>& datafeed) {
local_reader_ = datafeed;
}
void ExecutorThreadWorker::binding_datafeed_memory() {
const std::vector<std::string>& input_feed = _local_reader->get_use_slot_alias();
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed = local_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
_local_reader->add_feed_var(_thread_scope->Var(name), name);
local_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void ExecutorThreadWorker::set_inspect_var_name(
void ExecutorThreadWorker::SetInspectVarName(
const std::string& inspect_var_name) {
_inspect_var_name = inspect_var_name;
inspect_var_name_ = inspect_var_name;
}
void ExecutorThreadWorker::set_model_param_names(
void ExecutorThreadWorker::SetModelParamNames(
const std::vector<std::string>& param_names) {
_model_param_names = param_names;
model_param_names_ = param_names;
}
void ExecutorThreadWorker::set_sparse_comm_data(
void ExecutorThreadWorker::SetSparseCommData(
const std::map<std::string, int>& param_names) {
_sparse_comm_data = param_names;
sparse_comm_data_ = param_names;
}
void ExecutorThreadWorker::set_device() {
void ExecutorThreadWorker::SetDevice() {
static unsigned priority[] = {
0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11,
......@@ -214,7 +214,7 @@ void ExecutorThreadWorker::set_device() {
42, 43, 44, 45, 46, 47
};
unsigned int i = this->_thread_id;
unsigned int i = this->thread_id_;
if (i < sizeof(priority) / sizeof(unsigned)) {
unsigned proc = priority[i];
......@@ -235,55 +235,55 @@ void ExecutorThreadWorker::set_device() {
}
}
void ExecutorThreadWorker::update_epoch_num() {
_s_current_finished_file_cnt++;
void ExecutorThreadWorker::UpdateEpochNum() {
s_current_finished_file_cnt_++;
if (_s_current_finished_file_cnt >= _s_thread_filelist.size()) {
_s_current_finished_file_cnt = 0;
_s_current_epoch++;
if (s_current_finished_file_cnt_ >= s_thread_filelist_.size()) {
s_current_finished_file_cnt_ = 0;
s_current_epoch_++;
}
}
const char* ExecutorThreadWorker::pick_one_file() {
const char* ExecutorThreadWorker::PickOneFile() {
std::string file_to_be_preocessed;
std::lock_guard<std::mutex> lock(_s_locker_for_pick_file);
if (_s_current_file_idx >= _s_thread_filelist.size()) {
std::random_shuffle(_s_thread_filelist.begin(),
_s_thread_filelist.end());
_s_current_file_idx = 0;
// _s_current_epoch++; //example: when one file, one thread, it's bug
LOG(ERROR) << "thread " << _thread_id
<< ": finish traing for epoch " << _s_current_epoch + 1;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
if (s_current_file_idx_ >= s_thread_filelist_.size()) {
std::random_shuffle(s_thread_filelist_.begin(),
s_thread_filelist_.end());
s_current_file_idx_ = 0;
// s_current_epoch_++; //example: when one file, one thread, it's bug
LOG(ERROR) << "thread " << thread_id_
<< ": finish traing for epoch " << s_current_epoch_ + 1;
}
file_to_be_preocessed = _s_thread_filelist[_s_current_file_idx];
file_to_be_preocessed = s_thread_filelist_[s_current_file_idx_];
_s_current_file_idx++;
s_current_file_idx_++;
return file_to_be_preocessed.c_str();
}
void ExecutorThreadWorker::train() {
void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train";
set_device();
SetDevice();
#ifdef LOCAL_PROF
std::vector<double> op_total_time;
std::vector<std::string> op_name;
// int total_batch = 0;
for (auto& op : _ops) {
for (auto& op : ops_) {
op_name.push_back(op->Type());
}
op_total_time.resize(_ops.size());
op_total_time.resize(ops_.size());
for (int i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
#endif
std::string inspect_key = "inspect";
if (!_inspect_var_name.empty()) {
inspect_key = _inspect_var_name.substr(0,
_inspect_var_name.find_first_of('_'));
if (!inspect_var_name_.empty()) {
inspect_key = inspect_var_name_.substr(0,
inspect_var_name_.find_first_of('_'));
}
for (unsigned i = 0; i < _max_epoch; ++i) {
for (unsigned i = 0; i < max_epoch_; ++i) {
LOG(ERROR) << "epoch: " << i;
#ifdef LOCAL_PROF
Timer timeline;
......@@ -292,14 +292,14 @@ void ExecutorThreadWorker::train() {
#endif
float total_inspect = 0;
int batch_num = 1;
while (i == _s_current_epoch) {
const char* filename = pick_one_file();
_local_reader->set_file(filename);
while (i == s_current_epoch_) {
const char* filename = PickOneFile();
local_reader_->SetFile(filename);
while (true) {
#ifdef LOCAL_PROF
timeline.start();
#endif
bool flag = _local_reader->read_batch();
bool flag = local_reader_->ReadBatch();
if (!flag) {
break;
}
......@@ -312,11 +312,11 @@ void ExecutorThreadWorker::train() {
break;
}
for (unsigned int i = 0; i < _ops.size(); ++i) {
for (unsigned int i = 0; i < ops_.size(); ++i) {
#ifdef LOCAL_PROF
timeline.start();
#endif
_ops[i]->Run(*_thread_scope, _place);
ops_[i]->Run(*thread_scope_, place_);
#ifdef LOCAL_PROF
timeline.pause();
op_total_time[i] += timeline.elapsed_sec();
......@@ -325,17 +325,17 @@ void ExecutorThreadWorker::train() {
}
batch_num++;
float avg_inspect = 0.0;
if (!_inspect_var_name.empty()) {
avg_inspect = _thread_scope->FindVar(_inspect_var_name)
if (!inspect_var_name_.empty()) {
avg_inspect = thread_scope_->FindVar(inspect_var_name_)
->GetMutable<LoDTensor>()
->data<float>()[0];
}
total_inspect += avg_inspect;
_thread_scope->DropKids();
thread_scope_->DropKids();
}
update_epoch_num();
UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << i + 1
<< " called: " << memory::memory_usage(_place);
<< " called: " << memory::memory_usage(place_);
#ifdef LOCAL_PROF
for (int i = 0; i < op_total_time.size(); ++i) {
......@@ -354,12 +354,12 @@ void ExecutorThreadWorker::train() {
LOG(ERROR) << "mean " << inspect_key.c_str()
<< " of epoch " << i + 1 << ": " << total_inspect / batch_num;
#endif
if (_thread_id == 0) {
if (thread_id_ == 0) {
char modelfile[1024];
snprintf(&modelfile[0],
sizeof(modelfile),
"%s_epoch%d.model",
_model_prefix.c_str(),
model_prefix_.c_str(),
i);
std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this
......@@ -367,55 +367,55 @@ void ExecutorThreadWorker::train() {
//
// currently comment it
LOG(ERROR) << "Going to save model " << modelfile;
save_model(_main_program,
_thread_scope,
_model_param_names,
save_model(main_program_,
thread_scope_,
model_param_names_,
model_filename,
true);
}
}
}
void ExecutorThreadWorker::set_thread_id(int tid) {
_thread_id = tid;
void ExecutorThreadWorker::SetThreadId(int tid) {
thread_id_ = tid;
}
void ExecutorThreadWorker::set_place(const platform::Place& place) {
_place = place;
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
place_ = place;
}
void ExecutorThreadWorker::set_main_program(
void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) {
_main_program.reset(new ProgramDesc(main_program_desc));
main_program_.reset(new ProgramDesc(main_program_desc));
}
void ExecutorThreadWorker::set_root_scope(Scope* g_scope) {
_root_scope = g_scope;
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
void ExecutorThreadWorker::set_max_training_epoch(int max_epoch) {
_max_epoch = max_epoch;
void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
MultiExecutor::MultiExecutor(const platform::Place& place) : _place(place) {}
MultiExecutor::MultiExecutor(const platform::Place& place) : place_(place) {}
void MultiExecutor::init_root_scope(Scope* scope) {
_root_scope = scope;
void MultiExecutor::InitRootScope(Scope* scope) {
root_scope_ = scope;
}
void MultiExecutor::set_max_training_epoch(int max_epoch) {
_max_epoch = max_epoch;
void MultiExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
void MultiExecutor::set_datafeed_name(const char* feedname) {
_feed_name = std::string(feedname);
void MultiExecutor::SetDataFeedName(const char* feedname) {
feed_name_ = std::string(feedname);
}
void MultiExecutor::set_model_prefix(const std::string& model_prefix) {
_model_prefix = model_prefix;
void MultiExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix;
}
void MultiExecutor::run_startup_program(const ProgramDesc& program,
void MultiExecutor::RunStartupProgram(const ProgramDesc& program,
Scope* scope) {
auto& block = program.Block(0);
for (auto& var : block.AllVars()) {
......@@ -447,7 +447,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// param_dict.size(), ops.size());
for (auto& op : ops) {
op->Run(*scope, _place);
op->Run(*scope, place_);
}
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for (auto& op : ops) {
......@@ -456,7 +456,7 @@ void MultiExecutor::run_startup_program(const ProgramDesc& program,
// LOGERR("run startup program done.");
}
std::unique_ptr<ProgramDesc> MultiExecutor::load_desc_from_file(
std::unique_ptr<ProgramDesc> MultiExecutor::LoadDescFromFile(
const std::string& f) {
std::string program_desc_str;
read_binary_file(f, &program_desc_str);
......@@ -464,102 +464,102 @@ std::unique_ptr<ProgramDesc> MultiExecutor::load_desc_from_file(
return program;
}
void MultiExecutor::set_dense_comm_tensor(
void MultiExecutor::SetDenseCommTensor(
const std::vector<std::string>& dense_comm_tensor) {
_dense_comm_tensor.resize(dense_comm_tensor.size());
dense_comm_tensor_.resize(dense_comm_tensor.size());
for (unsigned int i = 0; i < dense_comm_tensor.size(); ++i) {
_dense_comm_tensor[i] = dense_comm_tensor[i];
dense_comm_tensor_[i] = dense_comm_tensor[i];
}
}
void MultiExecutor::set_sparse_comm_tensor(
void MultiExecutor::SetSparseCommTensor(
const std::vector<std::string>& sparse_comm_tensor) {
_sparse_comm_tensor.resize(sparse_comm_tensor.size());
sparse_comm_tensor_.resize(sparse_comm_tensor.size());
for (unsigned int i = 0; i < sparse_comm_tensor.size(); ++i) {
_sparse_comm_tensor[i] = sparse_comm_tensor[i];
sparse_comm_tensor_[i] = sparse_comm_tensor[i];
}
}
void MultiExecutor::set_sparse_comm_data(
void MultiExecutor::SetSparseCommData(
const std::map<std::string, int>& sparse_comm_data) {
_sparse_comm_data = sparse_comm_data;
LOG(INFO) << "Sparse comm data: " << _sparse_comm_data.size();
sparse_comm_data_ = sparse_comm_data;
LOG(INFO) << "Sparse comm data: " << sparse_comm_data_.size();
}
void MultiExecutor::set_filelist(const char* filelist) {
_filelist.clear();
void MultiExecutor::SetFileList(const char* filelist) {
filelist_.clear();
std::ifstream fin(filelist);
std::string filename;
while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist";
_filelist.push_back(filename);
filelist_.push_back(filename);
}
fin.close();
}
void MultiExecutor::set_filelist(std::vector<std::string> tfiles) {
_filelist.clear();
_filelist.insert(_filelist.end(), tfiles.begin(), tfiles.end());
void MultiExecutor::SetFileList(std::vector<std::string> tfiles) {
filelist_.clear();
filelist_.insert(filelist_.end(), tfiles.begin(), tfiles.end());
return;
}
void MultiExecutor::set_inspect_var_name(const std::string& inspect_var_name) {
_inspect_var_name = inspect_var_name;
void MultiExecutor::SetInspectVarName(const std::string& inspect_var_name) {
inspect_var_name_ = inspect_var_name;
}
void MultiExecutor::set_param_names(const std::vector<std::string>& param_names) {
_model_param_names = param_names;
void MultiExecutor::SetParamNames(const std::vector<std::string>& param_names) {
model_param_names_ = param_names;
}
void MultiExecutor::set_thread_num(const int thread_num) {
_thread_num = thread_num;
void MultiExecutor::SetThreadNum(const int thread_num) {
thread_num_ = thread_num;
}
void MultiExecutor::prepare_threads(const ProgramDesc& host_program) {
_workers.resize(_thread_num);
for (unsigned i = 0; i < _thread_num; ++i) {
_workers[i].reset(new ExecutorThreadWorker);
_workers[i]->set_thread_id(i);
_workers[i]->create_thread_operators(host_program);
_workers[i]->set_root_scope(_root_scope);
_workers[i]->set_place(_place);
_workers[i]->set_max_training_epoch(_max_epoch);
_workers[i]->create_thread_scope(host_program);
_workers[i]->set_inspect_var_name(_inspect_var_name);
_workers[i]->set_model_param_names(_model_param_names);
_workers[i]->set_sparse_comm_data(_sparse_comm_data);
_workers[i]->set_main_program(host_program);
_workers[i]->set_model_prefix(_model_prefix);
void MultiExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_.resize(thread_num_);
for (unsigned i = 0; i < thread_num_; ++i) {
workers_[i].reset(new ExecutorThreadWorker);
workers_[i]->SetThreadId(i);
workers_[i]->CreateThreadOperators(host_program);
workers_[i]->SetRootScope(root_scope_);
workers_[i]->SetPlace(place_);
workers_[i]->SetMaxTrainingEpoch(max_epoch_);
workers_[i]->CreateThreadScope(host_program);
workers_[i]->SetInspectVarName(inspect_var_name_);
workers_[i]->SetModelParamNames(model_param_names_);
workers_[i]->SetSparseCommData(sparse_comm_data_);
workers_[i]->SetMainProgram(host_program);
workers_[i]->SetModelPrefix(model_prefix_);
}
for (unsigned i = 0; i < _filelist.size(); ++i) {
for (unsigned i = 0; i < filelist_.size(); ++i) {
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
_workers[0]->add_train_file(_filelist[i]);
workers_[0]->AddTrainFile(filelist_[i]);
}
// mpi_wrapper::ModelParam model_param(true);
// _workers[0]->register_parallel_training_param(model_param);
// workers_[0]->register_parallel_training_param(model_param);
for (unsigned i = 0; i < _thread_num; ++i) {
for (unsigned i = 0; i < thread_num_; ++i) {
// new a datafeed here
std::shared_ptr<DataFeed> local_feed = create_datafeed(_feed_name.c_str());
local_feed->init(_data_feed_param);
local_feed->set_batch_size(_batch_size);
_workers[i]->set_datafeed(local_feed);
_workers[i]->binding_datafeed_memory();
_workers[i]->set_thread_id(i);
std::shared_ptr<DataFeed> local_feed = CreateDataFeed(feed_name_.c_str());
local_feed->Init(data_feed_param_);
local_feed->SetBatchSize(batch_size_);
workers_[i]->SetDataFeed(local_feed);
workers_[i]->BindingDataFeedMemory();
workers_[i]->SetThreadId(i);
}
}
void MultiExecutor::run_multi_executor(const ProgramDesc& host_program) {
void MultiExecutor::RunMultiExecutor(const ProgramDesc& host_program) {
// thread binding here?
prepare_threads(host_program);
for (unsigned i = 0; i < _thread_num; ++i) {
_threads.push_back(std::thread(&ExecutorThreadWorker::train,
_workers[i].get()));
PrepareThreads(host_program);
for (unsigned i = 0; i < thread_num_; ++i) {
threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
workers_[i].get()));
}
for (auto& th : _threads) {
for (auto& th : threads_) {
th.join();
}
}
......
......@@ -36,137 +36,129 @@ class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
virtual ~ExecutorThreadWorker() {}
void create_thread_scope(const framework::ProgramDesc& program);
void set_datafeed(const DataFeed& datafeed);
void set_thread_id(int tid);
void create_thread_operators(const framework::ProgramDesc& program);
void set_root_scope(Scope* g_scope);
void set_device();
virtual void add_fid_set();
void set_comm_batch(int comm_batch) { _comm_batch = comm_batch; }
void add_train_file(const std::string& filename);
void set_main_program(const ProgramDesc& main_program_desc);
void set_place(const paddle::platform::Place& place);
void set_max_training_epoch(const int max_epoch);
void binding_datafeed_memory();
void set_model_prefix(const std::string& prefix) { _model_prefix = prefix; }
void set_inspect_var_name(const std::string& inspect_var_name);
void set_model_param_names(const std::vector<std::string>& param_names);
void set_sparse_comm_data(const std::map<std::string, int>& param_names);
void set_datafeed(const std::shared_ptr<DataFeed>& datafeed);
virtual void mpi_train();
void gpu_train();
void train();
virtual const char* pick_one_file();
void update_epoch_num();
virtual void set_dense_comm_tensor(
void CreateThreadScope(const framework::ProgramDesc& program);
void SetDataFeed(const DataFeed& datafeed);
void SetThreadId(int tid);
void CreateThreadOperators(const framework::ProgramDesc& program);
void SetRootScope(Scope* g_scope);
void SetDevice();
virtual void AddFidSet();
void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
void AddTrainFile(const std::string& filename);
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
void SetMaxTrainingEpoch(const int max_epoch);
void BindingDataFeedMemory();
void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void SetInspectVarName(const std::string& inspect_var_name);
void SetModelParamNames(const std::vector<std::string>& param_names);
void SetSparseCommData(const std::map<std::string, int>& param_names);
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
void Train();
virtual const char* PickOneFile();
void UpdateEpochNum();
virtual void SetDenseCommTensor(
const std::vector<std::string>& param_names) {}
virtual void initialize() {}
virtual void Initialize() {}
public:
static std::mutex _s_locker_for_pick_file;
static unsigned int _s_current_file_idx;
static size_t _s_current_finished_file_cnt;
static unsigned int _s_current_epoch;
static int _s_current_save_epoch;
static std::vector<std::string> _s_thread_filelist; // filelist
static bool _s_is_first_worker;
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::vector<std::string> s_thread_filelist_; // filelist
static bool s_is_first_worker_;
protected:
// thread index
int _thread_id;
// current training file
int _cur_fileidx;
int thread_id_;
// max epoch for each thread
unsigned int _max_epoch;
unsigned int max_epoch_;
// instances learned currently
int _comm_batch;
std::string _model_prefix;
std::vector<std::string> _op_names;
int comm_batch_;
std::string model_prefix_;
std::vector<std::string> op_names_;
// local ops for forward and backward
std::vector<OperatorBase *> _ops;
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<framework::ProgramDesc> _main_program;
std::unique_ptr<framework::ProgramDesc> main_program_;
// binary data reader
std::shared_ptr<DataFeed> _local_reader;
std::shared_ptr<DataFeed> local_reader_;
std::string _inspect_var_name;
std::vector<std::string> _model_param_names;
std::map<std::string, int> _sparse_comm_data;
std::vector<int> _ids_buffer;
std::string inspect_var_name_;
std::vector<std::string> model_param_names_;
std::map<std::string, int> sparse_comm_data_;
// execution place
platform::Place _place;
platform::Place place_;
// root scope for model parameters
Scope* _root_scope;
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* _thread_scope;
Scope* thread_scope_;
};
class MultiExecutor {
public:
explicit MultiExecutor(const platform::Place& place);
virtual ~MultiExecutor() {}
static std::unique_ptr<ProgramDesc> load_desc_from_file(
static std::unique_ptr<ProgramDesc> LoadDescFromFile(
const std::string& filename);
void init_root_scope(Scope* scope);
void set_inspect_var_name(const std::string& inspect_var_name);
void set_param_names(const std::vector<std::string>& param_names);
void set_max_training_epoch(const int max_epoch);
Scope* get_root_scope() { return _root_scope; }
void set_thread_num(const int thread_num);
void set_batch_size(const int batch_size) { _batch_size = batch_size; }
void set_filelist(const char* filelist);
void set_filelist(const std::vector<std::string> filelist);
void set_datafeed_name(const char* feedname);
void set_data_feed_param(const datafeed::DataFeedParameter& feed_param) {
_data_feed_param = feed_param;
void InitRootScope(Scope* scope);
void SetInspectVarName(const std::string& inspect_var_name);
void SetParamNames(const std::vector<std::string>& param_names);
void SetMaxTrainingEpoch(const int max_epoch);
Scope* GetRootScope() { return root_scope_; }
void SetThreadNum(const int thread_num);
void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
void SetFileList(const char* filelist);
void SetFileList(const std::vector<std::string> filelist);
void SetDataFeedName(const char* feedname);
void SetDataFeedParam(const datafeed::DataFeedParameter& feed_param) {
data_feed_param_ = feed_param;
}
void set_comm_batch(int comm_batch) {
_comm_batch = comm_batch;
void SetCommBatch(int comm_batch) {
comm_batch_ = comm_batch;
}
void set_model_prefix(const std::string& model_prefix);
void set_dense_comm_tensor(const std::vector<std::string>& dense_comm_tensor);
void set_sparse_comm_tensor(
void SetModelPrefix(const std::string& model_prefix);
void SetDenseCommTensor(const std::vector<std::string>& dense_comm_tensor);
void SetSparseCommTensor(
const std::vector<std::string>& sparse_comm_tensor);
void set_sparse_comm_data(const std::map<std::string, int>& sparse_comm_data);
virtual void prepare_threads(const framework::ProgramDesc& host_program);
void run_startup_program(const framework::ProgramDesc& program,
void SetSparseCommData(const std::map<std::string, int>& sparse_comm_data);
virtual void PrepareThreads(const framework::ProgramDesc& host_program);
void RunStartupProgram(const framework::ProgramDesc& program,
framework::Scope* scope);
void run_multi_executor(const ProgramDesc& host_program);
void RunMultiExecutor(const ProgramDesc& host_program);
public:
unsigned int _thread_num;
datafeed::DataFeedParameter _data_feed_param;
int _max_epoch;
int _batch_size;
int _comm_batch;
std::vector<std::shared_ptr<ExecutorThreadWorker> > _workers;
std::vector<std::thread> _threads;
std::vector<std::string> _filelist;
std::string _inspect_var_name;
std::vector<std::string> _model_param_names;
std::vector<std::string> _dense_comm_tensor;
std::vector<std::string> _sparse_comm_tensor;
std::map<std::string, int> _sparse_comm_data;
int node_num;
std::string _model_prefix;
ProgramDesc _host_program;
std::string _feed_name;
Scope* _root_scope;
platform::Place _place;
unsigned int thread_num_;
datafeed::DataFeedParameter data_feed_param_;
int max_epoch_;
int batch_size_;
int comm_batch_;
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
std::vector<std::thread> threads_;
std::vector<std::string> filelist_;
std::string inspect_var_name_;
std::vector<std::string> model_param_names_;
std::vector<std::string> dense_comm_tensor_;
std::vector<std::string> sparse_comm_tensor_;
std::map<std::string, int> sparse_comm_data_;
std::string model_prefix_;
std::string feed_name_;
Scope* root_scope_;
platform::Place place_;
};
} // namespace framework
......
......@@ -38,111 +38,114 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle {
namespace framework {
void TextClassDataFeed::init(const datafeed::DataFeedParameter& feed_param) {
void TextClassDataFeed::Init(const datafeed::DataFeedParameter& feed_param) {
// hard coding for a specific datafeed
_feed_vec.resize(2);
// _feed_vec[0].reset(new LoDTensor);
// _feed_vec[1].reset(new LoDTensor);
_all_slot_ids = {0, 1};
_use_slot_ids = {0, 1};
_use_slot_alias = {"words", "label"};
_file_content_buffer_host.reset(new char[200*1024*1024],
feed_vec_.resize(2);
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[1].reset(new LoDTensor);
all_slot_ids_ = {0, 1};
use_slot_ids_ = {0, 1};
use_slot_alias_ = {"words", "label"};
file_content_buffer_host_.reset(new char[200*1024*1024],
[](char *p) {delete[] p;});
_file_content_buffer = _file_content_buffer_host.get();
_file_content_buffer_ptr = _file_content_buffer;
_batch_id_host.reset(new int[10240*1024],
file_content_buffer_ = file_content_buffer_host_.get();
file_content_buffer_ptr_ = file_content_buffer_;
batch_id_host_.reset(new int[10240*1024],
[](int *p) {delete[] p;}); // max word num in a batch
_label_host.reset(new int[10240],
batch_id_buffer_ = batch_id_host_.get();
label_host_.reset(new int[10240],
[](int *p) {delete[] p;}); // max label in a batch
_batch_id_buffer = _batch_id_host.get();
_label_ptr = _label_host.get();
label_ptr_ = label_host_.get();
}
// todo: use elegant implemention for this function
bool TextClassDataFeed::read_batch() {
bool TextClassDataFeed::ReadBatch() {
paddle::framework::Vector<size_t> offset;
int tlen = 0;
int llen = 0;
int inst_idx = 0;
offset.resize(_batch_size + 1);
offset.resize(batch_size_ + 1);
offset[0] = 0;
while (inst_idx < _batch_size) {
while (inst_idx < batch_size_) {
int ptr_offset = 0;
if (_file_content_buffer_ptr - _file_content_buffer >= _file_size) {
if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) {
break;
}
memcpy(reinterpret_cast<char *>(&llen),
_file_content_buffer_ptr + ptr_offset,
file_content_buffer_ptr_ + ptr_offset,
sizeof(int));
ptr_offset += sizeof(int);
memcpy(reinterpret_cast<char *>(_batch_id_buffer + tlen),
_file_content_buffer_ptr + ptr_offset,
memcpy(reinterpret_cast<char *>(batch_id_buffer_ + tlen),
file_content_buffer_ptr_ + ptr_offset,
llen * sizeof(int));
tlen += llen;
offset[inst_idx + 1] = offset[inst_idx] + llen;
ptr_offset += sizeof(int) * llen;
memcpy(reinterpret_cast<char *>(_label_ptr + inst_idx),
_file_content_buffer_ptr + ptr_offset,
memcpy(reinterpret_cast<char *>(label_ptr_ + inst_idx),
file_content_buffer_ptr_ + ptr_offset,
sizeof(int));
ptr_offset += sizeof(int);
_file_content_buffer_ptr += ptr_offset;
file_content_buffer_ptr_ += ptr_offset;
inst_idx++;
}
if (inst_idx != _batch_size) {
if (inst_idx != batch_size_) {
return false;
}
LoD input_lod{offset};
paddle::framework::Vector<size_t> label_offset;
label_offset.resize(_batch_size + 1);
for (int i = 0; i <= _batch_size; ++i) {
label_offset.resize(batch_size_ + 1);
for (int i = 0; i <= batch_size_; ++i) {
label_offset[i] = i;
}
LoD label_lod{label_offset};
int64_t* input_ptr = _feed_vec[0]->mutable_data<int64_t>(
int64_t* input_ptr = feed_vec_[0]->mutable_data<int64_t>(
{static_cast<int64_t>(offset.back()), 1},
platform::CPUPlace());
int64_t* label_ptr = _feed_vec[1]->mutable_data<int64_t>({_batch_size, 1},
int64_t* label_ptr = feed_vec_[1]->mutable_data<int64_t>({batch_size_, 1},
platform::CPUPlace());
for (unsigned int i = 0; i < offset.back(); ++i) {
input_ptr[i] = static_cast<int64_t>(_batch_id_buffer[i]);
input_ptr[i] = static_cast<int64_t>(batch_id_buffer_[i]);
}
for (int i = 0; i < _batch_size; ++i) {
label_ptr[i] = static_cast<int64_t>(_label_ptr[i]);
for (int i = 0; i < batch_size_; ++i) {
label_ptr[i] = static_cast<int64_t>(label_ptr_[i]);
}
_feed_vec[0]->set_lod(input_lod);
_feed_vec[1]->set_lod(label_lod);
feed_vec_[0]->set_lod(input_lod);
feed_vec_[1]->set_lod(label_lod);
return true;
}
void TextClassDataFeed::add_feed_var(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < _use_slot_alias.size(); ++i) {
if (name == _use_slot_alias[i]) {
_feed_vec[i] = feed->GetMutable<LoDTensor>();
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == use_slot_alias_[i]) {
feed_vec_[i] = feed->GetMutable<LoDTensor>();
}
}
}
bool TextClassDataFeed::set_file(const char* filename) {
bool TextClassDataFeed::SetFile(const char* filename) {
// termnum termid termid ... termid label
int filesize = read_whole_file(filename, _file_content_buffer);
int filesize = ReadWholeFile(filename, file_content_buffer_);
// todo , remove magic number
if (filesize < 0 || filesize >= 1024 * 1024 * 1024) {
return false;
}
_file_content_buffer_ptr = _file_content_buffer;
_file_size = filesize;
file_content_buffer_ptr_ = file_content_buffer_;
file_size_ = filesize;
return true;
}
int TextClassDataFeed::read_whole_file(const std::string& filename,
int TextClassDataFeed::ReadWholeFile(const std::string& filename,
char* buffer) {
std::ifstream ifs(filename.c_str(), std::ios::binary);
if (ifs.fail()) {
......
......@@ -35,47 +35,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef uint64_t FeatureKey;
struct FeatureItem {
FeatureItem() {}
FeatureItem(FeatureKey sign_, uint16_t slot_) {
sign() = sign_;
slot() = slot_;
}
FeatureKey& sign() {
return *(reinterpret_cast<FeatureKey*>(sign_buffer()));
}
const FeatureKey& sign() const {
return *(const FeatureKey*)sign_buffer();
}
uint16_t& slot() {
return _slot;
}
const uint16_t& slot() const {
return _slot;
}
private:
char _sign[sizeof(FeatureKey)];
uint16_t _slot;
char* sign_buffer() const {
return (char *)_sign;
}
};
// Record(average:14031B) is smaller than Sample(average:16530B)
struct Record {
int show, click;
std::vector<FeatureItem> feas;
std::string lineid;
std::string tags;
};
struct Gauc {
int show, click;
uint64_t fea;
......@@ -89,241 +48,83 @@ struct Instance {
std::vector<Gauc> gauc_vec;
};
struct Sample {
uint64_t label;
std::map<uint16_t, std::vector<uint64_t>> feas;
bool from_string(const std::string& input, const std::set<uint32_t>& slots) {
size_t end = input.find_first_of(' ');
if (end == std::string::npos) {
LOG(ERROR) << "[ERROR] Fail in parsing:" << input;
return false;
}
label = input[end + 3] - '0';
CHECK(label == 0 || label == 1) << "invalid label:" << label;
std::stringstream ss(input);
std::string token;
uint16_t slot_id = 0;
uint64_t feature_id = 0;
int num_nonfeas_token = 0;
std::ostringstream os;
while (ss >> token) {
size_t end = token.find_first_of(':');
if (end == std::string::npos) {
++num_nonfeas_token;
continue;
}
try {
slot_id = stoi(token.substr(end + 1));
} catch (...) {
LOG(ERROR) << "Error in parsing slot id:" << token;
return false;
}
try {
feature_id = stoull(token.substr(0, end));
} catch (...) {
LOG(ERROR) << "Error in parsing feature id:" << token;
return false;
}
if (slot_id <= 0) {
LOG(ERROR) << "invalid slot:" << slot_id << " feasign:" << feature_id
<< " line:" << input;
return false;
}
if (slots.find(slot_id) == slots.end()) {
continue;
}
feas[slot_id].push_back(feature_id);
}
if (num_nonfeas_token != 4) {
LOG(ERROR) << "Format error. Invalid number of non-feasign token:"
<< num_nonfeas_token;
return false;
}
return true;
}
};
struct TeacherStudentSample {
uint64_t label;
std::map<uint16_t, std::vector<uint64_t>> feas;
float q_score;
void print() {
LOG(ERROR) << "label: " << label << " score: " << q_score;
for (auto &slot : feas) {
for (auto &fea : slot.second) {
LOG(ERROR) << "slot: " << slot.first << " fea: " << fea;
}
}
}
bool from_string(const std::string& input,
const std::set<uint32_t>& slots,
Gauc& gauc) { // NOLINT
size_t end = input.find_first_of(' ');
if (end == std::string::npos) {
LOG(ERROR) << "[ERROR] Fail in parsing:" << input;
return false;
}
label = input[end + 3] - '0';
CHECK(label == 0 || label == 1) << "invalid label:" << label;
gauc.show = 1;
gauc.click = label;
gauc.lineid = input.substr(0, end);
gauc.fea = 0;
size_t dnn_start = input.find("*");
if (dnn_start == std::string::npos) {
q_score = -1.0;
} else {
dnn_start += 1;
size_t dnn_end = input.find(' ', dnn_start);
q_score = static_cast<float>(
atof(input.substr(dnn_start, dnn_end - dnn_start).c_str()));
}
size_t head_pos = input.find("\t");
std::string head = input.substr(0, head_pos);
std::stringstream ss(head);
std::string token;
uint16_t slot_id = 0;
uint64_t feature_id = 0;
int num_nonfeas_token = 0;
std::ostringstream os;
while (ss >> token) {
size_t end = token.find_first_of(':');
if (end == std::string::npos) {
++num_nonfeas_token;
continue;
}
try {
slot_id = stoi(token.substr(end + 1));
} catch (...) {
LOG(ERROR) << "Error in parsing slot id:" << token;
return false;
}
try {
feature_id = stoull(token.substr(0, end));
} catch (...) {
LOG(ERROR) << "Error in parsing feature id:" << token;
return false;
}
if (slot_id <= 0) {
LOG(ERROR) << "invalid slot:" << slot_id << " feasign:" << feature_id
<< " line:" << input;
return false;
}
if (slots.find(slot_id) == slots.end()) {
continue;
}
if (slot_id == 6048) {
gauc.fea = feature_id;
}
feas[slot_id].push_back(feature_id);
}
if (num_nonfeas_token != 4) {
LOG(ERROR) << "Format error. Invalid number of non-feasign token:"
<< num_nonfeas_token;
return false;
}
return true;
}
};
class DataFeed {
public:
DataFeed() {}
virtual ~DataFeed() {}
virtual void init(const datafeed::DataFeedParameter& feed_param) = 0;
virtual void Init(const datafeed::DataFeedParameter& feed_param) = 0;
/*
* This function will be used to check file format.
* Considering that this function may be used alone,
* it does not check anything.
* */
virtual bool check_file(const char* filename) = 0;
virtual bool set_file(const char* filename) = 0;
virtual bool read_batch() = 0;
virtual const std::vector<uint16_t>& get_all_slot_ids() {
return _all_slot_ids;
virtual bool CheckFile(const char* filename) = 0;
virtual bool SetFile(const char* filename) = 0;
virtual bool ReadBatch() = 0;
virtual const std::vector<uint16_t>& GetAllSlotIds() {
return all_slot_ids_;
}
virtual const std::vector<uint16_t>& get_use_slot_ids() {
return _use_slot_ids;
virtual const std::vector<uint16_t>& GetUseSlotIds() {
return use_slot_ids_;
}
virtual const std::vector<std::string>& get_use_slot_alias() {
return _use_slot_alias;
virtual const std::vector<std::string>& GetUseSlotAlias() {
return use_slot_alias_;
}
virtual void add_feed_var(Variable* var,
virtual void AddFeedVar(Variable* var,
const std::string& name) = 0;
virtual void bind_scope(Scope* scope) = 0;
virtual void set_batch_size(int batch) { _default_batch_size = batch; }
virtual int get_batch_size() { return _batch_size; }
virtual void set_buffer_size(int buffer_size) {}
virtual void BindScope(Scope* scope) = 0;
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; }
virtual void SetBufferSize(int buffer_size) {}
std::vector<LoDTensor*>& get_feed_vec() {
return _feed_vec;
std::vector<LoDTensor*>& GetFeedVec() {
return feed_vec_;
}
virtual std::vector<LoDTensor*>& get_feed_vec(const Instance& ins) {
virtual std::vector<LoDTensor*>& GetFeedVec(const Instance& ins) {
LOG(ERROR) << "use defalut get_feed_vec";
return _feed_vec;
return feed_vec_;
}
protected:
std::vector<uint16_t> _all_slot_ids;
std::vector<uint16_t> _use_slot_ids;
std::vector<std::string> _use_slot_alias;
std::vector<LoDTensor*> _feed_vec;
int _default_batch_size;
int _batch_size;
std::vector<uint16_t> all_slot_ids_;
std::vector<uint16_t> use_slot_ids_;
std::vector<std::string> use_slot_alias_;
std::vector<LoDTensor*> feed_vec_;
int default_batch_size_;
int batch_size_;
};
class TextClassDataFeed : public DataFeed {
public:
virtual ~TextClassDataFeed() {}
virtual void init(const datafeed::DataFeedParameter& feed_param);
virtual bool read_batch();
virtual void add_feed_var(Variable* feed, const std::string& name);
virtual void bind_scope(Scope* scope) {}
virtual bool set_file(const char* filename);
virtual void Init(const datafeed::DataFeedParameter& feed_param);
virtual bool ReadBatch();
virtual void AddFeedVar(Variable* feed, const std::string& name);
virtual void BindScope(Scope* scope) {}
virtual bool SetFile(const char* filename);
virtual bool check_file(const char* filename) {
virtual bool CheckFile(const char* filename) {
// TODO(xxx)
return false;
}
void set_batch_size(int batch) {_batch_size = batch;}
void SetBatchSize(int batch) {batch_size_ = batch;}
private:
int read_whole_file(const std::string& filename, char* buffer);
char* _file_content_buffer;
char* _file_content_buffer_ptr;
int* _batch_id_buffer;
int* _label_ptr;
int _file_size;
std::vector<std::string> _names;
std::shared_ptr<char> _file_content_buffer_host;
std::shared_ptr<int> _batch_id_host;
std::shared_ptr<int> _label_host;
int ReadWholeFile(const std::string& filename, char* buffer);
char* file_content_buffer_;
char* file_content_buffer_ptr_;
int* batch_id_buffer_;
int* label_ptr_;
int file_size_;
std::vector<std::string> names_;
std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> label_host_;
};
} // namespace framework
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/datafeed_creator.h"
std::shared_ptr<paddle::framework::DataFeed> create_datafeed(
std::shared_ptr<paddle::framework::DataFeed> CreateDataFeed(
const char* datafeed_class) {
if (strcmp(datafeed_class, "TextClass") == 0) {
return std::shared_ptr<paddle::framework::DataFeed>(
......
......@@ -17,6 +17,6 @@ limitations under the License. */
#include <memory>
#include "paddle/fluid/framework/data_feed.h"
std::shared_ptr<paddle::framework::DataFeed> create_datafeed(
std::shared_ptr<paddle::framework::DataFeed> CreateDataFeed(
const char* datafeed_class);
#endif // PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册