提交 bcfb82d3 编写于 作者: T typhoonzero

dist train support split selectedrows

上级 0237b7e9
...@@ -59,6 +59,51 @@ def split_dense_variable(var_list, ...@@ -59,6 +59,51 @@ def split_dense_variable(var_list,
return blocks return blocks
def split_selected_rows(var,
pserver_count,
min_block_size=1024,
max_block_size=1048576):
assert ((len(var.shape)) <= 1)
split_count = pserver_count
indices = var.desc.selected_rows().dims()
var_width = reduce(lambda x, y: x * y, var.shape[1:])
row_count = len(indices)
rows_per_block = 1
if var_width < min_block_size:
rows_per_block = 1
split_count = row_count
else:
rows_per_block = row_count / pserver_count
if not rows_per_block % pserver_count:
rows_per_block += 1
split_count = row_count / rows_per_block
if not row_count % rows_per_block:
split_count += 1
blocks = []
for block_id in xrange(split_count):
curr_block_rows = min(rows_per_block,
row_count - (block_id * rows_per_block))
block = VarBlock(var.name, block_id, curr_block_rows)
blocks.append(block)
return blocks
def split_variable(var_list,
pserver_count,
min_block_size=1024,
max_block_size=1048576):
for var in var_list:
if var.type == core.VarDesc.VarType.LOD_TENSOR:
split_dense_variable(var_list, pserver_count, min_block_size,
max_block_size)
elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
split_selected_rows(var_list, pserver_count, min_block_size,
max_block_size)
else:
raise TypeError("variable must be lodtensor or selected rows")
class DistributeTranspiler: class DistributeTranspiler:
def transpile(self, def transpile(self,
optimize_ops, optimize_ops,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册