From 900f411d629dd4fe18417455055529b86f4455f2 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 25 Jan 2018 17:11:52 +0800 Subject: [PATCH] fix dist transpiler bug --- python/paddle/v2/fluid/distribute_transpiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 934eba73b82..908810c8be1 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -225,7 +225,7 @@ class DistributeTranspiler: if len(splited_vars) <= 1: continue orig_var = program.global_block().vars[varname] - if orig_var == core.VarDesc.VarType.SELECTED_ROWS: + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) @@ -234,7 +234,7 @@ class DistributeTranspiler: inputs={"X": orig_var}, outputs={"Out": splited_vars}, attrs={"height_sections": height_sections}) - elif orig_var == core.VarDesc.VarType.LOD_TENSOR: + elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: sections = [] for v in splited_vars: sections.append(v.shape[0]) -- GitLab