From bcfb82d33e431d621317f97d3c0703d9b002a8ee Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 15 Jan 2018 20:55:48 +0800 Subject: [PATCH] dist train support split selectedrows --- .../paddle/v2/fluid/distribute_transpiler.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index d17f9815cc..00fe3e68c9 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -59,6 +59,51 @@ def split_dense_variable(var_list, return blocks +def split_selected_rows(var, + pserver_count, + min_block_size=1024, + max_block_size=1048576): + assert ((len(var.shape)) <= 1) + + split_count = pserver_count + indices = var.desc.selected_rows().dims() + var_width = reduce(lambda x, y: x * y, var.shape[1:]) + row_count = len(indices) + rows_per_block = 1 + if var_width < min_block_size: + rows_per_block = 1 + split_count = row_count + else: + rows_per_block = row_count / pserver_count + if not rows_per_block % pserver_count: + rows_per_block += 1 + split_count = row_count / rows_per_block + if not row_count % rows_per_block: + split_count += 1 + blocks = [] + for block_id in xrange(split_count): + curr_block_rows = min(rows_per_block, + row_count - (block_id * rows_per_block)) + block = VarBlock(var.name, block_id, curr_block_rows) + blocks.append(block) + return blocks + + +def split_variable(var_list, + pserver_count, + min_block_size=1024, + max_block_size=1048576): + for var in var_list: + if var.type == core.VarDesc.VarType.LOD_TENSOR: + split_dense_variable(var_list, pserver_count, min_block_size, + max_block_size) + elif var.type == core.VarDesc.VarType.SELECTED_ROWS: + split_selected_rows(var_list, pserver_count, min_block_size, + max_block_size) + else: + raise TypeError("variable must be lodtensor or selected rows") + + class DistributeTranspiler: def transpile(self, optimize_ops, -- GitLab