diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 1bb86acdf8398fff63e5f55148ddb43b6b4da5be..0328c172cd02b9016509a1a39b58ee152b3f9554 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -877,9 +877,15 @@ class DistributeTranspiler(object): # create table param and grad var in pserver program origin_param_var = self.origin_program.global_block().vars[ self.table_name] + + zero_dim = long( + math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints))) + table_shape = list(origin_param_var.shape) + table_shape[0] = zero_dim + param_var = pserver_program.global_block().create_var( name=origin_param_var.name, - shape=origin_param_var.shape, + shape=table_shape, dtype=origin_param_var.dtype, type=core.VarDesc.VarType.SELECTED_ROWS, persistable=True)