提交 eb6a941f 编写于 作者: W wangguibao

Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch

上级 1d239cc8
...@@ -40,13 +40,8 @@ limitations under the License. */ ...@@ -40,13 +40,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::mutex ExecutorThreadWorker::s_locker_for_pick_file_;
unsigned int ExecutorThreadWorker::s_current_file_idx_ = 0; bool AsyncExecutor::workers_initialized_ = false;
size_t ExecutorThreadWorker::s_current_finished_file_cnt_ = 0;
unsigned int ExecutorThreadWorker::s_current_epoch_ = 0;
int ExecutorThreadWorker::s_current_save_epoch_ = 0;
bool ExecutorThreadWorker::s_is_first_worker_ = false;
std::vector<std::string> ExecutorThreadWorker::s_thread_filelist_;
void CreateTensor(Variable* var, proto::VarType::Type var_type) { void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) { if (var_type == proto::VarType::LOD_TENSOR) {
...@@ -124,7 +119,6 @@ static void SaveModel( ...@@ -124,7 +119,6 @@ static void SaveModel(
{{"X", {var->Name()}}}, {{"X", {var->Name()}}},
{}, {},
attrs); attrs);
save_op->Run(*scope, place); save_op->Run(*scope, place);
} else { } else {
paralist.push_back(var->Name()); paralist.push_back(var->Name());
...@@ -140,15 +134,14 @@ static void SaveModel( ...@@ -140,15 +134,14 @@ static void SaveModel(
{{"X", paralist}}, {{"X", paralist}},
{}, {},
attrs); attrs);
save_op->Run(*scope, place); save_op->Run(*scope, place);
} }
} // end SaveModel } // end SaveModel
void ExecutorThreadWorker::Reset() {
void ExecutorThreadWorker::AddTrainFile(const std::string& file) { inspect_values_.clear();
s_thread_filelist_.push_back(file);
} }
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) { void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0); auto& block = program.Block(0);
op_names_.clear(); op_names_.clear();
...@@ -175,8 +168,12 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { ...@@ -175,8 +168,12 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
} }
} }
void ExecutorThreadWorker::SetDataFeed(const std::shared_ptr<DataFeed>& datafeed) { void ExecutorThreadWorker::SetDataFeed(DataFeed& datafeed) {
local_reader_ = datafeed; if (typeid(datafeed) == typeid(TextClassDataFeed)) {
local_reader_.reset(
new TextClassDataFeed(dynamic_cast<TextClassDataFeed &>(datafeed)));
local_reader_->SetThreadId(thread_id_);
}
} }
void ExecutorThreadWorker::BindingDataFeedMemory() { void ExecutorThreadWorker::BindingDataFeedMemory() {
...@@ -186,9 +183,11 @@ void ExecutorThreadWorker::BindingDataFeedMemory() { ...@@ -186,9 +183,11 @@ void ExecutorThreadWorker::BindingDataFeedMemory() {
} }
} }
void ExecutorThreadWorker::SetInspectVarName( void ExecutorThreadWorker::SetInspectVarNames(
const std::string& inspect_var_name) { const std::vector<std::string>& inspect_var_names) {
inspect_var_name_ = inspect_var_name; inspect_var_names_.clear();
inspect_var_names_.insert(inspect_var_names_.end(),
inspect_var_names.begin(), inspect_var_names.end());
} }
void ExecutorThreadWorker::SetModelParamNames( void ExecutorThreadWorker::SetModelParamNames(
...@@ -196,11 +195,6 @@ void ExecutorThreadWorker::SetModelParamNames( ...@@ -196,11 +195,6 @@ void ExecutorThreadWorker::SetModelParamNames(
model_param_names_ = param_names; model_param_names_ = param_names;
} }
void ExecutorThreadWorker::SetSparseCommData(
const std::map<std::string, int>& param_names) {
sparse_comm_data_ = param_names;
}
void ExecutorThreadWorker::SetDevice() { void ExecutorThreadWorker::SetDevice() {
static unsigned priority[] = { static unsigned priority[] = {
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
...@@ -228,150 +222,90 @@ void ExecutorThreadWorker::SetDevice() { ...@@ -228,150 +222,90 @@ void ExecutorThreadWorker::SetDevice() {
CPU_ZERO(&mask); CPU_ZERO(&mask);
if ((0 == sched_getaffinity(0, sizeof(mask), &mask)) if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
&& CPU_ISSET(proc, &mask)) { && CPU_ISSET(proc, &mask)) {
LOG(ERROR) << "TRACE: Thread " << i << " is running on processor " << proc << "..."; LOG(ERROR) << "TRACE: Thread " << i
<< " is running on processor " << proc
<< "...";
} }
} }
} }
} }
void ExecutorThreadWorker::UpdateEpochNum() {
s_current_finished_file_cnt_++;
if (s_current_finished_file_cnt_ >= s_thread_filelist_.size()) {
s_current_finished_file_cnt_ = 0;
s_current_epoch_++;
}
}
const char* ExecutorThreadWorker::PickOneFile() {
std::string file_to_be_preocessed;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
if (s_current_file_idx_ >= s_thread_filelist_.size()) {
std::random_shuffle(s_thread_filelist_.begin(),
s_thread_filelist_.end());
s_current_file_idx_ = 0;
// s_current_epoch_++; //example: when one file, one thread, it's bug
LOG(ERROR) << "thread " << thread_id_
<< ": finish traing for epoch " << s_current_epoch_ + 1;
}
file_to_be_preocessed = s_thread_filelist_[s_current_file_idx_];
s_current_file_idx_++;
return file_to_be_preocessed.c_str();
}
void ExecutorThreadWorker::Train() { void ExecutorThreadWorker::Train() {
LOG(ERROR) << "begin to train"; LOG(ERROR) << "begin to train";
SetDevice(); SetDevice();
#ifdef LOCAL_PROF
std::vector<double> op_total_time;
std::vector<std::string> op_name;
// int total_batch = 0;
for (auto& op : ops_) {
op_name.push_back(op->Type());
}
op_total_time.resize(ops_.size());
for (int i = 0; i < op_total_time.size(); ++i) {
op_total_time[i] = 0.0;
}
#endif
std::string inspect_key = "inspect";
if (!inspect_var_name_.empty()) {
inspect_key = inspect_var_name_.substr(0,
inspect_var_name_.find_first_of('_'));
}
for (unsigned i = 0; i < max_epoch_; ++i) { int inspect_var_num = inspect_var_names_.size();
LOG(ERROR) << "epoch: " << i; inspect_values_.clear();
#ifdef LOCAL_PROF inspect_values_.resize(inspect_var_num, 0);
Timer timeline;
double total_time = 0.0; local_reader_->WaitNextEpoch();
double read_time = 0.0; int epoch = local_reader_->GetCurrentEpoch();
#endif
float total_inspect = 0; LOG(ERROR) << "epoch: " << epoch;
int batch_num = 1;
while (i == s_current_epoch_) { int batch_num = 1;
const char* filename = PickOneFile();
local_reader_->SetFile(filename); while (true) {
while (true) { const char *file = local_reader_->PickOneFile();
#ifdef LOCAL_PROF if (file == NULL) {
timeline.start(); break;
#endif }
bool flag = local_reader_->ReadBatch();
if (!flag) { if (!local_reader_->SetFile(file)) {
break; break;
} }
#ifdef LOCAL_PROF
timeline.pause(); while (true) {
read_time += timeline.elapsed_sec(); bool flag = local_reader_->ReadBatch();
total_time += timeline.elapsed_sec(); if (!flag) {
#endif break;
if (!flag) {
break;
}
for (unsigned int i = 0; i < ops_.size(); ++i) {
#ifdef LOCAL_PROF
timeline.start();
#endif
ops_[i]->Run(*thread_scope_, place_);
#ifdef LOCAL_PROF
timeline.pause();
op_total_time[i] += timeline.elapsed_sec();
total_time += timeline.elapsed_sec();
#endif
}
batch_num++;
float avg_inspect = 0.0;
if (!inspect_var_name_.empty()) {
avg_inspect = thread_scope_->FindVar(inspect_var_name_)
->GetMutable<LoDTensor>()
->data<float>()[0];
}
total_inspect += avg_inspect;
thread_scope_->DropKids();
} }
UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << i + 1 for (unsigned int i = 0; i < ops_.size(); ++i) {
<< " called: " << memory::memory_usage(place_); ops_[i]->Run(*thread_scope_, place_);
#ifdef LOCAL_PROF
for (int i = 0; i < op_total_time.size(); ++i) {
std::cerr << "op_name:[" << i << "][" << op_name[i] << "]"
<< " op_mean_time:[" << op_total_time[i] << "s]"
<< std::endl;
} }
std::cerr << "read time: " << read_time << "s" << std::endl; batch_num++;
#endif
} float avg_inspect = 0.0;
#ifdef LOCAL_PROF for (int i = 0; i < inspect_var_num; ++i) {
LOG(ERROR) << "mean " << inspect_key.c_str() avg_inspect = thread_scope_->FindVar(inspect_var_names_[i])
<< " of epoch " << i + 1 << ": " << total_inspect / batch_num ->GetMutable<LoDTensor>()
<< ", total_time: " << total_time; ->data<float>()[0];
#else inspect_values_[i] += avg_inspect;
LOG(ERROR) << "mean " << inspect_key.c_str() }
<< " of epoch " << i + 1 << ": " << total_inspect / batch_num; thread_scope_->DropKids();
#endif
if (thread_id_ == 0) {
char modelfile[1024];
snprintf(&modelfile[0],
sizeof(modelfile),
"%s_epoch%d.model",
model_prefix_.c_str(),
i);
std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG(ERROR) << "Going to save model " << modelfile;
SaveModel(main_program_,
thread_scope_,
model_param_names_,
model_filename,
true);
} }
local_reader_->UpdateEpochNum();
LOG(ERROR) << "memory used after epoch " << epoch + 1
<< " called: " << memory::memory_usage(place_);
}
for (int i = 0; i < inspect_var_num; ++i) {
inspect_values_[i] /= batch_num;
std::string var = inspect_var_names_[i].substr(
0,
inspect_var_names_[i].find_first_of("_"));
LOG(ERROR) << "mean " << var.c_str()
<< " of epoch " << i + 1 << ": " << inspect_values_[i];
}
if (thread_id_ == 0) {
char modelfile[1024];
snprintf(&modelfile[0], sizeof(modelfile), "%s_epoch%d.model",
model_prefix_.c_str(), epoch);
std::string model_filename = std::string(modelfile);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG(ERROR) << "Going to save model " << modelfile;
SaveModel(main_program_,
thread_scope_,
model_param_names_,
model_filename,
true);
} }
} }
...@@ -396,7 +330,20 @@ void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) { ...@@ -396,7 +330,20 @@ void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch; max_epoch_ = max_epoch;
} }
AsyncExecutor::AsyncExecutor(const platform::Place& place) : place_(place) {} AsyncExecutor::AsyncExecutor(ProgramDesc& main_program,
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed,
unsigned int thread_num,
const platform::Place& place)
: thread_num_(thread_num),
place_(place),
main_program_(main_program),
data_feed_(data_feed) {
model_param_names_.clear();
model_param_names_.insert(model_param_names_.end(),
param_names.begin(),
param_names.end());
}
void AsyncExecutor::InitRootScope(Scope* scope) { void AsyncExecutor::InitRootScope(Scope* scope) {
root_scope_ = scope; root_scope_ = scope;
...@@ -406,10 +353,6 @@ void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) { ...@@ -406,10 +353,6 @@ void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch; max_epoch_ = max_epoch;
} }
void AsyncExecutor::SetDataFeedName(const char* feedname) {
feed_name_ = std::string(feedname);
}
void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) { void AsyncExecutor::SetModelPrefix(const std::string& model_prefix) {
model_prefix_ = model_prefix; model_prefix_ = model_prefix;
} }
...@@ -463,60 +406,16 @@ std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile( ...@@ -463,60 +406,16 @@ std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile(
return program; return program;
} }
void AsyncExecutor::SetDenseCommTensor( void AsyncExecutor::SetInspectVarNames(
const std::vector<std::string>& dense_comm_tensor) { const std::vector<std::string>& inspect_var_names) {
dense_comm_tensor_.resize(dense_comm_tensor.size()); inspect_var_names_.clear();
for (unsigned int i = 0; i < dense_comm_tensor.size(); ++i) { inspect_var_names_.insert(inspect_var_names_.end(),
dense_comm_tensor_[i] = dense_comm_tensor[i]; inspect_var_names.begin(), inspect_var_names.end());
}
}
void AsyncExecutor::SetSparseCommTensor(
const std::vector<std::string>& sparse_comm_tensor) {
sparse_comm_tensor_.resize(sparse_comm_tensor.size());
for (unsigned int i = 0; i < sparse_comm_tensor.size(); ++i) {
sparse_comm_tensor_[i] = sparse_comm_tensor[i];
}
}
void AsyncExecutor::SetSparseCommData(
const std::map<std::string, int>& sparse_comm_data) {
sparse_comm_data_ = sparse_comm_data;
LOG(INFO) << "Sparse comm data: " << sparse_comm_data_.size();
}
void AsyncExecutor::SetFileList(const char* filelist) {
filelist_.clear();
std::ifstream fin(filelist);
std::string filename;
while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist";
filelist_.push_back(filename);
}
fin.close();
}
void AsyncExecutor::SetFileList(std::vector<std::string> tfiles) {
filelist_.clear();
filelist_.insert(filelist_.end(), tfiles.begin(), tfiles.end());
return;
}
void AsyncExecutor::SetInspectVarName(const std::string& inspect_var_name) {
inspect_var_name_ = inspect_var_name;
}
void AsyncExecutor::SetParamNames(const std::vector<std::string>& param_names) {
model_param_names_ = param_names;
}
void AsyncExecutor::SetThreadNum(const int thread_num) {
thread_num_ = thread_num;
} }
void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) { void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_.resize(thread_num_); workers_.resize(thread_num_);
for (unsigned i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i].reset(new ExecutorThreadWorker); workers_[i].reset(new ExecutorThreadWorker);
workers_[i]->SetThreadId(i); workers_[i]->SetThreadId(i);
workers_[i]->CreateThreadOperators(host_program); workers_[i]->CreateThreadOperators(host_program);
...@@ -524,34 +423,31 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) { ...@@ -524,34 +423,31 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_[i]->SetPlace(place_); workers_[i]->SetPlace(place_);
workers_[i]->SetMaxTrainingEpoch(max_epoch_); workers_[i]->SetMaxTrainingEpoch(max_epoch_);
workers_[i]->CreateThreadScope(host_program); workers_[i]->CreateThreadScope(host_program);
workers_[i]->SetInspectVarName(inspect_var_name_); workers_[i]->SetInspectVarNames(inspect_var_names_);
workers_[i]->SetModelParamNames(model_param_names_); workers_[i]->SetModelParamNames(model_param_names_);
workers_[i]->SetSparseCommData(sparse_comm_data_);
workers_[i]->SetMainProgram(host_program); workers_[i]->SetMainProgram(host_program);
workers_[i]->SetModelPrefix(model_prefix_); workers_[i]->SetModelPrefix(model_prefix_);
} //
for (unsigned i = 0; i < filelist_.size(); ++i) {
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
workers_[0]->AddTrainFile(filelist_[i]);
}
for (unsigned i = 0; i < thread_num_; ++i) {
// new a datafeed here // new a datafeed here
std::shared_ptr<DataFeed> local_feed = CreateDataFeed(feed_name_.c_str()); workers_[i]->SetDataFeed(data_feed_);
local_feed->Init();
local_feed->SetBatchSize(batch_size_);
workers_[i]->SetDataFeed(local_feed);
workers_[i]->BindingDataFeedMemory(); workers_[i]->BindingDataFeedMemory();
workers_[i]->SetThreadId(i);
} }
} }
void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) { std::vector<float>& AsyncExecutor::Run(
const std::vector<std::string>& inspect_var_names) {
SetInspectVarNames(inspect_var_names);
threads_.clear();
// thread binding here? // thread binding here?
PrepareThreads(host_program); if (workers_initialized_ == false) {
for (unsigned i = 0; i < thread_num_; ++i) { PrepareThreads(main_program_);
workers_initialized_ = true;
}
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->Reset();
workers_[i]->SetInspectVarNames(inspect_var_names);
threads_.push_back(std::thread(&ExecutorThreadWorker::Train, threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
workers_[i].get())); workers_[i].get()));
} }
...@@ -559,6 +455,27 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) { ...@@ -559,6 +455,27 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) {
for (auto& th : threads_) { for (auto& th : threads_) {
th.join(); th.join();
} }
inspect_values_.clear();
inspect_values_.resize(inspect_var_names_.size(), 0);
std::vector<std::vector<float>*> inspect_value_vectors;
inspect_value_vectors.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
inspect_value_vectors[i] = &workers_[i]->GetInspectValues();
}
for (unsigned int i = 0; i < inspect_var_names_.size(); ++i) {
float value = 0.0;
for (int j = 0; j < thread_num_; ++j) {
value += inspect_value_vectors[j]->at(i);
}
value /= thread_num_;
inspect_values_[i] = value;
}
return inspect_values_;
} }
void AsyncExecutor::LoadInitModel() { void AsyncExecutor::LoadInitModel() {
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include <typeinfo>
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h" #include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
...@@ -36,10 +37,9 @@ class ExecutorThreadWorker { ...@@ -36,10 +37,9 @@ class ExecutorThreadWorker {
public: public:
ExecutorThreadWorker() {} ExecutorThreadWorker() {}
~ExecutorThreadWorker() {} ~ExecutorThreadWorker() {}
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const ProgramDesc& program);
void SetDataFeed(const DataFeed& datafeed);
void SetThreadId(int tid); void SetThreadId(int tid);
void CreateThreadOperators(const framework::ProgramDesc& program); void CreateThreadOperators(const ProgramDesc& program);
void SetRootScope(Scope* g_scope); void SetRootScope(Scope* g_scope);
void SetDevice(); void SetDevice();
void AddFidSet(); void AddFidSet();
...@@ -52,25 +52,16 @@ class ExecutorThreadWorker { ...@@ -52,25 +52,16 @@ class ExecutorThreadWorker {
void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; } void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void SetInspectVarName(const std::string& inspect_var_name); void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
void SetModelParamNames(const std::vector<std::string>& param_names); void SetModelParamNames(const std::vector<std::string>& param_names);
void SetSparseCommData(const std::map<std::string, int>& param_names); void SetDataFeed(DataFeed& datafeed); // NOLINT
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
void Train(); void Train();
const char* PickOneFile(); const char* PickOneFile();
void UpdateEpochNum(); void UpdateEpochNum();
void Reset();
void SetDenseCommTensor(const std::vector<std::string>& param_names) {}
void Initialize() {} void Initialize() {}
std::vector<float>& GetInspectValues() {return inspect_values_;}
public:
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::vector<std::string> s_thread_filelist_; // filelist
static bool s_is_first_worker_;
protected: protected:
// thread index // thread index
...@@ -88,14 +79,13 @@ class ExecutorThreadWorker { ...@@ -88,14 +79,13 @@ class ExecutorThreadWorker {
std::vector<OperatorBase *> ops_; std::vector<OperatorBase *> ops_;
// main program for training // main program for training
std::unique_ptr<framework::ProgramDesc> main_program_; std::unique_ptr<ProgramDesc> main_program_;
// binary data reader // binary data reader
std::shared_ptr<DataFeed> local_reader_; std::unique_ptr<DataFeed> local_reader_;
std::string inspect_var_name_; std::vector<std::string> inspect_var_names_;
std::vector<std::string> model_param_names_; std::vector<std::string> model_param_names_;
std::map<std::string, int> sparse_comm_data_;
// execution place // execution place
platform::Place place_; platform::Place place_;
...@@ -105,24 +95,26 @@ class ExecutorThreadWorker { ...@@ -105,24 +95,26 @@ class ExecutorThreadWorker {
// a thread scope, father scope is global score which is shared // a thread scope, father scope is global score which is shared
Scope* thread_scope_; Scope* thread_scope_;
private:
std::vector<float> inspect_values_;
}; };
class AsyncExecutor { class AsyncExecutor {
public: public:
explicit AsyncExecutor(const platform::Place& place); explicit AsyncExecutor(ProgramDesc& main_program, // NOLINT
const std::vector<std::string>& param_names,
TextClassDataFeed& data_feed, // NOLINT
unsigned int thread_num,
const platform::Place& place);
virtual ~AsyncExecutor() {} virtual ~AsyncExecutor() {}
static std::unique_ptr<ProgramDesc> LoadDescFromFile( static std::unique_ptr<ProgramDesc> LoadDescFromFile(
const std::string& filename); const std::string& filename);
void InitRootScope(Scope* scope); void InitRootScope(Scope* scope);
void SetInspectVarName(const std::string& inspect_var_name);
void SetParamNames(const std::vector<std::string>& param_names);
void SetMaxTrainingEpoch(const int max_epoch); void SetMaxTrainingEpoch(const int max_epoch);
Scope* GetRootScope() { return root_scope_; } Scope* GetRootScope() { return root_scope_; }
void SetThreadNum(const int thread_num);
void SetBatchSize(const int batch_size) { batch_size_ = batch_size; } void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
void SetFileList(const char* filelist);
void SetFileList(const std::vector<std::string> filelist);
void SetDataFeedName(const char* feedname);
void SetCommBatch(int comm_batch) { void SetCommBatch(int comm_batch) {
comm_batch_ = comm_batch; comm_batch_ = comm_batch;
} }
...@@ -140,37 +132,38 @@ class AsyncExecutor { ...@@ -140,37 +132,38 @@ class AsyncExecutor {
} }
void SetModelPrefix(const std::string& model_prefix); void SetModelPrefix(const std::string& model_prefix);
void SetDenseCommTensor(const std::vector<std::string>& dense_comm_tensor); virtual void PrepareThreads(const ProgramDesc& host_program);
void SetSparseCommTensor( void RunStartupProgram(const ProgramDesc& program, Scope* scope);
const std::vector<std::string>& sparse_comm_tensor); std::vector<float>& Run(const std::vector<std::string>& inspect_var_names);
void SetSparseCommData(const std::map<std::string, int>& sparse_comm_data);
virtual void PrepareThreads(const framework::ProgramDesc& host_program);
void RunStartupProgram(const framework::ProgramDesc& program,
framework::Scope* scope);
void RunAsyncExecutor(const ProgramDesc& host_program);
void LoadInitModel(); void LoadInitModel();
private:
void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
public: public:
unsigned int thread_num_; int thread_num_;
int max_epoch_; int max_epoch_;
int batch_size_; int batch_size_;
int comm_batch_; int comm_batch_;
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_; std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
std::vector<std::thread> threads_; std::vector<std::thread> threads_;
std::vector<std::string> filelist_; std::vector<std::string> inspect_var_names_;
std::string inspect_var_name_;
std::vector<std::string> model_param_names_; std::vector<std::string> model_param_names_;
std::vector<std::string> dense_comm_tensor_;
std::vector<std::string> sparse_comm_tensor_;
std::map<std::string, int> sparse_comm_data_;
std::string model_prefix_; std::string model_prefix_;
std::string model_path_; std::string model_path_;
std::string init_prog_file_; std::string init_prog_file_;
std::string init_model_file_; std::string init_model_file_;
std::string feed_name_;
Scope* root_scope_; Scope* root_scope_;
platform::Place place_; platform::Place place_;
private:
ProgramDesc& main_program_;
TextClassDataFeed& data_feed_;
std::vector<float> inspect_values_;
private:
static bool workers_initialized_;
}; };
} // namespace framework } // namespace framework
......
...@@ -38,6 +38,16 @@ DEFINE_bool(is_text_feed, false, "is_text_feed"); ...@@ -38,6 +38,16 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::vector<std::string> TextClassDataFeed::s_filelist_;
std::mutex TextClassDataFeed::s_locker_for_pick_file_;
unsigned int TextClassDataFeed::s_current_file_idx_ = 0;
size_t TextClassDataFeed::s_current_finished_file_cnt_ = 0;
unsigned int TextClassDataFeed::s_current_epoch_ = 0;
int TextClassDataFeed::s_current_save_epoch_ = 0;
std::mutex TextClassDataFeed::s_locker_epoch_start_;
std::condition_variable TextClassDataFeed::s_condition_epoch_start_;
bool TextClassDataFeed::s_epoch_start_flag_ = false;
void TextClassDataFeed::Init() { void TextClassDataFeed::Init() {
// hard coding for a specific datafeed // hard coding for a specific datafeed
feed_vec_.resize(2); feed_vec_.resize(2);
...@@ -59,6 +69,12 @@ void TextClassDataFeed::Init() { ...@@ -59,6 +69,12 @@ void TextClassDataFeed::Init() {
label_host_.reset(new int[10240], label_host_.reset(new int[10240],
[](int *p) {delete[] p;}); // max label in a batch [](int *p) {delete[] p;}); // max label in a batch
label_ptr_ = label_host_.get(); label_ptr_ = label_host_.get();
field_names_.clear();
}
TextClassDataFeed::TextClassDataFeed() {
Init();
} }
// todo: use elegant implemention for this function // todo: use elegant implemention for this function
...@@ -69,6 +85,7 @@ bool TextClassDataFeed::ReadBatch() { ...@@ -69,6 +85,7 @@ bool TextClassDataFeed::ReadBatch() {
int inst_idx = 0; int inst_idx = 0;
offset.resize(batch_size_ + 1); offset.resize(batch_size_ + 1);
offset[0] = 0; offset[0] = 0;
while (inst_idx < batch_size_) { while (inst_idx < batch_size_) {
int ptr_offset = 0; int ptr_offset = 0;
if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) { if (file_content_buffer_ptr_ - file_content_buffer_ >= file_size_) {
...@@ -125,6 +142,12 @@ bool TextClassDataFeed::ReadBatch() { ...@@ -125,6 +142,12 @@ bool TextClassDataFeed::ReadBatch() {
return true; return true;
} }
TextClassDataFeed::TextClassDataFeed(const TextClassDataFeed& data_feed) {
Init();
SetBatchSize(data_feed.batch_size_);
SetFieldNames(data_feed.field_names_);
}
void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) { for (unsigned int i = 0; i < use_slot_alias_.size(); ++i) {
if (name == use_slot_alias_[i]) { if (name == use_slot_alias_[i]) {
...@@ -133,30 +156,99 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) { ...@@ -133,30 +156,99 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
} }
} }
void TextClassDataFeed::SetFileList(const char* filelist) {
s_filelist_.clear();
std::ifstream fin(filelist);
PADDLE_ENFORCE(fin.good(),
"Opening file %s fail",
filelist);
std::string filename;
while (fin >> filename) {
LOG(ERROR) << "add " << filename.c_str() << " to filelist";
s_filelist_.push_back(filename);
}
fin.close();
}
void TextClassDataFeed::SetFieldNames(
const std::vector<std::string>& field_names) {
field_names_.clear();
field_names_.insert(field_names_.end(), field_names.begin(),
field_names.end());
}
bool TextClassDataFeed::SetFile(const char* filename) { bool TextClassDataFeed::SetFile(const char* filename) {
// termnum termid termid ... termid label // termnum termid termid ... termid label
int filesize = ReadWholeFile(filename, file_content_buffer_); std::ifstream ifs(filename, std::ios::binary);
// todo , remove magic number if (ifs.fail()) {
return false;
}
ifs.seekg(0, std::ios::end);
int filesize = ifs.tellg();
ifs.seekg(0, std::ios::beg);
ifs.read(file_content_buffer_, filesize);
if (filesize < 0 || filesize >= 1024 * 1024 * 1024) { if (filesize < 0 || filesize >= 1024 * 1024 * 1024) {
return false; return false;
} }
file_content_buffer_ptr_ = file_content_buffer_; file_content_buffer_ptr_ = file_content_buffer_;
file_size_ = filesize; file_size_ = filesize;
// todo , remove magic number
return true; return true;
} }
int TextClassDataFeed::ReadWholeFile(const std::string& filename, void TextClassDataFeed::UpdateEpochNum() {
char* buffer) { s_current_finished_file_cnt_++;
std::ifstream ifs(filename.c_str(), std::ios::binary);
if (ifs.fail()) { if (s_current_finished_file_cnt_ >= s_filelist_.size()) {
return -1; s_current_finished_file_cnt_ = 0;
s_current_epoch_++;
#if 1
LOG(WARNING) << "UpdateEpochNum: epoch = " << s_current_epoch_;
#endif
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = false;
}
} }
}
ifs.seekg(0, std::ios::end); void TextClassDataFeed::StartOneEpoch() {
int file_size = ifs.tellg(); std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
ifs.seekg(0, std::ios::beg); std::random_shuffle(s_filelist_.begin(), s_filelist_.end());
ifs.read(buffer, file_size); s_current_file_idx_ = 0;
return file_size; LOG(INFO) << "Beginning epoch " << s_current_epoch_;
{
std::lock_guard<std::mutex> lock(s_locker_epoch_start_);
s_epoch_start_flag_ = true;
}
s_condition_epoch_start_.notify_all();
}
void TextClassDataFeed::WaitNextEpoch() {
std::unique_lock<std::mutex> lock(s_locker_epoch_start_);
s_condition_epoch_start_.wait(lock, []{return s_epoch_start_flag_;});
}
const char* TextClassDataFeed::PickOneFile() {
std::string file_to_be_processed;
std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);
// One epoch has run over
// Wait for next epoch
if (s_current_file_idx_ >= s_filelist_.size()) {
LOG(ERROR) << "thread " << thread_id_
<< ": finish traing for epoch " << s_current_epoch_ + 1;
return NULL;
}
file_to_be_processed = s_filelist_[s_current_file_idx_];
s_current_file_idx_++;
return file_to_be_processed.c_str();
} }
} // namespace framework } // namespace framework
......
...@@ -47,24 +47,9 @@ struct Instance { ...@@ -47,24 +47,9 @@ struct Instance {
std::vector<Gauc> gauc_vec; std::vector<Gauc> gauc_vec;
}; };
class DataFeed {
DataFeed() {}
virtual ~DataFeed() {}
};
class BlockingQueueDataFeed : DataFeed {
BlockingQueueDataFeed() {}
virtual ~BlockingQueueDataFeed() {}
};
class ThreadedDataFeed : DataFeed {
ThreadedDataFeed() {}
virtual ~ThreadedDataFeed() {}
};
class DataFeed { class DataFeed {
public: public:
DataFeed() {} DataFeed() : default_batch_size_(1), batch_size_(0), thread_id_(0) {}
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init() = 0; virtual void Init() = 0;
/* /*
...@@ -93,6 +78,11 @@ class DataFeed { ...@@ -93,6 +78,11 @@ class DataFeed {
virtual void SetBatchSize(int batch) { default_batch_size_ = batch; } virtual void SetBatchSize(int batch) { default_batch_size_ = batch; }
virtual int GetBatchSize() { return batch_size_; } virtual int GetBatchSize() { return batch_size_; }
virtual void SetBufferSize(int buffer_size) {} virtual void SetBufferSize(int buffer_size) {}
virtual unsigned int GetCurrentEpoch() = 0;
virtual const char *PickOneFile() = 0;
virtual void UpdateEpochNum() = 0;
virtual void StartOneEpoch() = 0;
virtual void WaitNextEpoch() = 0;
std::vector<LoDTensor*>& GetFeedVec() { std::vector<LoDTensor*>& GetFeedVec() {
return feed_vec_; return feed_vec_;
...@@ -103,6 +93,9 @@ class DataFeed { ...@@ -103,6 +93,9 @@ class DataFeed {
return feed_vec_; return feed_vec_;
} }
int GetThreadId() {return thread_id_;}
void SetThreadId(int thread_id) {thread_id_ = thread_id;}
protected: protected:
std::vector<uint16_t> all_slot_ids_; std::vector<uint16_t> all_slot_ids_;
std::vector<uint16_t> use_slot_ids_; std::vector<uint16_t> use_slot_ids_;
...@@ -110,9 +103,14 @@ class DataFeed { ...@@ -110,9 +103,14 @@ class DataFeed {
std::vector<LoDTensor*> feed_vec_; std::vector<LoDTensor*> feed_vec_;
int default_batch_size_; int default_batch_size_;
int batch_size_; int batch_size_;
int thread_id_;
}; };
class TextClassDataFeed : public DataFeed { class TextClassDataFeed : public DataFeed {
public:
TextClassDataFeed();
TextClassDataFeed(const TextClassDataFeed& data_feed);
public: public:
virtual ~TextClassDataFeed() {} virtual ~TextClassDataFeed() {}
virtual void Init(); virtual void Init();
...@@ -120,25 +118,45 @@ class TextClassDataFeed : public DataFeed { ...@@ -120,25 +118,45 @@ class TextClassDataFeed : public DataFeed {
virtual void AddFeedVar(Variable* feed, const std::string& name); virtual void AddFeedVar(Variable* feed, const std::string& name);
virtual void BindScope(Scope* scope) {} virtual void BindScope(Scope* scope) {}
virtual bool SetFile(const char* filename); virtual bool SetFile(const char* filename);
virtual bool CheckFile(const char* filename) { virtual bool CheckFile(const char* filename) {
// TODO(xxx) // TODO(xxx)
return false; return false;
} }
void SetBatchSize(int batch) {batch_size_ = batch;} void SetBatchSize(int batch) {batch_size_ = batch;}
unsigned int GetCurrentEpoch() {return s_current_epoch_;}
void UpdateEpochNum();
void StartOneEpoch();
void WaitNextEpoch();
public:
void SetFieldNames(const std::vector<std::string>& field_names);
public:
static void SetFileList(const char* filelist);
private:
const char* PickOneFile();
private: private:
int ReadWholeFile(const std::string& filename, char* buffer);
char* file_content_buffer_; char* file_content_buffer_;
char* file_content_buffer_ptr_; char* file_content_buffer_ptr_;
int* batch_id_buffer_; int* batch_id_buffer_;
int* label_ptr_; int* label_ptr_;
int file_size_; int file_size_;
std::vector<std::string> names_; std::vector<std::string> field_names_;
std::shared_ptr<char> file_content_buffer_host_; std::shared_ptr<char> file_content_buffer_host_;
std::shared_ptr<int> batch_id_host_; std::shared_ptr<int> batch_id_host_;
std::shared_ptr<int> label_host_; std::shared_ptr<int> label_host_;
static std::vector<std::string> s_filelist_;
static std::mutex s_locker_for_pick_file_;
static unsigned int s_current_file_idx_;
static size_t s_current_finished_file_cnt_;
static unsigned int s_current_epoch_;
static int s_current_save_epoch_;
static std::mutex s_locker_epoch_start_;
static std::condition_variable s_condition_epoch_start_;
static bool s_epoch_start_flag_;
}; };
} // namespace framework } // namespace framework
......
...@@ -21,7 +21,10 @@ limitations under the License. */ ...@@ -21,7 +21,10 @@ limitations under the License. */
#ifdef _XOPEN_SOURCE #ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE #undef _XOPEN_SOURCE
#endif #endif
#include <vector>
#include <string>
#include "paddle/fluid/pybind/async_executor_py.h"
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/io/zero_copy_stream_impl.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
...@@ -29,58 +32,36 @@ limitations under the License. */ ...@@ -29,58 +32,36 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h" #include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/async_executor.h" #include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/pybind/async_executor_py.h" #include "paddle/fluid/framework/data_feed.h"
namespace py = pybind11; namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
void BindAsyncExecutor(py::module* m) { void BindAsyncExecutor(py::module* m) {
py::class_<paddle::AsyncExecutorParameter>(*m, "AsyncExecutorParameter") py::class_<framework::DataFeed>(*m, "DataFeed");
.def(py::init<>()) py::class_<framework::TextClassDataFeed,
.def("parse", framework::DataFeed>(*m, "TextDataFeed")
[](paddle::AsyncExecutorParameter &self, const std::string &conf_file) { .def(py::init())
int file_descriptor = open(conf_file.c_str(), O_RDONLY); .def("set_filelist",
google::protobuf::io::FileInputStream file_input(file_descriptor); [] (framework::TextClassDataFeed &self, const char *data_list_file) {
google::protobuf::TextFormat::Parse(&file_input, &self); self.SetFileList(data_list_file);
close(file_descriptor); })
} .def("set_batch_size", &framework::TextClassDataFeed::SetBatchSize)
); .def("set_field_names", &framework::TextClassDataFeed::SetFieldNames)
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor") .def("start_one_epoch", &framework::TextClassDataFeed::StartOneEpoch);
.def(py::init<const platform::Place&>())
.def("init",
[](framework::AsyncExecutor &self,
paddle::AsyncExecutorParameter &parameter,
framework::Scope *scope) {
paddle::BaseParameter base_param = parameter.base_param();
// TODO Extract parameter list from python side, instead of py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
// providing them in confgurations manually .def(py::init<framework::ProgramDesc&,
std::vector<std::string> param_names; std::vector<std::string>&,
for (int i = 0; i < base_param.model_param_names_size(); ++i) { framework::TextClassDataFeed&,
param_names.push_back(base_param.model_param_names(i)); unsigned int,
} const platform::Place&>())
paddle::framework::InitDevices(false); .def("init_root_scope", &framework::AsyncExecutor::InitRootScope)
self.InitRootScope(scope);
self.SetThreadNum(base_param.thread_num());
self.SetMaxTrainingEpoch(base_param.max_epoch());
self.SetFileList(base_param.filelist().c_str());
self.SetBatchSize(base_param.batch_size());
self.SetDataFeedName(base_param.datafeed_class().c_str());
self.SetInspectVarName(base_param.inspect_var_name());
self.SetParamNames(param_names);
self.SetModelPath(base_param.model_path());
self.SetModelPrefix(base_param.model_prefix());
self.SetInitProgFile(base_param.init_prog_file());
self.SetInitModelFile(base_param.init_model_file());
return;
}
)
.def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram) .def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram)
.def("load_init_model", &framework::AsyncExecutor::LoadInitModel) .def("run", &framework::AsyncExecutor::Run);
.def("run", &framework::AsyncExecutor::RunAsyncExecutor);
} // end BindAsyncExecutor } // end BindAsyncExecutor
} // end namespace framework } // end namespace pybind
} // end namespace paddle } // end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */ /* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#ifndef PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_ #ifndef PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#define PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_ #define PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11; namespace py = pybind11;
......
...@@ -21,22 +21,28 @@ from .framework import Program, default_main_program, Variable ...@@ -21,22 +21,28 @@ from .framework import Program, default_main_program, Variable
from . import core from . import core
from . import Executor from . import Executor
__all__ = ['AsyncExecutorParameter', 'AsyncExecutor'] __all__ = ['TextDataFeed', 'AsyncExecutor']
g_scope = core.Scope() g_scope = core.Scope()
class AsyncExecutorParameter(object): class TextDataFeed():
"""
AsyncExecutor configure parameter
Args:
None
"""
def __init__(self): def __init__(self):
self.parameter = core.AsyncExecutorParameter() self.feed = core.TextDataFeed()
def set_filelist(self, filelist):
self.feed.set_filelist(filelist)
def set_batch_size(self, batch_size):
self.feed.set_batch_size(batch_size)
def set_field_names(self, field_names):
if isinstance(field_names, Variable):
field_names = [field_names]
self.feed.set_field_names(field_names)
def parse(self, conf_file): def start_an_epoch(self):
self.parameter.parse(conf_file) self.feed.start_one_epoch()
class AsyncExecutor(object): class AsyncExecutor(object):
""" """
...@@ -50,39 +56,31 @@ class AsyncExecutor(object): ...@@ -50,39 +56,31 @@ class AsyncExecutor(object):
""" """
def __init__(self, def __init__(self,
async_executor_parameter, program,
place, param_names,
scope): data_feed,
if not isinstance(async_executor_parameter, AsyncExecutorParameter): thread_num,
raise TypeError( place=None,
"AsyncExecutor requires AsyncExecutorParameter as its parameter. " scope=None):
"But you passed in %s" %s (type(async_executor_parameter)) if program is None:
) program = default_main_program()
program_desc = program.desc
self.place = place
p = core.Place()
p.set_place(place)
self.executor = core.AsyncExecutor(p)
self.executor.init(async_executor_parameter.parameter, scope)
self._closed = False
self.parameter = async_executor_parameter.parameter
def close(self): if not isinstance(data_feed, TextDataFeed):
""" raise ValueError("data_feed for AsyncExecutor.run() type error")
Close this executor.
You can no long use this executor after calling this method. if place is None:
For the distributed training, this method would free the resource on PServers related to place = core.CPUPlace()
the current Trainer. if not isinstance(place, core.CPUPlace):
raise ValueError("AsyncExecutor only supports CPU device")
if isinstance(param_names, Variable):
param_names = [param_names]
p = core.Place()
p.set_place(place)
self.executor = core.AsyncExecutor(program_desc, param_names, data_feed.feed, thread_num, p)
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if not self._closed:
self._closed = True
def run_startup_program(self, def run_startup_program(self,
program=None, program=None,
scope=None): scope=None):
...@@ -94,8 +92,8 @@ class AsyncExecutor(object): ...@@ -94,8 +92,8 @@ class AsyncExecutor(object):
scope = g_scope scope = g_scope
self.executor.run_startup_program(program_desc, scope) self.executor.run_startup_program(program_desc, scope)
def run(self, program=None, scope=None): def run(self, inspect_vars, scope=None):
""" """
Run program by this Executor. Feed data by feed map, fetch result by fetch_list. Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according Python executor takes a program, add feed operators and fetch operators to this program according
...@@ -138,23 +136,16 @@ class AsyncExecutor(object): ...@@ -138,23 +136,16 @@ class AsyncExecutor(object):
>>> feed={'X': x}, >>> feed={'X': x},
>>> fetch_list=[loss.name]) >>> fetch_list=[loss.name])
""" """
if inspect_vars is not None:
if self._closed: if isinstance(inspect_vars, Variable):
raise RuntimeError("Attempted to use a closed Executor") inspect_vars = [inspect_vars]
inspect_var_names = [var.name for var in inspect_vars]
if program is None:
program = default_main_program()
program_desc = program.desc
if not isinstance(program, Program):
raise TypeError(
"Executor requires Program as its Parameter. But you passed in %s"
% (type(program)))
if scope is None: if scope is None:
scope = g_scope scope = g_scope
self.executor.run(program.desc)
def load_init_model(self): self.executor.init_root_scope(scope)
return self.executor.load_init_model()
evaluation = self.executor.run(inspect_var_names)
return evaluation
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册