提交 b7a202aa 编写于 作者: D dongdaxiang

add distributed optimizer factory

上级 70a5d4f7
...@@ -294,14 +294,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync( ...@@ -294,14 +294,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif #endif
} }
int FleetWrapper::RegisterClientToClientMsgHandler( int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
int msg_type, MsgHandlerFunc handler) { MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pslib_ptr_=" << pslib_ptr_; VLOG(3) << "pslib_ptr_=" << pslib_ptr_;
VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr; VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr;
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler( pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler);
msg_type, handler);
#else #else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler" VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib"; << " does nothing when no pslib";
...@@ -309,11 +308,10 @@ int FleetWrapper::RegisterClientToClientMsgHandler( ...@@ -309,11 +308,10 @@ int FleetWrapper::RegisterClientToClientMsgHandler(
return 0; return 0;
} }
int FleetWrapper::SendClientToClientMsg( int FleetWrapper::SendClientToClientMsg(int msg_type, int to_client_id,
int msg_type, int to_client_id, const std::string& msg) { const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->send_client2client_msg( pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg);
msg_type, to_client_id, msg);
#else #else
VLOG(0) << "FleetWrapper::SendClientToClientMsg" VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib"; << " does nothing when no pslib";
......
...@@ -45,6 +45,7 @@ void BindFleetWrapper(py::module* m) { ...@@ -45,6 +45,7 @@ void BindFleetWrapper(py::module* m) {
.def(py::init()) .def(py::init())
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync) .def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer) .def("init_server", &framework::FleetWrapper::InitServer)
.def("run_server", &framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker) .def("init_worker", &framework::FleetWrapper::InitWorker)
.def("stop_server", &framework::FleetWrapper::StopServer) .def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers); .def("gather_servers", &framework::FleetWrapper::GatherServers);
......
...@@ -11,13 +11,20 @@ ...@@ -11,13 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD'] __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object): class DeviceWorker(object):
def __init__(self): def __init__(self):
pass self.program_ = None
def set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc
def set_program(self, program):
self.program_ = program
def gen_worker_desc(self, trainer_desc): def gen_worker_desc(self, trainer_desc):
pass pass
...@@ -33,7 +40,7 @@ class Hogwild(DeviceWorker): ...@@ -33,7 +40,7 @@ class Hogwild(DeviceWorker):
class DownpourSGD(DeviceWorker): class DownpourSGD(DeviceWorker):
def __init__(self): def __init__(self):
super(Downpour, self).__init__() super(DownpourSGD, self).__init__()
def gen_worker_desc(self, trainer_desc): def gen_worker_desc(self, trainer_desc):
trainer_desc.device_worker_name = "DownpourWorker" trainer_desc.device_worker_name = "DownpourWorker"
...@@ -41,20 +48,21 @@ class DownpourSGD(DeviceWorker): ...@@ -41,20 +48,21 @@ class DownpourSGD(DeviceWorker):
pull_thread.device_num = trainer_desc.thread_num pull_thread.device_num = trainer_desc.thread_num
dense_table = pull_thread.dense_table.add() dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend( dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name) self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
dense_table.table_id = \ dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id self.fleet_desc_.trainer_param.dense_table[0].table_id
downpour = trainer_desc.downpour_param downpour = trainer_desc.downpour_param
sparse_table = downpour.sparse_table.add() sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \ sparse_table.table_id = \
fleet_desc.trainer_param.sparse_table[0].table_id self.fleet_desc_.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend( sparse_table.sparse_key_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_key) self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend( sparse_table.sparse_value_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_value) self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend( sparse_table.sparse_grad_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_gradient) self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param[ sparse_table.emb_dim = \
self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2 0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2 sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve # TODO(guru4elephant): hard code here, need to improve
...@@ -62,12 +70,49 @@ class DownpourSGD(DeviceWorker): ...@@ -62,12 +70,49 @@ class DownpourSGD(DeviceWorker):
dense_table = downpour.dense_table.add() dense_table = downpour.dense_table.add()
dense_table.table_id = \ dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id self.fleet_desc_.trainer_param.dense_table[0].table_id
dense_table.dense_value_name.extend( dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name) self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
dense_table.dense_grad_name.extend(fleet_desc.trainer_param.dense_table[ dense_table.dense_grad_name.extend(
self.fleet_desc_.trainer_param.dense_table[
0].dense_gradient_variable_name) 0].dense_gradient_variable_name)
downpour.skip_ops.extend(fleet_desc.trainer_param.skip_op) downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)
program_id = str(id(self.program_))
if self.program_ == None:
print("program of current device worker is not configured")
sys.exit(-1)
opt_info = self.program_._fleet_opt
program_configs = opt_info["program_configs"]
for program_id in program_configs:
if program_configs[program_id] == program_id:
pc = downpour.program_config.add()
pc.program_id = program_id
for i in program_configs[program_id]["push_sparse"]:
pc.push_sparse_table_id.extend([i])
for i in program_configs[program_id]["push_dense"]:
pc.push_dense_table_id.extend([i])
for i in program_configs[program_id]["pull_sparse"]:
pc.pull_sparse_table_id.extend([i])
for i in program_configs[program_id]["pull_dense"]:
pc.pull_dense_table_id.extend([i])
break
'''
for program_config in self.fleet_desc_.trainer_param.program_config:
if program_config.program_id == program_id:
pc = downpour.program_config.add()
pc.program_id = program_config.program_id
for i in program_config.push_sparse_table_id:
pc.push_sparse_table_id.extend([i])
for i in program_config.push_dense_table_id:
pc.push_dense_table_id.extend([i])
for i in program_config.pull_sparse_table_id:
pc.pull_sparse_table_id.extend([i])
for i in program_config.pull_dense_table_id:
pc.pull_dense_table_id.extend([i])
break
'''
class DeviceWorkerFactory(object): class DeviceWorkerFactory(object):
......
...@@ -632,14 +632,14 @@ class Executor(object): ...@@ -632,14 +632,14 @@ class Executor(object):
scope = global_scope() scope = global_scope()
if fetch_list is None: if fetch_list is None:
fetch_list = [] fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
if not compiled: if not compiled:
trainer = TrainerFactory().create_trainer(program._fleet_opt) trainer = TrainerFactory().create_trainer(program._fleet_opt)
trainer.set_program(program)
else: else:
trainer = TrainerFactory().create_trainer( trainer = TrainerFactory().create_trainer(
program.program._fleet_opt) program.program._fleet_opt)
trainer.set_program(program.program)
if thread <= 0: if thread <= 0:
trainer.set_thread(dataset.thread_num) trainer.set_thread(dataset.thread_num)
else: else:
......
...@@ -2707,6 +2707,7 @@ class Program(object): ...@@ -2707,6 +2707,7 @@ class Program(object):
# if this program has been optimized by distributed optimizer # if this program has been optimized by distributed optimizer
# fleet_opt will be given a value # fleet_opt will be given a value
self._fleet_opt = None self._fleet_opt = None
self._program_config = None
@property @property
def _is_mem_optimized(self): def _is_mem_optimized(self):
......
...@@ -54,10 +54,12 @@ class Fleet(object): ...@@ -54,10 +54,12 @@ class Fleet(object):
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
self._fleet_ptr.init_server(self._dist_desc_str) self._fleet_ptr.init_server(self._dist_desc_str,
ip = self._fleet_ptr.start_server() self.role_maker_.get_rank())
ips = self.role_maker_.all_gather(ip) self.local_ip_ = self._fleet_ptr.run_server()
self._fleet_ptr.gather_servers(ips, self.role_maker_.get_size()) self.all_ips_ = self.role_maker_.all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_,
self.role_maker_.get_size())
self.role_maker_.barrier_all() self.role_maker_.barrier_all()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
...@@ -73,8 +75,7 @@ class Fleet(object): ...@@ -73,8 +75,7 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
self.role_maker_.barrier_all() self.role_maker_.barrier_all()
self._fleet_ptr.init_work(self.dist_desc_str_, self._fleet_ptr.init_worker(self._dist_desc_str, [0],
self.role_maker.get_ips(),
self.role_maker_.get_size(), self.role_maker_.get_size(),
self.role_maker_.get_rank()) self.role_maker_.get_rank())
self.role_maker_.barrier_worker() self.role_maker_.barrier_worker()
......
...@@ -84,15 +84,21 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -84,15 +84,21 @@ class DistributedAdam(DistributedOptimizerImplBase):
worker.add_sparse_table(sparse_table_index, self.learning_rate_, worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
dense_table_index = 1 dense_table_index = 1
program_configs = [] program_configs = {}
param_grads_list = [] param_grads_list = []
for loss_index in range(len(losses)): for loss_index in range(len(losses)):
program_config = ps_param.trainer_param.program_config.add() #program_config = ps_param.trainer_param.program_config.add()
program_config.program_id = str( #program_config.program_id = str(
id(losses[loss_index].block.program)) # id(losses[loss_index].block.program))
program_config.pull_sparse_table_id.extend([sparse_table_index]) program_id = str(id(losses[loss_index].block.program))
program_config.push_sparse_table_id.extend([sparse_table_index]) program_configs[program_id] = {
"pull_sparse": [sparse_table_index],
"push_sparse": [sparse_table_index]
}
#program_config.pull_sparse_table_id.extend([sparse_table_index])
#program_config.push_sparse_table_id.extend([sparse_table_index])
params_grads = sorted( params_grads = sorted(
fluid.backward.append_backward(losses[loss_index], fluid.backward.append_backward(losses[loss_index],
parameter_list, no_grad_set), parameter_list, no_grad_set),
...@@ -122,8 +128,10 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -122,8 +128,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
params, grads) params, grads)
worker.add_dense_table(dense_table_index, self.learning_rate_, worker.add_dense_table(dense_table_index, self.learning_rate_,
params, grads) params, grads)
program_config.pull_dense_table_id.extend([dense_table_index]) program_configs[program_id]["pull_dense"] = [dense_table_index]
program_config.push_dense_table_id.extend([dense_table_index]) program_configs[program_id]["push_dense"] = [dense_table_index]
#program_config.pull_dense_table_id.extend([dense_table_index])
#program_config.push_dense_table_id.extend([dense_table_index])
if len(data_norm_params) != 0 and len(data_norm_grads) != 0: if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
dense_table_index += 1 dense_table_index += 1
server.add_data_norm_table(dense_table_index, server.add_data_norm_table(dense_table_index,
...@@ -131,20 +139,25 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -131,20 +139,25 @@ class DistributedAdam(DistributedOptimizerImplBase):
data_norm_params, data_norm_grads) data_norm_params, data_norm_grads)
worker.add_dense_table(dense_table_index, self.learning_rate_, worker.add_dense_table(dense_table_index, self.learning_rate_,
data_norm_params, data_norm_grads) data_norm_params, data_norm_grads)
program_config.pull_dense_table_id.extend([dense_table_index]) #program_config.pull_dense_table_id.extend([dense_table_index])
program_config.push_dense_table_id.extend([dense_table_index]) #program_config.push_dense_table_id.extend([dense_table_index])
program_config[program_id]["pull_dense"].extend(
[dense_table_index])
program_config[program_id]["push_dense"].extend(
[dense_table_index])
dense_table_index += 1 dense_table_index += 1
program_configs.append(program_config) #program_configs.append(program_config)
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc()) ps_param.trainer_param.CopyFrom(worker.get_desc())
for program_config in program_configs: #for program_config in program_configs:
ps_param.trainer_param.program_config.extend([program_config]) # ps_param.trainer_param.program_config.extend([program_config])
# Todo(guru4elephant): figure out how to support more sparse parameters # Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table # currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops) ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
opt_info = {} opt_info = {}
opt_info["program_configs"] = program_configs
opt_info["trainer"] = "DistMultiTrainer" opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD" opt_info["device_worker"] = "DownpourSGD"
opt_info["optimizer"] = "DownpourSGD" opt_info["optimizer"] = "DownpourSGD"
......
...@@ -34,6 +34,7 @@ class TrainerDesc(object): ...@@ -34,6 +34,7 @@ class TrainerDesc(object):
self.proto_desc.thread_num = mp.cpu_count() self.proto_desc.thread_num = mp.cpu_count()
self.fleet_desc_ = None self.fleet_desc_ = None
self.device_worker_ = None self.device_worker_ = None
self.program_ = None
def set_thread(self, thread_num): def set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num self.proto_desc.thread_num = thread_num
...@@ -47,6 +48,9 @@ class TrainerDesc(object): ...@@ -47,6 +48,9 @@ class TrainerDesc(object):
def gen_trainer_desc(self): def gen_trainer_desc(self):
pass pass
def set_program(self, program):
self.program_ = program
def _desc(self): def _desc(self):
return text_format.MessageToString(self.proto_desc) return text_format.MessageToString(self.proto_desc)
...@@ -70,19 +74,5 @@ class DistMultiTrainer(TrainerDesc): ...@@ -70,19 +74,5 @@ class DistMultiTrainer(TrainerDesc):
def gen_trainer_desc(self): def gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc() super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
self.device_worker_.set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc) self.device_worker_.gen_worker_desc(self.proto_desc)
def set_program_config(self, fleet_desc, program_id):
for program_config in fleet_desc.trainer_param.program_config:
if program_config.program_id == program_id:
pc = self.proto_desc.downpour_param.program_config.add()
pc.program_id = program_config.program_id
for i in program_config.push_sparse_table_id:
pc.push_sparse_table_id.extend([i])
for i in program_config.push_dense_table_id:
pc.push_dense_table_id.extend([i])
for i in program_config.pull_sparse_table_id:
pc.pull_sparse_table_id.extend([i])
for i in program_config.pull_dense_table_id:
pc.pull_dense_table_id.extend([i])
break
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .trainer_desc import MultiTrainer from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild from .device_worker import Hogwild, DownpourSGD
__all__ = ["TrainerFactory"] __all__ = ["TrainerFactory"]
...@@ -30,13 +30,12 @@ class TrainerFactory(object): ...@@ -30,13 +30,12 @@ class TrainerFactory(object):
trainer = MultiTrainer() trainer = MultiTrainer()
device_worker = Hogwild() device_worker = Hogwild()
trainer.set_device_worker(device_worker) trainer.set_device_worker(device_worker)
trainer.gen_trainer_desc()
else: else:
trainer_class = opt_info["trainer"] trainer_class = opt_info["trainer"]
device_worker_class = opt_info["device_worker"] device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]() trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]() device_worker = globals()[device_worker_class]()
device_worker.set_fleet_desc(opt_info["fleet_desc"])
trainer.set_device_worker(device_worker) trainer.set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"]) trainer.set_fleet_desc(opt_info["fleet_desc"])
trainer.gen_trainer_desc()
return trainer return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册