From 0707abb51b0bdbd4cddd3c0c62ce5288515217b1 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 14 Aug 2018 17:16:46 +0800 Subject: [PATCH] lookup table fix --- python/paddle/fluid/transpiler/distribute_transpiler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 1bb86acdf83..0328c172cd0 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -877,9 +877,15 @@ class DistributeTranspiler(object): # create table param and grad var in pserver program origin_param_var = self.origin_program.global_block().vars[ self.table_name] + + zero_dim = long( + math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints))) + table_shape = list(origin_param_var.shape) + table_shape[0] = zero_dim + param_var = pserver_program.global_block().create_var( name=origin_param_var.name, - shape=origin_param_var.shape, + shape=table_shape, dtype=origin_param_var.dtype, type=core.VarDesc.VarType.SELECTED_ROWS, persistable=True) -- GitLab