提交 ceac9df8 编写于 作者: D dongdaxiang

fix code style for incubator

上级 aa46caf3
...@@ -26,8 +26,8 @@ class DeviceWorker(object): ...@@ -26,8 +26,8 @@ class DeviceWorker(object):
""" """
Init. Init.
""" """
self.program_ = None self._program = None
self.infer_ = None self._infer = None
def _set_infer(self, infer=False): def _set_infer(self, infer=False):
""" """
...@@ -36,7 +36,7 @@ class DeviceWorker(object): ...@@ -36,7 +36,7 @@ class DeviceWorker(object):
Args: Args:
infer(bool): whether to do inference infer(bool): whether to do inference
""" """
self.infer_ = infer self._infer = infer
def _set_fleet_desc(self, fleet_desc): def _set_fleet_desc(self, fleet_desc):
""" """
...@@ -45,7 +45,7 @@ class DeviceWorker(object): ...@@ -45,7 +45,7 @@ class DeviceWorker(object):
Args: Args:
fleet_desc(PSParameter): pslib.PSParameter object fleet_desc(PSParameter): pslib.PSParameter object
""" """
self.fleet_desc_ = fleet_desc self._fleet_desc = fleet_desc
def _set_program(self, program): def _set_program(self, program):
""" """
...@@ -54,7 +54,7 @@ class DeviceWorker(object): ...@@ -54,7 +54,7 @@ class DeviceWorker(object):
Args: Args:
program(Program): a Program object program(Program): a Program object
""" """
self.program_ = program self._program = program
def _gen_worker_desc(self, trainer_desc): def _gen_worker_desc(self, trainer_desc):
""" """
...@@ -88,7 +88,7 @@ class Hogwild(DeviceWorker): ...@@ -88,7 +88,7 @@ class Hogwild(DeviceWorker):
trainer_desc(TrainerDesc): a TrainerDesc object trainer_desc(TrainerDesc): a TrainerDesc object
""" """
trainer_desc.device_worker_name = "HogwildWorker" trainer_desc.device_worker_name = "HogwildWorker"
if self.infer_: if self._infer:
# just ignore feed op for inference model # just ignore feed op for inference model
trainer_desc.hogwild_param.skip_ops.extend(["feed"]) trainer_desc.hogwild_param.skip_ops.extend(["feed"])
...@@ -113,11 +113,11 @@ class DownpourSGD(DeviceWorker): ...@@ -113,11 +113,11 @@ class DownpourSGD(DeviceWorker):
trainer_desc(TrainerDesc): a TrainerDesc object trainer_desc(TrainerDesc): a TrainerDesc object
""" """
dense_table_set = set() dense_table_set = set()
program_id = str(id(self.program_)) program_id = str(id(self._program))
if self.program_ == None: if self._program == None:
print("program of current device worker is not configured") print("program of current device worker is not configured")
exit(-1) exit(-1)
opt_info = self.program_._fleet_opt opt_info = self._program._fleet_opt
program_configs = opt_info["program_configs"] program_configs = opt_info["program_configs"]
downpour = trainer_desc.downpour_param downpour = trainer_desc.downpour_param
...@@ -140,7 +140,7 @@ class DownpourSGD(DeviceWorker): ...@@ -140,7 +140,7 @@ class DownpourSGD(DeviceWorker):
trainer_desc.device_worker_name = "DownpourWorker" trainer_desc.device_worker_name = "DownpourWorker"
pull_thread = trainer_desc.pull_dense_param pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num pull_thread.device_num = trainer_desc.thread_num
for i in self.fleet_desc_.trainer_param.dense_table: for i in self._fleet_desc.trainer_param.dense_table:
if i.table_id in dense_table_set: if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add() dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_value_name.extend(i.dense_variable_name)
...@@ -148,29 +148,29 @@ class DownpourSGD(DeviceWorker): ...@@ -148,29 +148,29 @@ class DownpourSGD(DeviceWorker):
i.table_id i.table_id
sparse_table = downpour.sparse_table.add() sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \ sparse_table.table_id = \
self.fleet_desc_.trainer_param.sparse_table[0].table_id self._fleet_desc.trainer_param.sparse_table[0].table_id
sparse_table.sparse_key_name.extend( sparse_table.sparse_key_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_key) self._fleet_desc.trainer_param.sparse_table[0].slot_key)
sparse_table.sparse_value_name.extend( sparse_table.sparse_value_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_value) self._fleet_desc.trainer_param.sparse_table[0].slot_value)
sparse_table.sparse_grad_name.extend( sparse_table.sparse_grad_name.extend(
self.fleet_desc_.trainer_param.sparse_table[0].slot_gradient) self._fleet_desc.trainer_param.sparse_table[0].slot_gradient)
sparse_table.emb_dim = \ sparse_table.emb_dim = \
self.fleet_desc_.server_param.downpour_server_param.downpour_table_param[ self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
0].accessor.fea_dim - 2 0].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2 sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve # TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click" sparse_table.label_var_name = "click"
for i in self.fleet_desc_.trainer_param.dense_table: for i in self._fleet_desc.trainer_param.dense_table:
if i.table_id in dense_table_set: if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add() dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(i.dense_variable_name) dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend( dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name) i.dense_gradient_variable_name)
downpour.skip_ops.extend(self.fleet_desc_.trainer_param.skip_op) downpour.skip_ops.extend(self._fleet_desc.trainer_param.skip_op)
if self.infer_: if self._infer:
downpour.push_dense = False downpour.push_dense = False
downpour.push_sparse = False downpour.push_sparse = False
......
...@@ -23,10 +23,10 @@ class RoleMakerBase(object): ...@@ -23,10 +23,10 @@ class RoleMakerBase(object):
""" """
def __init__(self): def __init__(self):
self.role_maker_name_ = "" self._role_maker_name = ""
self.trainer_endpoints_ = [] self._trainer_endpoints = []
self.pserver_endpoints_ = [] self._pserver_endpoints = []
self.role_is_generated_ = False self._role_is_generated = False
def _is_worker(self): def _is_worker(self):
""" """
...@@ -45,20 +45,20 @@ class RoleMakerBase(object): ...@@ -45,20 +45,20 @@ class RoleMakerBase(object):
return get local ip return get local ip
""" """
import socket import socket
self.ip_ = socket.gethostbyname(socket.gethostname()) self._ip = socket.gethostbyname(socket.gethostname())
return self.ip_ return self._ip
def _get_trainer_endpoints(self): def _get_trainer_endpoints(self):
""" """
return trainer endpoints return trainer endpoints
""" """
return self.trainer_endpoints_ return self._trainer_endpoints
def _get_pserver_endpoints(self): def _get_pserver_endpoints(self):
""" """
return pserver endpoints return pserver endpoints
""" """
return self.pserver_endpoints_ return self._pserver_endpoints
def _generate_role(self): def _generate_role(self):
""" """
...@@ -76,59 +76,59 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -76,59 +76,59 @@ class MPIRoleMaker(RoleMakerBase):
def __init__(self): def __init__(self):
super(MPIRoleMaker, self).__init__() super(MPIRoleMaker, self).__init__()
from mpi4py import MPI from mpi4py import MPI
self.comm_ = MPI.COMM_WORLD self._comm = MPI.COMM_WORLD
self.MPI = MPI self.MPI = MPI
self.ips_ = None self._ips = None
def _get_rank(self): def _get_rank(self):
""" """
return rank return rank
""" """
self.rank_ = self.comm_.Get_rank() self._rank = self._comm.Get_rank()
return self.rank_ return self._rank
def _get_size(self): def _get_size(self):
""" """
return size return size
""" """
self.size_ = self.comm_.Get_size() self._size = self._comm.Get_size()
return self.size_ return self._size
def _all_gather(self, obj): def _all_gather(self, obj):
""" """
all_gather(obj) will call MPI's allgather function all_gather(obj) will call MPI's allgather function
""" """
self._barrier_all() self._barrier_all()
return self.comm_.allgather(obj) return self._comm.allgather(obj)
def _worker_gather(self, obj): def _worker_gather(self, obj):
""" """
worker_gather(obj) will call MPI's allgather function worker_gather(obj) will call MPI's allgather function
""" """
if self._is_worker(): if self._is_worker():
self.node_type_comm_.barrier() self._node_type_comm.barrier()
return self.node_type_comm_.allgather(obj) return self._node_type_comm.allgather(obj)
return None return None
def _barrier_all(self): def _barrier_all(self):
""" """
barrier_all() will call MPI's barrier_all function barrier_all() will call MPI's barrier_all function
""" """
self.comm_.barrier() self._comm.barrier()
def _get_ips(self): def _get_ips(self):
""" """
collect current distributed job's ip list collect current distributed job's ip list
""" """
if self.ips_ == None: if self._ips == None:
self.ips_ = self.comm_.allgather(self._get_local_ip()) self._ips = self._comm.allgather(self._get_local_ip())
return self.ips_ return self._ips
def _finalize(self): def _finalize(self):
""" """
finalize the current MPI instance. finalize the current MPI instance.
""" """
self.comm_.finalize() self._comm.finalize()
class MPISymetricRoleMaker(MPIRoleMaker): class MPISymetricRoleMaker(MPIRoleMaker):
...@@ -140,11 +140,11 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -140,11 +140,11 @@ class MPISymetricRoleMaker(MPIRoleMaker):
def __init__(self): def __init__(self):
super(MPISymetricRoleMaker, self).__init__() super(MPISymetricRoleMaker, self).__init__()
self.node_type_ = None self._node_type = None
self.proc_per_node_ = 2 self._proc_per_node = 2
def _check_role_generation(self): def _check_role_generation(self):
if not self.role_is_generated_: if not self._role_is_generated:
sys.stderr.write("generate_role() should be called first") sys.stderr.write("generate_role() should be called first")
sys.exit(-1) sys.exit(-1)
return False return False
...@@ -163,7 +163,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -163,7 +163,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return whether current process is worker assigned by role maker return whether current process is worker assigned by role maker
""" """
if self._check_role_generation(): if self._check_role_generation():
return self.node_type_ == 1 return self._node_type == 1
return False return False
def _is_server(self): def _is_server(self):
...@@ -171,7 +171,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -171,7 +171,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return whether current process is server assigned by role maker return whether current process is server assigned by role maker
""" """
if self._check_role_generation(): if self._check_role_generation():
return self.node_type_ == 0 return self._node_type == 0
return False return False
def _worker_num(self): def _worker_num(self):
...@@ -197,7 +197,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -197,7 +197,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the index of worker return the index of worker
""" """
if self._check_role_generation(): if self._check_role_generation():
return self.rank_ / self.proc_per_node_ return self._rank / self._proc_per_node
return 0 return 0
def _server_index(self): def _server_index(self):
...@@ -205,7 +205,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -205,7 +205,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the index of server return the index of server
""" """
if self._check_role_generation(): if self._check_role_generation():
return self.rank_ / self.proc_per_node_ return self._rank / self._proc_per_node
return 0 return 0
def _barrier_worker(self): def _barrier_worker(self):
...@@ -214,7 +214,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -214,7 +214,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
""" """
if self._check_role_generation(): if self._check_role_generation():
if self._is_worker(): if self._is_worker():
self.node_type_comm_.barrier() self._node_type_comm.barrier()
def _barrier_server(self): def _barrier_server(self):
""" """
...@@ -222,20 +222,20 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -222,20 +222,20 @@ class MPISymetricRoleMaker(MPIRoleMaker):
""" """
if self._check_role_generation(): if self._check_role_generation():
if self._is_server(): if self._is_server():
self.node_type_comm_.barrier() self._node_type_comm.barrier()
def _generate_role(self): def _generate_role(self):
""" """
generate currently process's role generate currently process's role
""" """
if not self.role_is_generated_: if not self._role_is_generated:
# TODO(guru4elephant): only allow to be called once # TODO(guru4elephant): only allow to be called once
self.trainer_endpoints_ = self._get_ips() self._trainer_endpoints = self._get_ips()
self.pserver_endpoints_ = self._get_ips() self._pserver_endpoints = self._get_ips()
if 0 == self._get_rank() % self.proc_per_node_ % 2: if 0 == self._get_rank() % self._proc_per_node % 2:
self.node_type_ = 0 self._node_type = 0
else: else:
self.node_type_ = 1 self._node_type = 1
self.node_type_comm_ = self.comm_.Split(self.node_type_) self._node_type_comm = self._comm.Split(self._node_type)
self.role_is_generated_ = True self._role_is_generated = True
...@@ -64,9 +64,9 @@ class Fleet(object): ...@@ -64,9 +64,9 @@ class Fleet(object):
def __init__(self): def __init__(self):
self._opt_info = None # for fleet only self._opt_info = None # for fleet only
self.role_maker_ = None self._role_maker = None
self.local_ip_ = 0 self._local_ip = 0
self.is_initialized_ = False self._is_initialized = False
def init(self): def init(self):
# TODO(guru4elephant) # TODO(guru4elephant)
...@@ -78,22 +78,22 @@ class Fleet(object): ...@@ -78,22 +78,22 @@ class Fleet(object):
current node's role, e.g. worker, server, etc. current node's role, e.g. worker, server, etc.
""" """
if not self.is_initialized_: if not self.is_initialized_:
self.role_maker_ = MPISymetricRoleMaker() self._role_maker = MPISymetricRoleMaker()
self.role_maker_._generate_role() self._role_maker._generate_role()
self._fleet_ptr = fluid.core.Fleet() self._fleet_ptr = fluid.core.Fleet()
self.is_initialized_ = True self._is_initialized = True
def stop(self): def stop(self):
""" """
stop(): will be called after a user finishes his/her training task. Fleet instance will be stop(): will be called after a user finishes his/her training task. Fleet instance will be
destroyed when stop() is called. destroyed when stop() is called.
""" """
self.role_maker_._barrier_worker() self._role_maker._barrier_worker()
if self.role_maker_._is_first_worker(): if self._role_maker._is_first_worker():
self._fleet_ptr.stop_server() self._fleet_ptr.stop_server()
self.role_maker_._barrier_worker() self._role_maker._barrier_worker()
self.role_maker_._barrier_all() self._role_maker._barrier_all()
self.role_maker_._finalize() self._role_maker._finalize()
def init_pserver(self): def init_pserver(self):
""" """
...@@ -110,15 +110,15 @@ class Fleet(object): ...@@ -110,15 +110,15 @@ class Fleet(object):
sys.exit(-1) sys.exit(-1)
self._fleet_ptr.init_server(self._dist_desc_str, self._fleet_ptr.init_server(self._dist_desc_str,
self.role_maker_._get_rank()) self.role_maker_._get_rank())
self.local_ip_ = self._fleet_ptr.run_server() self._local_ip = self._fleet_ptr.run_server()
# barrier_all for init_server # barrier_all for init_server
self.role_maker_._barrier_all() self._role_maker._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_) self._all_ips = self._role_maker._all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_, self._fleet_ptr.gather_servers(self._all_ips,
self.role_maker_._get_size()) self._role_maker._get_size())
# barrier_all for init_worker, wait all workers start # barrier_all for init_worker, wait all workers start
self.role_maker_._barrier_all() self._role_maker._barrier_all()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
...@@ -151,21 +151,21 @@ class Fleet(object): ...@@ -151,21 +151,21 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
# barrier_all for init_server, wait for server starts # barrier_all for init_server, wait for server starts
self.role_maker_._barrier_all() self._role_maker._barrier_all()
self.all_ips_ = self.role_maker_._all_gather(self.local_ip_) self._all_ips = self._role_maker._all_gather(self.local_ip_)
self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_, self._fleet_ptr.init_worker(self._dist_desc_str, self._all_ips,
self.role_maker_._get_size(), self._role_maker._get_size(),
self.role_maker_._get_rank()) self._role_maker._get_rank())
# barrier_all for init_worker # barrier_all for init_worker
self.role_maker_._barrier_all() self._role_maker._barrier_all()
# prepare for client to client communication # prepare for client to client communication
info = self._fleet_ptr.get_clients_info() info = self._fleet_ptr.get_clients_info()
all_info = self.role_maker_._worker_gather(info[0]) all_info = self._role_maker._worker_gather(info[0])
self._fleet_ptr.gather_clients(all_info) self._fleet_ptr.gather_clients(all_info)
self._fleet_ptr.create_client2client_connection() self._fleet_ptr.create_client2client_connection()
# barrier for init model # barrier for init model
self.role_maker_._barrier_worker() self._role_maker._barrier_worker()
if self.role_maker_._is_first_worker(): if self._role_maker._is_first_worker():
tables = self._dist_desc.trainer_param.dense_table tables = self._dist_desc.trainer_param.dense_table
for prog, scope in zip(programs, scopes): for prog, scope in zip(programs, scopes):
prog_id = str(id(prog)) prog_id = str(id(prog))
...@@ -192,7 +192,7 @@ class Fleet(object): ...@@ -192,7 +192,7 @@ class Fleet(object):
int(table.table_id), int(table.table_id),
var_name_list) var_name_list)
# barrier for init model done # barrier for init model done
self.role_maker_._barrier_worker() self._role_maker._barrier_worker()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
...@@ -201,39 +201,39 @@ class Fleet(object): ...@@ -201,39 +201,39 @@ class Fleet(object):
""" """
return the number of current job's worker num return the number of current job's worker num
""" """
return self.role_maker_._worker_num() return self._role_maker._worker_num()
def get_server_num(self): def get_server_num(self):
""" """
return the number of current job's server num return the number of current job's server num
""" """
return self.role_maker_._server_num() return self._role_maker._server_num()
def get_worker_index(self): def get_worker_index(self):
""" """
return the mpi rank of current worker return the mpi rank of current worker
""" """
return self.role_maker_._worker_index() return self._role_maker._worker_index()
def is_worker(self): def is_worker(self):
""" """
return whether current node is a worker return whether current node is a worker
""" """
return self.role_maker_._is_worker() return self._role_maker._is_worker()
def is_server(self): def is_server(self):
""" """
return whether current node is pserver return whether current node is pserver
""" """
return self.role_maker_._is_server() return self._role_maker._is_server()
def init_pserver_model(self): def init_pserver_model(self):
""" """
init pserver model called from pserver init pserver model called from pserver
""" """
if self.role_maker_._is_first_worker(): if self._role_maker._is_first_worker():
self._fleet_ptr.init_model() self._fleet_ptr.init_model()
self.role_maker_._barrier_worker() self._role_maker._barrier_worker()
def save_pserver_model(self, save_path): def save_pserver_model(self, save_path):
""" """
......
...@@ -42,13 +42,13 @@ class DownpourServer(Server): ...@@ -42,13 +42,13 @@ class DownpourServer(Server):
""" """
def __init__(self): def __init__(self):
self.server_ = pslib.ServerParameter() self._server = pslib.ServerParameter()
self.server_.downpour_server_param.service_param.start_server_port = 0 self._server.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer" self._server.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer"
self.server_.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient" self._server.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient"
self.server_.downpour_server_param.service_param.service_class = "DownpourPsService" self._server.downpour_server_param.service_param.service_class = "DownpourPsService"
self.server_.downpour_server_param.service_param.start_server_port = 0 self._server.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_thread_num = 12 self._server.downpour_server_param.service_param.server_thread_num = 12
def add_sparse_table(self, table_id, learning_rate, slot_key_vars, def add_sparse_table(self, table_id, learning_rate, slot_key_vars,
slot_value_var): slot_value_var):
...@@ -62,7 +62,7 @@ class DownpourServer(Server): ...@@ -62,7 +62,7 @@ class DownpourServer(Server):
Returns: Returns:
return None return None
""" """
table = self.server_.downpour_server_param.downpour_table_param.add() table = self._server.downpour_server_param.downpour_table_param.add()
table.table_id = table_id table.table_id = table_id
table.table_class = "DownpourSparseTable" table.table_class = "DownpourSparseTable"
table.type = pslib.PS_SPARSE_TABLE table.type = pslib.PS_SPARSE_TABLE
...@@ -123,7 +123,7 @@ class DownpourServer(Server): ...@@ -123,7 +123,7 @@ class DownpourServer(Server):
Returns: Returns:
return None return None
""" """
table = self.server_.downpour_server_param.downpour_table_param.add() table = self._server.downpour_server_param.downpour_table_param.add()
table.table_id = table_id table.table_id = table_id
table.table_class = "DownpourDenseTable" table.table_class = "DownpourDenseTable"
table.type = pslib.PS_DENSE_TABLE table.type = pslib.PS_DENSE_TABLE
...@@ -140,7 +140,7 @@ class DownpourServer(Server): ...@@ -140,7 +140,7 @@ class DownpourServer(Server):
""" """
Return downpour server program_desc Return downpour server program_desc
""" """
return self.server_ return self._server
class DownpourWorker(Worker): class DownpourWorker(Worker):
...@@ -155,7 +155,7 @@ class DownpourWorker(Worker): ...@@ -155,7 +155,7 @@ class DownpourWorker(Worker):
def __init__(self, window): def __init__(self, window):
self.window = window self.window = window
self.worker_ = pslib.DownpourTrainerParameter() self._worker = pslib.DownpourTrainerParameter()
def add_sparse_table(self, table_id, learning_rate, slot_key_vars, def add_sparse_table(self, table_id, learning_rate, slot_key_vars,
slot_value_vars): slot_value_vars):
...@@ -187,7 +187,7 @@ class DownpourWorker(Worker): ...@@ -187,7 +187,7 @@ class DownpourWorker(Worker):
Returns: Returns:
return None return None
""" """
table = self.worker_.dense_table.add() table = self._worker.dense_table.add()
table.table_id = table_id table.table_id = table_id
table.dense_variable_name.extend( table.dense_variable_name.extend(
filter(lambda x: x.find("embedding") == -1, filter(lambda x: x.find("embedding") == -1,
...@@ -200,4 +200,4 @@ class DownpourWorker(Worker): ...@@ -200,4 +200,4 @@ class DownpourWorker(Worker):
""" """
Return downpour worker program_desc Return downpour worker program_desc
""" """
return self.worker_ return self._worker
...@@ -24,9 +24,9 @@ from .node import DownpourWorker, DownpourServer ...@@ -24,9 +24,9 @@ from .node import DownpourWorker, DownpourServer
class DistributedOptimizerImplBase(object): class DistributedOptimizerImplBase(object):
def __init__(self, optimizer): def __init__(self, optimizer):
self.optimizer_ = optimizer self._optimizer = optimizer
self.learning_rate_ = optimizer._learning_rate self._learning_rate = optimizer._learning_rate
self.regularization_ = optimizer.regularization self._regularization = optimizer.regularization
def minimize(self, def minimize(self,
losses, losses,
...@@ -41,7 +41,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -41,7 +41,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
# todo(guru4elephant): add more optimizers here as argument # todo(guru4elephant): add more optimizers here as argument
# todo(guru4elephant): make learning_rate as a variable # todo(guru4elephant): make learning_rate as a variable
super(DistributedAdam, self).__init__(optimizer) super(DistributedAdam, self).__init__(optimizer)
self.window_ = 1 self._window = 1
self.type = "downpour" self.type = "downpour"
self.data_norm_name = [ self.data_norm_name = [
".batch_size", ".batch_square_sum", ".batch_sum", ".batch_size", ".batch_square_sum", ".batch_sum",
...@@ -79,9 +79,9 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -79,9 +79,9 @@ class DistributedAdam(DistributedOptimizerImplBase):
server = DownpourServer() server = DownpourServer()
worker = DownpourWorker(self.window_) worker = DownpourWorker(self.window_)
sparse_table_index = 0 sparse_table_index = 0
server.add_sparse_table(sparse_table_index, self.learning_rate_, server.add_sparse_table(sparse_table_index, self._learning_rate,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
worker.add_sparse_table(sparse_table_index, self.learning_rate_, worker.add_sparse_table(sparse_table_index, self._learning_rate,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
dense_table_index = 1 dense_table_index = 1
program_configs = {} program_configs = {}
...@@ -124,9 +124,9 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -124,9 +124,9 @@ class DistributedAdam(DistributedOptimizerImplBase):
data_norm_grads.append(i[1]) data_norm_grads.append(i[1])
if not is_data_norm_data: if not is_data_norm_data:
grads.append(i[1]) grads.append(i[1])
server.add_dense_table(dense_table_index, self.learning_rate_, server.add_dense_table(dense_table_index, self._learning_rate,
params, grads) params, grads)
worker.add_dense_table(dense_table_index, self.learning_rate_, worker.add_dense_table(dense_table_index, self._learning_rate,
params, grads) params, grads)
program_configs[program_id]["pull_dense"] = [dense_table_index] program_configs[program_id]["pull_dense"] = [dense_table_index]
program_configs[program_id]["push_dense"] = [dense_table_index] program_configs[program_id]["push_dense"] = [dense_table_index]
...@@ -135,9 +135,9 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -135,9 +135,9 @@ class DistributedAdam(DistributedOptimizerImplBase):
if len(data_norm_params) != 0 and len(data_norm_grads) != 0: if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
dense_table_index += 1 dense_table_index += 1
server.add_data_norm_table(dense_table_index, server.add_data_norm_table(dense_table_index,
self.learning_rate_, self._learning_rate,
data_norm_params, data_norm_grads) data_norm_params, data_norm_grads)
worker.add_dense_table(dense_table_index, self.learning_rate_, worker.add_dense_table(dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads) data_norm_params, data_norm_grads)
#program_config.pull_dense_table_id.extend([dense_table_index]) #program_config.pull_dense_table_id.extend([dense_table_index])
#program_config.push_dense_table_id.extend([dense_table_index]) #program_config.push_dense_table_id.extend([dense_table_index])
......
...@@ -28,10 +28,10 @@ class TrainerDesc(object): ...@@ -28,10 +28,10 @@ class TrainerDesc(object):
import multiprocessing as mp import multiprocessing as mp
# set default thread num == cpu count # set default thread num == cpu count
self.proto_desc.thread_num = mp.cpu_count() self.proto_desc.thread_num = mp.cpu_count()
self.fleet_desc_ = None self._fleet_desc = None
self.device_worker_ = None self._device_worker = None
self.program_ = None self._program = None
self.infer_ = False self._infer = False
def _set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period): def _set_fetch_var_and_info(self, fetch_vars, fetch_info, print_period):
for i, v in enumerate(fetch_vars): for i, v in enumerate(fetch_vars):
...@@ -47,19 +47,19 @@ class TrainerDesc(object): ...@@ -47,19 +47,19 @@ class TrainerDesc(object):
self.proto_desc.thread_num = thread_num self.proto_desc.thread_num = thread_num
def _set_device_worker(self, device_worker): def _set_device_worker(self, device_worker):
self.device_worker_ = device_worker self._device_worker = device_worker
def _set_infer(self, infer): def _set_infer(self, infer):
self.infer_ = infer self._infer = infer
def _set_fleet_desc(self, fleet_desc): def _set_fleet_desc(self, fleet_desc):
self.fleet_desc_ = fleet_desc self._fleet_desc = fleet_desc
def _gen_trainer_desc(self): def _gen_trainer_desc(self):
pass pass
def _set_program(self, program): def _set_program(self, program):
self.program_ = program self._program = program
def _desc(self): def _desc(self):
from google.protobuf import text_format from google.protobuf import text_format
...@@ -73,13 +73,13 @@ class MultiTrainer(TrainerDesc): ...@@ -73,13 +73,13 @@ class MultiTrainer(TrainerDesc):
def _set_program(self, program): def _set_program(self, program):
super(MultiTrainer, self)._set_program(program) super(MultiTrainer, self)._set_program(program)
self.program_ = program self._program = program
def _gen_trainer_desc(self): def _gen_trainer_desc(self):
super(MultiTrainer, self)._gen_trainer_desc() super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer" self.proto_desc.class_name = "MultiTrainer"
self.device_worker_._set_infer(self.infer_) self._device_worker._set_infer(self.infer_)
self.device_worker_._gen_worker_desc(self.proto_desc) self._device_worker._gen_worker_desc(self.proto_desc)
class DistMultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc):
...@@ -89,13 +89,13 @@ class DistMultiTrainer(TrainerDesc): ...@@ -89,13 +89,13 @@ class DistMultiTrainer(TrainerDesc):
def _set_program(self, program): def _set_program(self, program):
super(DistMultiTrainer, self)._set_program(program) super(DistMultiTrainer, self)._set_program(program)
self.program_ = program self._program = program
def _gen_trainer_desc(self): def _gen_trainer_desc(self):
super(DistMultiTrainer, self)._gen_trainer_desc() super(DistMultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None: if self._program == None:
raise RuntimeError("None Program") raise RuntimeError("None Program")
self.device_worker_._set_infer(self.infer_) self._device_worker._set_infer(self.infer_)
self.device_worker_._set_program(self.program_) self._device_worker._set_program(self.program_)
self.device_worker_._gen_worker_desc(self.proto_desc) self._device_worker._gen_worker_desc(self.proto_desc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册