提交 016a0687 编写于 作者: H heqiaozhi

stop server

上级 8e3fe2d7
......@@ -83,6 +83,10 @@ uint64_t AsyncExecutor::StartServer() {
return _pslib_ptr->run_server();
}
void AsyncExecutor::StopServer() {
_pslib_ptr->stop_server();
}
void AsyncExecutor::GatherServers(std::vector<uint64_t>& host_sign_list, int node_num) {
_pslib_ptr->gather_servers(host_sign_list.data(), node_num);
}
......
......@@ -67,6 +67,7 @@ class AsyncExecutor {
void InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index);
//void ConfigWorker() {}
uint64_t StartServer();
void StopServer();
void GatherServers(std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel();
void SaveModel(const std::string& path);
......
......@@ -569,7 +569,6 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) {
}
void AsyncExecutorThreadWorker::PushSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; //TODO
auto fea_dim = _param_config->fea_dim;//_current_train_job.fea_dim();TODO
auto& features = _features[table_id];
......@@ -592,19 +591,20 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) {
}
Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]);
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
//int count = g_tensor->numel();
float* g = g_tensor->data<float>();
/*
if (FLAGS_scale_sparse_gradient_with_batch_size) {
Eigen::Map<Eigen::MatrixXf> g_mat(g, 1, tensor->numel());
g_mat *= _batch_size;
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << _param_config->gradient_var[table_id][slot_idx - 1] << "] not found";
exit(-1);
}
*/
float* g = g_tensor->data<float>();
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << feed_vec[slot_idx] << "] not found";
exit(-1);
}
int len = tensor->lod()[0].back();
//assert(slot_dim * len == count);
assert(slot_dim * len == g_tensor->numel());
int64_t* ids = tensor->data<int64_t>();
for (auto id_idx = 0u; id_idx < len; ++id_idx){
if (ids[id_idx] == 0) {
......
......@@ -51,6 +51,7 @@ void BindAsyncExecutor(py::module* m) {
.def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer)
.def("stop_server", &framework::AsyncExecutor::StopServer)
.def("gather_servers", &framework::AsyncExecutor::GatherServers)
.def("init_model", &framework::AsyncExecutor::InitModel)
.def("save_model", &framework::AsyncExecutor::SaveModel);
......
......@@ -151,7 +151,10 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num,
fetch_var_names, debug)
self.instance.barrier_all()
self.instance.barrier_all() #worker do all things
if self.instance.is_first_worker():
self.executor.stop_server()
self.instance.barrier_all() #sync
def config_distributed_nodes(self, dist_opt):
......@@ -164,6 +167,9 @@ class AsyncExecutor(object):
def get_instance(self):
return self.instance
#def stop_server(self):
# self.executor.stop_server()
def init_server(self, dist_desc):
self.executor.init_server(dist_desc, self.instance._rankid)
ip = self.executor.start_server()
......@@ -174,6 +180,7 @@ class AsyncExecutor(object):
self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait worker do all things
self.instance.barrier_all() #sync
def init_worker(self, dist_desc):
self.instance.barrier_all() #wait all server start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册