提交 016a0687 编写于 作者: H heqiaozhi

stop server

上级 8e3fe2d7
...@@ -83,6 +83,10 @@ uint64_t AsyncExecutor::StartServer() { ...@@ -83,6 +83,10 @@ uint64_t AsyncExecutor::StartServer() {
return _pslib_ptr->run_server(); return _pslib_ptr->run_server();
} }
void AsyncExecutor::StopServer() {
_pslib_ptr->stop_server();
}
void AsyncExecutor::GatherServers(std::vector<uint64_t>& host_sign_list, int node_num) { void AsyncExecutor::GatherServers(std::vector<uint64_t>& host_sign_list, int node_num) {
_pslib_ptr->gather_servers(host_sign_list.data(), node_num); _pslib_ptr->gather_servers(host_sign_list.data(), node_num);
} }
......
...@@ -67,6 +67,7 @@ class AsyncExecutor { ...@@ -67,6 +67,7 @@ class AsyncExecutor {
void InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index); void InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index);
//void ConfigWorker() {} //void ConfigWorker() {}
uint64_t StartServer(); uint64_t StartServer();
void StopServer();
void GatherServers(std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel(); void InitModel();
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
......
...@@ -569,7 +569,6 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) { ...@@ -569,7 +569,6 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) {
} }
void AsyncExecutorThreadWorker::PushSparse(int table_id) { void AsyncExecutorThreadWorker::PushSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; //TODO auto slot_dim = _param_config->slot_dim; //TODO
auto fea_dim = _param_config->fea_dim;//_current_train_job.fea_dim();TODO auto fea_dim = _param_config->fea_dim;//_current_train_job.fea_dim();TODO
auto& features = _features[table_id]; auto& features = _features[table_id];
...@@ -592,19 +591,20 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) { ...@@ -592,19 +591,20 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) {
} }
Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]); Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]);
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>(); LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
//int count = g_tensor->numel(); if (g_tensor == NULL) {
float* g = g_tensor->data<float>(); LOG(ERROR) << "var[" << _param_config->gradient_var[table_id][slot_idx - 1] << "] not found";
/* exit(-1);
if (FLAGS_scale_sparse_gradient_with_batch_size) {
Eigen::Map<Eigen::MatrixXf> g_mat(g, 1, tensor->numel());
g_mat *= _batch_size;
} }
*/ float* g = g_tensor->data<float>();
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]); Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << feed_vec[slot_idx] << "] not found";
exit(-1);
}
int len = tensor->lod()[0].back(); int len = tensor->lod()[0].back();
//assert(slot_dim * len == count); assert(slot_dim * len == g_tensor->numel());
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
for (auto id_idx = 0u; id_idx < len; ++id_idx){ for (auto id_idx = 0u; id_idx < len; ++id_idx){
if (ids[id_idx] == 0) { if (ids[id_idx] == 0) {
......
...@@ -51,6 +51,7 @@ void BindAsyncExecutor(py::module* m) { ...@@ -51,6 +51,7 @@ void BindAsyncExecutor(py::module* m) {
.def("init_server", &framework::AsyncExecutor::InitServer) .def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker) .def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer) .def("start_server", &framework::AsyncExecutor::StartServer)
.def("stop_server", &framework::AsyncExecutor::StopServer)
.def("gather_servers", &framework::AsyncExecutor::GatherServers) .def("gather_servers", &framework::AsyncExecutor::GatherServers)
.def("init_model", &framework::AsyncExecutor::InitModel) .def("init_model", &framework::AsyncExecutor::InitModel)
.def("save_model", &framework::AsyncExecutor::SaveModel); .def("save_model", &framework::AsyncExecutor::SaveModel);
......
...@@ -151,7 +151,10 @@ class AsyncExecutor(object): ...@@ -151,7 +151,10 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, debug)
self.instance.barrier_all() self.instance.barrier_all() #worker do all things
if self.instance.is_first_worker():
self.executor.stop_server()
self.instance.barrier_all() #sync
def config_distributed_nodes(self, dist_opt): def config_distributed_nodes(self, dist_opt):
...@@ -164,6 +167,9 @@ class AsyncExecutor(object): ...@@ -164,6 +167,9 @@ class AsyncExecutor(object):
def get_instance(self): def get_instance(self):
return self.instance return self.instance
#def stop_server(self):
# self.executor.stop_server()
def init_server(self, dist_desc): def init_server(self, dist_desc):
self.executor.init_server(dist_desc, self.instance._rankid) self.executor.init_server(dist_desc, self.instance._rankid)
ip = self.executor.start_server() ip = self.executor.start_server()
...@@ -174,6 +180,7 @@ class AsyncExecutor(object): ...@@ -174,6 +180,7 @@ class AsyncExecutor(object):
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait worker do all things self.instance.barrier_all() #wait worker do all things
self.instance.barrier_all() #sync
def init_worker(self, dist_desc): def init_worker(self, dist_desc):
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册