提交 0707abb5 编写于 作者: T tangwei12

lookup table fix

上级 83c85f34
...@@ -877,9 +877,15 @@ class DistributeTranspiler(object): ...@@ -877,9 +877,15 @@ class DistributeTranspiler(object):
# create table param and grad var in pserver program # create table param and grad var in pserver program
origin_param_var = self.origin_program.global_block().vars[ origin_param_var = self.origin_program.global_block().vars[
self.table_name] self.table_name]
zero_dim = long(
math.ceil(origin_param_var.shape[0] / len(self.pserver_endpoints)))
table_shape = list(origin_param_var.shape)
table_shape[0] = zero_dim
param_var = pserver_program.global_block().create_var( param_var = pserver_program.global_block().create_var(
name=origin_param_var.name, name=origin_param_var.name,
shape=origin_param_var.shape, shape=table_shape,
dtype=origin_param_var.dtype, dtype=origin_param_var.dtype,
type=core.VarDesc.VarType.SELECTED_ROWS, type=core.VarDesc.VarType.SELECTED_ROWS,
persistable=True) persistable=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册