From 82d924984f75cecce2626ecb7376f2424b50aaae Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 29 Jan 2018 14:45:10 +0800 Subject: [PATCH] update dist transpiler --- python/paddle/v2/fluid/distribute_transpiler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index abcad899bf..4e54ab806b 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -221,7 +221,7 @@ class DistributeTranspiler: if len(splited_vars) <= 1: continue orig_var = program.global_block().vars[varname] - if orig_var == core.VarDesc.VarType.SELECTED_ROWS: + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) @@ -230,7 +230,7 @@ class DistributeTranspiler: inputs={"X": orig_var}, outputs={"Out": splited_vars}, attrs={"height_sections": height_sections}) - elif orig_var == core.VarDesc.VarType.LOD_TENSOR: + elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: sections = [] for v in splited_vars: sections.append(v.shape[0]) @@ -470,8 +470,7 @@ class DistributeTranspiler: # Append the recv op pserver_program.global_block().append_op( type="recv", - inputs={"RX": self.param_grad_ep_mapping[endpoint]["grads"] - }, # grads to recv + inputs={}, outputs={}, attrs={ "OptimizeBlock": optimize_sub_program.global_block(), -- GitLab