提交 6bc0efb4 编写于 作者: H heqiaozhi

refine interface

上级 575ae7c6
......@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import io
from .data_feed_desc import DataFeedDesc
from .distributed import ps_instance
__all__ = ['AsyncExecutor']
......@@ -85,6 +86,7 @@ class AsyncExecutor(object):
scope = global_scope()
self.executor = core.AsyncExecutor(scope, p)
self.instance = ps_instance.PaddlePSInstance("init_param", 1, 2)
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
"""
......@@ -149,26 +151,38 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num,
fetch_var_names, debug)
self.instance.barrier_all()
def config_distributed_nodes(self, dist_opt):
# get total rank
# get rank index
# get iplists
# get hadoop info
return
def init_server(self, filename, index):
self.executor.init_server(filename, index)
def init_worker(self, filename, ips, nodes_cnt, index):
self.executor.init_worker(filename, ips, nodes_cnt, index)
def start_server(self):
return self.executor.start_server()
def gather_servers(self, ips, nodes_cnt):
self.executor.gather_servers(ips, nodes_cnt)
pass
def get_instance(self):
return self.instance
def init_server(self, dist_desc):
self.executor.init_server(dist_desc, self.instance._rankid)
ip = self.executor.start_server()
self.instance.set_ip(ip)
self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips()
self.executor.gather_servers(ips, self.instance.get_node_cnt())
self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait worker do all things
def init_worker(self, dist_desc):
self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
self.instance.barrier_all() #wait all worker start
if self.instance.is_first_worker():
self.executor.init_model()
self.instance.barrier_all() #wait init model
def init_model(self):
self.executor.init_model()
......
......@@ -46,14 +46,20 @@ class DownpourSGD(object):
sparse_table_index = 0
# currently merge all dense parameters into one dense table
dense_table_index = 1
params = []
grads = []
for i in params_grads:
params.append(i[0])
for i in params_grads:
grads.append(i[1])
server.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb)
server.add_dense_table(dense_table_index, self.learning_rate_,
params_grads[0], params_grads[1])
params, grads)
worker.add_sparse_table(sparse_table_index, self.learning_rate_,
prefetch_slots, prefetch_slots_emb)
worker.add_dense_table(dense_table_index, self.learning_rate_,
params_grads[0], params_grads[1])
params, grads)
ps_param = pslib.PSParameter()
ps_param.server_param.CopyFrom(server.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc())
......@@ -61,4 +67,4 @@ class DownpourSGD(object):
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
ps_param_str = text_format.MessageToString(ps_param)
return [ps_param_str, worker_skipped_ops]
return [ps_param, worker_skipped_ops]
from mpi4py import MPI
import ps_pb2 as pslib
class FileSystem(object):
def __init__(self, fs_type="afs",
......@@ -7,20 +8,23 @@ class FileSystem(object):
passwd=None,
hadoop_bin="",
afs_conf=None):
assert user not None
assert passwd not None
assert hadoop_bin not None
fs_client = pslib.FsClientParameter()
if fs_type == "afs":
fs_client.fs_type = pslib.FsApiType.AFS
else:
fs_client.fs_type = pslib.FsApiType.HDFS
fs_client.uri = uri
fs_client.user = user
fs_client.passwd = passwd
fs_client.buffer_size = 0
fs_client.afs_conf = afs_conf if not afs_conf else ""
assert user != None
assert passwd != None
assert hadoop_bin != None
self.fs_client = pslib.FsClientParameter()
#if fs_type == "afs":
# fs_client.fs_type = pslib.FsApiType.AFS
#else:
# fs_client.fs_type = pslib.FsApiType.HDFS
self.fs_client.uri = uri
self.fs_client.user = user
self.fs_client.passwd = passwd
#self.fs_client.buffer_size = 0
self.fs_client.hadoop_bin = hadoop_bin
#self.fs_client.afs_conf = afs_conf if not afs_conf else ""
def get_desc(self):
return self.fs_client
class MPIHelper(object):
def __init__(self):
......
......@@ -13,24 +13,52 @@ class Worker(object):
class DownpourServer(Server):
def __init__(self):
self.server_ = pslib.ServerParameter()
self.server_.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer"
self.server_.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient"
self.server_.downpour_server_param.service_param.service_class = "DownpourPsService"
self.server_.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_thread_num = 12
def add_sparse_table(self, table_id, learning_rate,
slot_key_vars, slot_value_var):
table = self.server_.downpour_server_param.downpour_table_param.add()
table.table_id = table_id
table.table_class = "DownpourSparseTable"
table.type = pslib.PS_SPARSE_TABLE
table.accessor.accessor_class = "DownpourFeatureValueAccessor"
table.accessor.dense_sgd_param.adam.learning_rate = learning_rate
table.accessor.fea_dim = abs(reduce(lambda x, y: x * y,
slot_value_var[0].shape, 1))
table.accessor.sparse_sgd_param.learning_rate = learning_rate
table.accessor.sparse_sgd_param.initial_g2sum = 3
table.accessor.sparse_sgd_param.initial_range = 1e-4
table.accessor.sparse_sgd_param.weight_bounds.extend([-10, 10])
table.accessor.embedx_dim = 8
table.accessor.embedx_threshold = 5
table.accessor.fea_dim = 11
#table.accessor.fea_dim = abs(reduce(lambda x, y: x * y,
# slot_value_var[0].shape, 1))
table.accessor.downpour_accessor_param.nonclk_coeff = 0.1
table.accessor.downpour_accessor_param.click_coeff = 2
table.accessor.downpour_accessor_param.base_threshold = 0.2
table.accessor.downpour_accessor_param.delta_threshold = 0.15
table.accessor.downpour_accessor_param.delta_keep_days = 31
table.accessor.downpour_accessor_param.show_click_decay_rate = 0.999
table.accessor.downpour_accessor_param.delete_threshold = 0.8
def add_dense_table(self, table_id, learning_rate,
param_var, grad_var):
table = self.server_.downpour_server_param.downpour_table_param.add()
table.table_id = table_id
table.table_class = "DownpourDenseTable"
table.type = pslib.PS_DENSE_TABLE
table.accessor.accessor_class = "DownpourDenseValueAccessor"
table.accessor.sparse_sgd_param.learning_rate = learning_rate
table.accessor.dense_sgd_param.name = "adam"
table.accessor.dense_sgd_param.adam.learning_rate = learning_rate
table.accessor.dense_sgd_param.adam.avg_decay_rate = 0.999993
table.accessor.dense_sgd_param.adam.ada_decay_rate = 0.9999
table.accessor.dense_sgd_param.adam.ada_epsilon = 1e-8
table.accessor.dense_sgd_param.adam.mom_decay_rate = 0.99
table.accessor.dense_sgd_param.naive.learning_rate = 0.0002
fea_dim = 0
for param in param_var:
fea_dim += reduce(lambda x, y: x * y, param.shape, 1)
......@@ -44,8 +72,8 @@ class DownpourWorker(Worker):
def __init__(self, window):
self.window = window
self.worker_ = pslib.DownpourTrainerParameter()
self.worker_.pull_dense_per_batch = window
self.worker_.push_dense_per_batch = window
#self.worker_.pull_dense_per_batch = window
#self.worker_.push_dense_per_batch = window
def add_sparse_table(self, table_id, learning_rate,
slot_key_vars, slot_value_vars):
......@@ -62,8 +90,8 @@ class DownpourWorker(Worker):
param_vars, grad_vars):
table = self.worker_.dense_table.add()
table.table_id = table_id
table.dense_variable_name.extend([p.name for p in param_vars])
table.dense_gradient_variable_name.extend([g.name for g in grad_vars])
table.dense_variable_name.extend(filter(lambda x: x.find("embedding") == -1, [p.name for p in param_vars]))
table.dense_gradient_variable_name.extend(filter(lambda x: x.find("embedding") == -1, [g.name for g in grad_vars]))
def get_desc(self):
return self.worker_
......@@ -531,21 +531,21 @@ _SERVERSERVICEPARAMETER = _descriptor.Descriptor(
_descriptor.FieldDescriptor(
name='server_class', full_name='paddle.ServerServiceParameter.server_class', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("AbacusBrpcPsServer").decode('utf-8'),
has_default_value=True, default_value=_b("DownpourBrpcPsServer").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='client_class', full_name='paddle.ServerServiceParameter.client_class', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("AbacusBrpcPsClient").decode('utf-8'),
has_default_value=True, default_value=_b("DownpourBrpcPsClient").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='service_class', full_name='paddle.ServerServiceParameter.service_class', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=True, default_value=_b("AbacusPsService").decode('utf-8'),
has_default_value=True, default_value=_b("DownpourPsService").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册