提交 b7a202aa 编写于 作者: D dongdaxiang

add distributed optimizer factory

上级 70a5d4f7
......@@ -294,14 +294,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
int FleetWrapper::RegisterClientToClientMsgHandler(
int msg_type, MsgHandlerFunc handler) {
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
#ifdef PADDLE_WITH_PSLIB
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pslib_ptr_=" << pslib_ptr_;
VLOG(3) << "_worker_ptr=" << pslib_ptr_->_worker_ptr;
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(
msg_type, handler);
pslib_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, handler);
#else
VLOG(0) << "FleetWrapper::RegisterClientToClientMsgHandler"
<< " does nothing when no pslib";
......@@ -309,11 +308,10 @@ int FleetWrapper::RegisterClientToClientMsgHandler(
return 0;
}
int FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
int FleetWrapper::SendClientToClientMsg(int msg_type, int to_client_id,
const std::string& msg) {
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_->_worker_ptr->send_client2client_msg(
msg_type, to_client_id, msg);
pslib_ptr_->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg);
#else
VLOG(0) << "FleetWrapper::SendClientToClientMsg"
<< " does nothing when no pslib";
......
......@@ -45,6 +45,7 @@ void BindFleetWrapper(py::module* m) {
.def(py::init())
.def("push_dense", &framework::FleetWrapper::PushDenseVarsSync)
.def("init_server", &framework::FleetWrapper::InitServer)
.def("run_server", &framework::FleetWrapper::RunServer)
.def("init_worker", &framework::FleetWrapper::InitWorker)
.def("stop_server", &framework::FleetWrapper::StopServer)
.def("gather_servers", &framework::FleetWrapper::GatherServers);
......
......@@ -11,13 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class DeviceWorker(object):
def __init__(self):
pass
self.program_ = None
def set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc
def set_program(self, program):
self.program_ = program
def gen_worker_desc(self, trainer_desc):
pass
......@@ -33,7 +40,7 @@ class Hogwild(DeviceWorker):
class DownpourSGD(DeviceWorker):
def __init__(self):
super(Downpour, self).__init__()
super(DownpourSGD, self).__init__()
def gen_worker_desc(self, trainer_desc):
trainer_desc.device_worker_name = "DownpourWorker"
......@@ -41,20 +48,21 @@ class DownpourSGD(DeviceWorker):
pull_thread.device_num = trainer_desc.thread_num
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name)
self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id
self.fleet_desc_.trainer_param.dense_table[0].table_id
downpour = trainer_desc.downpour_param
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
fleet_desc.trainer_param.sparse_table[0].table_id
self.fleet_desc_.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_key)
self.fleet_desc_.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_value)
self.fleet_desc_.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend(
fleet_desc.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param[
self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = \
self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve
......@@ -62,12 +70,49 @@ class DownpourSGD(DeviceWorker):
dense_table = downpour.dense_table.add()
dense_table.table_id = \
fleet_desc.trainer_param.dense_table[0].table_id
self.fleet_desc_.trainer_param.dense_table[0].table_id
dense_table.dense_value_name.extend(
fleet_desc.trainer_param.dense_table[0].dense_variable_name)
dense_table.dense_grad_name.extend(fleet_desc.trainer_param.dense_table[
self.fleet_desc_.trainer_param.dense_table[0].dense_variable_name)
dense_table.dense_grad_name.extend(
self.fleet_desc_.trainer_param.dense_table[
0].dense_gradient_variable_name)
downpour.skip_ops.extend(fleet_desc.trainer_param.skip_op)
downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op)
program_id = str(id(self.program_))
if self.program_ == None:
print("program of current device worker is not configured")
sys.exit(-1)
opt_info = self.program_._fleet_opt
program_configs = opt_info["program_configs"]
for program_id in program_configs:
if program_configs[program_id] == program_id:
pc = downpour.program_config.add()
pc.program_id = program_id
for i in program_configs[program_id]["push_sparse"]:
pc.push_sparse_table_id.extend([i])
for i in program_configs[program_id]["push_dense"]:
pc.push_dense_table_id.extend([i])
for i in program_configs[program_id]["pull_sparse"]:
pc.pull_sparse_table_id.extend([i])
for i in program_configs[program_id]["pull_dense"]:
pc.pull_dense_table_id.extend([i])
break
'''
for program_config in self.fleet_desc_.trainer_param.program_config:
if program_config.program_id == program_id:
pc = downpour.program_config.add()
pc.program_id = program_config.program_id
for i in program_config.push_sparse_table_id:
pc.push_sparse_table_id.extend([i])
for i in program_config.push_dense_table_id:
pc.push_dense_table_id.extend([i])
for i in program_config.pull_sparse_table_id:
pc.pull_sparse_table_id.extend([i])
for i in program_config.pull_dense_table_id:
pc.pull_dense_table_id.extend([i])
break
'''
class DeviceWorkerFactory(object):
......
......@@ -632,14 +632,14 @@ class Executor(object):
scope = global_scope()
if fetch_list is None:
fetch_list = []
compiled = isinstance(program, compiler.CompiledProgram)
if not compiled:
trainer = TrainerFactory().create_trainer(program._fleet_opt)
trainer.set_program(program)
else:
trainer = TrainerFactory().create_trainer(
program.program._fleet_opt)
trainer.set_program(program.program)
if thread <= 0:
trainer.set_thread(dataset.thread_num)
else:
......
......@@ -2707,6 +2707,7 @@ class Program(object):
# if this program has been optimized by distributed optimizer
# fleet_opt will be given a value
self._fleet_opt = None
self._program_config = None
@property
def _is_mem_optimized(self):
......
......@@ -54,10 +54,12 @@ class Fleet(object):
else:
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
self._fleet_ptr.init_server(self._dist_desc_str)
ip = self._fleet_ptr.start_server()
ips = self.role_maker_.all_gather(ip)
self._fleet_ptr.gather_servers(ips, self.role_maker_.get_size())
self._fleet_ptr.init_server(self._dist_desc_str,
self.role_maker_.get_rank())
self.local_ip_ = self._fleet_ptr.run_server()
self.all_ips_ = self.role_maker_.all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_,
self.role_maker_.get_size())
self.role_maker_.barrier_all()
else:
print("You should run DistributedOptimizer.minimize() first")
......@@ -73,8 +75,7 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
self.role_maker_.barrier_all()
self._fleet_ptr.init_work(self.dist_desc_str_,
self.role_maker.get_ips(),
self._fleet_ptr.init_worker(self._dist_desc_str, [0],
self.role_maker_.get_size(),
self.role_maker_.get_rank())
self.role_maker_.barrier_worker()
......
......@@ -84,15 +84,21 @@ class DistributedAdam(DistributedOptimizerImplBase):
worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb)
dense_table_index = 1
program_configs = []
program_configs = {}
param_grads_list = []
for loss_index in range(len(losses)):
program_config = ps_param.trainer_param.program_config.add()
program_config.program_id = str(
id(losses[loss_index].block.program))
program_config.pull_sparse_table_id.extend([sparse_table_index])
program_config.push_sparse_table_id.extend([sparse_table_index])
#program_config = ps_param.trainer_param.program_config.add()
#program_config.program_id = str(
# id(losses[loss_index].block.program))
program_id = str(id(losses[loss_index].block.program))
program_configs[program_id] = {
"pull_sparse": [sparse_table_index],
"push_sparse": [sparse_table_index]
}
#program_config.pull_sparse_table_id.extend([sparse_table_index])
#program_config.push_sparse_table_id.extend([sparse_table_index])
params_grads = sorted(
fluid.backward.append_backward(losses[loss_index],
parameter_list, no_grad_set),
......@@ -122,8 +128,10 @@ class DistributedAdam(DistributedOptimizerImplBase):
params, grads)
worker.add_dense_table(dense_table_index, self.learning_rate_,
params, grads)
program_config.pull_dense_table_id.extend([dense_table_index])
program_config.push_dense_table_id.extend([dense_table_index])
program_configs[program_id]["pull_dense"] = [dense_table_index]
program_configs[program_id]["push_dense"] = [dense_table_index]
#program_config.pull_dense_table_id.extend([dense_table_index])
#program_config.push_dense_table_id.extend([dense_table_index])
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
dense_table_index += 1
server.add_data_norm_table(dense_table_index,
......@@ -131,20 +139,25 @@ class DistributedAdam(DistributedOptimizerImplBase):
data_norm_params, data_norm_grads)
worker.add_dense_table(dense_table_index, self.learning_rate_,
data_norm_params, data_norm_grads)
program_config.pull_dense_table_id.extend([dense_table_index])
program_config.push_dense_table_id.extend([dense_table_index])
#program_config.pull_dense_table_id.extend([dense_table_index])
#program_config.push_dense_table_id.extend([dense_table_index])
program_config[program_id]["pull_dense"].extend(
[dense_table_index])
program_config[program_id]["push_dense"].extend(
[dense_table_index])
dense_table_index += 1
program_configs.append(program_config)
#program_configs.append(program_config)
ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc())
for program_config in program_configs:
ps_param.trainer_param.program_config.extend([program_config])
#for program_config in program_configs:
# ps_param.trainer_param.program_config.extend([program_config])
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param.trainer_param.skip_op.extend(worker_skipped_ops)
opt_info = {}
opt_info["program_configs"] = program_configs
opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD"
opt_info["optimizer"] = "DownpourSGD"
......
......@@ -34,6 +34,7 @@ class TrainerDesc(object):
self.proto_desc.thread_num = mp.cpu_count()
self.fleet_desc_ = None
self.device_worker_ = None
self.program_ = None
def set_thread(self, thread_num):
self.proto_desc.thread_num = thread_num
......@@ -47,6 +48,9 @@ class TrainerDesc(object):
def gen_trainer_desc(self):
pass
def set_program(self, program):
self.program_ = program
def _desc(self):
return text_format.MessageToString(self.proto_desc)
......@@ -70,19 +74,5 @@ class DistMultiTrainer(TrainerDesc):
def gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
self.device_worker_.set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc)
def set_program_config(self, fleet_desc, program_id):
for program_config in fleet_desc.trainer_param.program_config:
if program_config.program_id == program_id:
pc = self.proto_desc.downpour_param.program_config.add()
pc.program_id = program_config.program_id
for i in program_config.push_sparse_table_id:
pc.push_sparse_table_id.extend([i])
for i in program_config.push_dense_table_id:
pc.push_dense_table_id.extend([i])
for i in program_config.pull_sparse_table_id:
pc.pull_sparse_table_id.extend([i])
for i in program_config.pull_dense_table_id:
pc.pull_dense_table_id.extend([i])
break
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .trainer_desc import MultiTrainer
from .device_worker import Hogwild
from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild, DownpourSGD
__all__ = ["TrainerFactory"]
......@@ -30,13 +30,12 @@ class TrainerFactory(object):
trainer = MultiTrainer()
device_worker = Hogwild()
trainer.set_device_worker(device_worker)
trainer.gen_trainer_desc()
else:
trainer_class = opt_info["trainer"]
device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]()
device_worker.set_fleet_desc(opt_info["fleet_desc"])
trainer.set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"])
trainer.gen_trainer_desc()
return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册