提交 6bc0efb4 编写于 作者: H heqiaozhi

refine interface

上级 575ae7c6
...@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2 ...@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import io from . import io
from .data_feed_desc import DataFeedDesc from .data_feed_desc import DataFeedDesc
from .distributed import ps_instance
__all__ = ['AsyncExecutor'] __all__ = ['AsyncExecutor']
...@@ -85,6 +86,7 @@ class AsyncExecutor(object): ...@@ -85,6 +86,7 @@ class AsyncExecutor(object):
scope = global_scope() scope = global_scope()
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
self.instance = ps_instance.PaddlePSInstance("init_param", 1, 2)
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
""" """
...@@ -149,26 +151,38 @@ class AsyncExecutor(object): ...@@ -149,26 +151,38 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, debug)
self.instance.barrier_all()
def config_distributed_nodes(self, dist_opt): def config_distributed_nodes(self, dist_opt):
# get total rank # get total rank
# get rank index # get rank index
# get iplists # get iplists
# get hadoop info # get hadoop info
return pass
def get_instance(self):
def init_server(self, filename, index): return self.instance
self.executor.init_server(filename, index)
def init_server(self, dist_desc):
def init_worker(self, filename, ips, nodes_cnt, index): self.executor.init_server(dist_desc, self.instance._rankid)
self.executor.init_worker(filename, ips, nodes_cnt, index) ip = self.executor.start_server()
self.instance.set_ip(ip)
def start_server(self): self.instance.barrier_all() #wait all server start
return self.executor.start_server() ips = self.instance.gather_ips()
self.executor.gather_servers(ips, self.instance.get_node_cnt())
def gather_servers(self, ips, nodes_cnt): self.instance.barrier_all() #wait all worker start
self.executor.gather_servers(ips, nodes_cnt) self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait worker do all things
def init_worker(self, dist_desc):
self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
self.instance.barrier_all() #wait all worker start
if self.instance.is_first_worker():
self.executor.init_model()
self.instance.barrier_all() #wait init model
def init_model(self): def init_model(self):
self.executor.init_model() self.executor.init_model()
......
...@@ -46,14 +46,20 @@ class DownpourSGD(object): ...@@ -46,14 +46,20 @@ class DownpourSGD(object):
sparse_table_index = 0 sparse_table_index = 0
# currently merge all dense parameters into one dense table # currently merge all dense parameters into one dense table
dense_table_index = 1 dense_table_index = 1
params = []
grads = []
for i in params_grads:
params.append(i[0])
for i in params_grads:
grads.append(i[1])
server.add_sparse_table(sparse_table_index, self.learning_rate_, server.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
server.add_dense_table(dense_table_index, self.learning_rate_, server.add_dense_table(dense_table_index, self.learning_rate_,
params_grads[0], params_grads[1]) params, grads)
worker.add_sparse_table(sparse_table_index, self.learning_rate_, worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb) prefetch_slots, prefetch_slots_emb)
worker.add_dense_table(dense_table_index, self.learning_rate_, worker.add_dense_table(dense_table_index, self.learning_rate_,
params_grads[0], params_grads[1]) params, grads)
ps_param = pslib.PSParameter() ps_param = pslib.PSParameter()
ps_param.server_param.CopyFrom(server.get_desc()) ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc()) ps_param.trainer_param.CopyFrom(worker.get_desc())
...@@ -61,4 +67,4 @@ class DownpourSGD(object): ...@@ -61,4 +67,4 @@ class DownpourSGD(object):
# currently only support lookup_table # currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"] worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param_str = text_format.MessageToString(ps_param) ps_param_str = text_format.MessageToString(ps_param)
return [ps_param_str, worker_skipped_ops] return [ps_param, worker_skipped_ops]
from mpi4py import MPI from mpi4py import MPI
import ps_pb2 as pslib
class FileSystem(object): class FileSystem(object):
def __init__(self, fs_type="afs", def __init__(self, fs_type="afs",
...@@ -7,20 +8,23 @@ class FileSystem(object): ...@@ -7,20 +8,23 @@ class FileSystem(object):
passwd=None, passwd=None,
hadoop_bin="", hadoop_bin="",
afs_conf=None): afs_conf=None):
assert user not None assert user != None
assert passwd not None assert passwd != None
assert hadoop_bin not None assert hadoop_bin != None
fs_client = pslib.FsClientParameter() self.fs_client = pslib.FsClientParameter()
if fs_type == "afs": #if fs_type == "afs":
fs_client.fs_type = pslib.FsApiType.AFS # fs_client.fs_type = pslib.FsApiType.AFS
else: #else:
fs_client.fs_type = pslib.FsApiType.HDFS # fs_client.fs_type = pslib.FsApiType.HDFS
fs_client.uri = uri self.fs_client.uri = uri
fs_client.user = user self.fs_client.user = user
fs_client.passwd = passwd self.fs_client.passwd = passwd
fs_client.buffer_size = 0 #self.fs_client.buffer_size = 0
fs_client.afs_conf = afs_conf if not afs_conf else "" self.fs_client.hadoop_bin = hadoop_bin
#self.fs_client.afs_conf = afs_conf if not afs_conf else ""
def get_desc(self):
return self.fs_client
class MPIHelper(object): class MPIHelper(object):
def __init__(self): def __init__(self):
......
...@@ -13,24 +13,52 @@ class Worker(object): ...@@ -13,24 +13,52 @@ class Worker(object):
class DownpourServer(Server): class DownpourServer(Server):
def __init__(self): def __init__(self):
self.server_ = pslib.ServerParameter() self.server_ = pslib.ServerParameter()
self.server_.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer"
self.server_.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient"
self.server_.downpour_server_param.service_param.service_class = "DownpourPsService"
self.server_.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_thread_num = 12
def add_sparse_table(self, table_id, learning_rate, def add_sparse_table(self, table_id, learning_rate,
slot_key_vars, slot_value_var): slot_key_vars, slot_value_var):
table = self.server_.downpour_server_param.downpour_table_param.add() table = self.server_.downpour_server_param.downpour_table_param.add()
table.table_id = table_id table.table_id = table_id
table.table_class = "DownpourSparseTable"
table.type = pslib.PS_SPARSE_TABLE table.type = pslib.PS_SPARSE_TABLE
table.accessor.accessor_class = "DownpourFeatureValueAccessor" table.accessor.accessor_class = "DownpourFeatureValueAccessor"
table.accessor.dense_sgd_param.adam.learning_rate = learning_rate table.accessor.sparse_sgd_param.learning_rate = learning_rate
table.accessor.fea_dim = abs(reduce(lambda x, y: x * y, table.accessor.sparse_sgd_param.initial_g2sum = 3
slot_value_var[0].shape, 1)) table.accessor.sparse_sgd_param.initial_range = 1e-4
table.accessor.sparse_sgd_param.weight_bounds.extend([-10, 10])
table.accessor.embedx_dim = 8
table.accessor.embedx_threshold = 5
table.accessor.fea_dim = 11
#table.accessor.fea_dim = abs(reduce(lambda x, y: x * y,
# slot_value_var[0].shape, 1))
table.accessor.downpour_accessor_param.nonclk_coeff = 0.1
table.accessor.downpour_accessor_param.click_coeff = 2
table.accessor.downpour_accessor_param.base_threshold = 0.2
table.accessor.downpour_accessor_param.delta_threshold = 0.15
table.accessor.downpour_accessor_param.delta_keep_days = 31
table.accessor.downpour_accessor_param.show_click_decay_rate = 0.999
table.accessor.downpour_accessor_param.delete_threshold = 0.8
def add_dense_table(self, table_id, learning_rate, def add_dense_table(self, table_id, learning_rate,
param_var, grad_var): param_var, grad_var):
table = self.server_.downpour_server_param.downpour_table_param.add() table = self.server_.downpour_server_param.downpour_table_param.add()
table.table_id = table_id table.table_id = table_id
table.table_class = "DownpourDenseTable"
table.type = pslib.PS_DENSE_TABLE table.type = pslib.PS_DENSE_TABLE
table.accessor.accessor_class = "DownpourDenseValueAccessor" table.accessor.accessor_class = "DownpourDenseValueAccessor"
table.accessor.sparse_sgd_param.learning_rate = learning_rate table.accessor.dense_sgd_param.name = "adam"
table.accessor.dense_sgd_param.adam.learning_rate = learning_rate
table.accessor.dense_sgd_param.adam.avg_decay_rate = 0.999993
table.accessor.dense_sgd_param.adam.ada_decay_rate = 0.9999
table.accessor.dense_sgd_param.adam.ada_epsilon = 1e-8
table.accessor.dense_sgd_param.adam.mom_decay_rate = 0.99
table.accessor.dense_sgd_param.naive.learning_rate = 0.0002
fea_dim = 0 fea_dim = 0
for param in param_var: for param in param_var:
fea_dim += reduce(lambda x, y: x * y, param.shape, 1) fea_dim += reduce(lambda x, y: x * y, param.shape, 1)
...@@ -44,8 +72,8 @@ class DownpourWorker(Worker): ...@@ -44,8 +72,8 @@ class DownpourWorker(Worker):
def __init__(self, window): def __init__(self, window):
self.window = window self.window = window
self.worker_ = pslib.DownpourTrainerParameter() self.worker_ = pslib.DownpourTrainerParameter()
self.worker_.pull_dense_per_batch = window #self.worker_.pull_dense_per_batch = window
self.worker_.push_dense_per_batch = window #self.worker_.push_dense_per_batch = window
def add_sparse_table(self, table_id, learning_rate, def add_sparse_table(self, table_id, learning_rate,
slot_key_vars, slot_value_vars): slot_key_vars, slot_value_vars):
...@@ -62,8 +90,8 @@ class DownpourWorker(Worker): ...@@ -62,8 +90,8 @@ class DownpourWorker(Worker):
param_vars, grad_vars): param_vars, grad_vars):
table = self.worker_.dense_table.add() table = self.worker_.dense_table.add()
table.table_id = table_id table.table_id = table_id
table.dense_variable_name.extend([p.name for p in param_vars]) table.dense_variable_name.extend(filter(lambda x: x.find("embedding") == -1, [p.name for p in param_vars]))
table.dense_gradient_variable_name.extend([g.name for g in grad_vars]) table.dense_gradient_variable_name.extend(filter(lambda x: x.find("embedding") == -1, [g.name for g in grad_vars]))
def get_desc(self): def get_desc(self):
return self.worker_ return self.worker_
...@@ -531,21 +531,21 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor( ...@@ -531,21 +531,21 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='server_class', full_name='paddle.ServerServiceParameter.server_class', index=0, name='server_class', full_name='paddle.ServerServiceParameter.server_class', index=0,
number=1, type=9, cpp_type=9, label=1, number=1, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("AbacusBrpcPsServer").decode('utf-8'), has_default_value=True, default_value=_b("DownpourBrpcPsServer").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='client_class', full_name='paddle.ServerServiceParameter.client_class', index=1, name='client_class', full_name='paddle.ServerServiceParameter.client_class', index=1,
number=2, type=9, cpp_type=9, label=1, number=2, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("AbacusBrpcPsClient").decode('utf-8'), has_default_value=True, default_value=_b("DownpourBrpcPsClient").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='service_class', full_name='paddle.ServerServiceParameter.service_class', index=2, name='service_class', full_name='paddle.ServerServiceParameter.service_class', index=2,
number=3, type=9, cpp_type=9, label=1, number=3, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("AbacusPsService").decode('utf-8'), has_default_value=True, default_value=_b("DownpourPsService").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None, is_extension=False, extension_scope=None,
options=None), options=None),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册