From c70ea1cc30cfa54228d10e56934119e3a68e3e51 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Fri, 5 Jan 2018 19:35:52 +0800 Subject: [PATCH] add splitter --- .../paddle/v2/fluid/distribute_transpiler.py | 1 - python/paddle/v2/fluid/distributed_spliter.py | 38 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/fluid/distributed_spliter.py diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index e5314cf27..4c90b4a85 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -80,7 +80,6 @@ class DistributeTranspiler: # step3 send_inputs = [] - send_outputs = [] for _, splited in var2splited.iteritems(): send_inputs.extend(splited) send_outputs = self._create_vars_from_blocklist(program, param_blocks) diff --git a/python/paddle/v2/fluid/distributed_spliter.py b/python/paddle/v2/fluid/distributed_spliter.py new file mode 100644 index 000000000..e7ba53390 --- /dev/null +++ b/python/paddle/v2/fluid/distributed_spliter.py @@ -0,0 +1,38 @@ +def hash_name(varblocks, pserver_endpoints): + """ + :param varblocks: a list of VarBlock string indicating + sub blocks of variables + :return: a map of pserver endpoint -> varblock_str + """ + + def _hash_block(block_str, total): + return hash(block_str) % total + + ep2block = dict() + for varblock_str in varblocks: + if param.trainable is True and grad is not None: + server_id = _hash_block(varblock_str, len(pserver_endpoints)) + server_for_param = pserver_endpoints[server_id] + if not ep2block.has_key(server_for_param): + ep2block[server_for_param] = [] + ep2block[server_for_param].append(varblock_str) + + return ep2block + + +def round_robin(varblocks, pserver_endpoints): + assert (len(varblocks) > len(pserver_endpoints)) + + ep2block = dict() + pserver_idx = 0 + for varblock_str in varblocks: + if param.trainable is True: + server_for_param = pserver_endpoints[pserver_idx] + if not ep2block.has_key(server_for_param): + ep2block[server_for_param] = [] + ep2block[server_for_param].append(varblock_str) + + pserver_idx += 1 + if pserver_idx >= len(pserver_endpoints): + pserver_idx = 0 + return ep2block -- GitLab