提交 8a335b50 编写于 作者: D dongdaxiang

add downpour device_worker pb configuration

上级 24a80011
...@@ -59,7 +59,6 @@ message TableParameter { ...@@ -59,7 +59,6 @@ message TableParameter {
optional int64 table_id = 1; optional int64 table_id = 1;
repeated string dense_value_name = 2; repeated string dense_value_name = 2;
repeated string dense_grad_name = 3; repeated string dense_grad_name = 3;
repeated int32 dense_table_size = 4;
repeated int32 push_dense_wait_times = 5; repeated int32 push_dense_wait_times = 5;
// sparse table only // sparse table only
repeated string sparse_key_name = 6; repeated string sparse_key_name = 6;
......
...@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2 ...@@ -24,6 +24,7 @@ from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import io from . import io
from .data_feed_desc import DataFeedDesc from .data_feed_desc import DataFeedDesc
from .trainer_desc import TrainerDesc, MultiTrainer, DistMultiTrainer
from .distributed import ps_instance from .distributed import ps_instance
from .contrib.utils import hdfs_utils as hdfs from .contrib.utils import hdfs_utils as hdfs
...@@ -89,6 +90,38 @@ class AsyncExecutor(object): ...@@ -89,6 +90,38 @@ class AsyncExecutor(object):
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
self.instance = None self.instance = None
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
if program is None:
program = default_main_program()
program_desc = program.desc
if data_feed is None:
raise ValueError('ValueError: data_feed should be provided')
if filelist is None:
raise ValueError('ValueError: filelist should be provided')
if isinstance(filelist, str):
filelist = [filelist]
if not isinstance(thread_num, int):
raise TypeError('TypeError: thread_num should be a positive number')
is_local = self.instance == None
trainer = None
if is_local:
trainer = MultiTrainer(data_feed=data_feed, worker="Hogwild")
else:
trainer = DistMultiTrainer(
data_feed, worker="Downpour", fleet_desc=self.dist_desc)
# define a trainer and a device_worker here
trainer.set_thread(thread_num)
trainer.set_filelist(filelist)
trainer.set_data_feed(data_feed)
self.executor.run_from_files(program_desc, trainer._desc(), debug)
'''
def run(self, def run(self,
program, program,
data_feed, data_feed,
...@@ -160,6 +193,7 @@ class AsyncExecutor(object): ...@@ -160,6 +193,7 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, mode, debug) fetch_var_names, mode, debug)
'''
def download_data(self, def download_data(self,
afs_path, afs_path,
...@@ -250,6 +284,7 @@ class AsyncExecutor(object): ...@@ -250,6 +284,7 @@ class AsyncExecutor(object):
raise ValueError( raise ValueError(
'instance is None, please run config_distributed_nodes init instance' 'instance is None, please run config_distributed_nodes init instance'
) )
self.init_desc = init_desc
self.executor.init_server(dist_desc, self.instance._rankid) self.executor.init_server(dist_desc, self.instance._rankid)
ip = self.executor.start_server() ip = self.executor.start_server()
self.instance.set_ip(ip) self.instance.set_ip(ip)
...@@ -270,6 +305,8 @@ class AsyncExecutor(object): ...@@ -270,6 +305,8 @@ class AsyncExecutor(object):
raise ValueError( raise ValueError(
'instance is None, please run config_distributed_nodes init instance' 'instance is None, please run config_distributed_nodes init instance'
) )
self.dist_desc = dist_desc
place = core.CPUPlace() place = core.CPUPlace()
executor = Executor(place) executor = Executor(place)
executor.run(startup_program) executor.run(startup_program)
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from paddle.fluid.proto import trainer_desc_pb2 from paddle.fluid.proto import trainer_desc_pb2
import ps_pb2 as pslib
from google.protobuf import text_format from google.protobuf import text_format
__all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer'] __all__ = ['TrainerDesc', 'MultiTrainer', 'DistMultiTrainer']
...@@ -42,7 +43,7 @@ class TrainerDesc(object): ...@@ -42,7 +43,7 @@ class TrainerDesc(object):
class MultiTrainer(TrainerDesc): class MultiTrainer(TrainerDesc):
def __init__(self, worker="Hogwild"): def __init__(self, dataset=None, worker="Hogwild"):
super(MultiTrainer, self).__init__() super(MultiTrainer, self).__init__()
if worker == "Hogwild": if worker == "Hogwild":
self.proto_desc.device_worker_name = worker + "Worker" self.proto_desc.device_worker_name = worker + "Worker"
...@@ -53,11 +54,39 @@ class MultiTrainer(TrainerDesc): ...@@ -53,11 +54,39 @@ class MultiTrainer(TrainerDesc):
class DistMultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc):
def __init__(self, worker='Downpour'): def __init__(self, dataset=None, worker='Downpour', fleet_desc=None):
super(DistMultiTrainer, self).__init__() super(DistMultiTrainer, self).__init__()
if worker == "Downpour": if worker == "Downpour":
self.proto_desc.device_worker_name = worker + "Worker" self.proto_desc.device_worker_name = worker + "Worker"
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
self.proto_desc.data_feed.CopyFrom(dataset)
downpour = self.proto_desc.downpour_param.add()
# sparse table should specify:
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = \
fleet_desc.trainer_param.sparse_table.table_id
sparse_table.sparse_key_name.CopyFrom(fleet_desc.trainer_param()
.sparse_table().slot_key())
sparse_table.sparse_value_name.CopyFrom(fleet_desc.trainer_param(
).sparse_table().slot_value())
sparse_table.sparse_grad_name.CopyFrom(fleet_desc.trainer_param(
).sparse_table().slot_gradient())
sparse_table.emb_dim = fleet_desc.server_param.downpour_server_param.downpour_table_param.accessor.fea_dim - 2
sparse_table.fea_dim = downpour.emb_dim + 2
sparse_table.label_var_name = "click"
# dense table should specify:
dense_table = downpour.dense_table.add()
dense_table.table_id = \
fleet_desc.trainer_param.dense_table.table_id
# dense_value_name
dense_table.dense_value_name.CopyFrom(fleet_desc.trainer_param(
).dense_table().dense_variable_name)
# dense_grad_name
dense_table.dense_grad_name.CopyFrom(fleet_desc.trainer_param(
).dense_table().dense_gradient_name)
downpour.skipped_ops.extend(fleet_desc.trainer_param.skip_op)
print(str(self.proto_desc))
else: else:
raise ValueError('ValueError: DeviceWorker %s ' raise ValueError('ValueError: DeviceWorker %s '
'is not supported in DistMultiTrainer' % worker) 'is not supported in DistMultiTrainer' % worker)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册