From 16027ea111cfa8023d9d9caadcbae2eb0b41fd70 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 29 May 2018 15:51:04 +0800 Subject: [PATCH] use block.clone_variable instead of _clone_var --- .../fluid/transpiler/distribute_transpiler.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index aed51dcda64..867c4bf3c88 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -689,15 +689,6 @@ class DistributeTranspiler: def _create_table_optimize_block(self, pserver_index, pserver_program, pre_block_idx, grad_to_block_id): - def _clone_var(block, var, persistable=True): - assert isinstance(var, Variable) - return block.create_var( - name=var.name, - shape=var.shape, - dtype=var.dtype, - type=var.type, - persistable=persistable) - # STEP: create table optimize block # create table param and grad var in pserver program origin_param_var = self.origin_program.global_block().vars[ @@ -708,11 +699,11 @@ class DistributeTranspiler: dtype=origin_param_var.dtype, type=core.VarDesc.VarType.SELECTED_ROWS, persistable=True) - grad_var = _clone_var( - pserver_program.global_block(), + # parameter must be selected rows + param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) + grad_var = pserver_program.global_block().clone_variable( self.origin_program.global_block().vars[grad_var_name( - self.table_name)], - persistable=False) + self.table_name)]) # create table optimize block in pserver program table_opt_op = [ -- GitLab