From d52fcaf42ecb251913d730250ac99ccab94a152c Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 26 Oct 2018 17:32:29 +0800 Subject: [PATCH] replace table init op with fake init --- .../fluid/transpiler/distribute_transpiler.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 29357f53c5..5826db292b 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -477,12 +477,23 @@ class DistributeTranspiler(object): # delete table init op if self.has_distributed_lookup_table: - trainer_table_param_init_op = [] + table_var = self.startup_program.global_block().vars[ + self.table_name] + table_param_init_op = [] for op in self.startup_program.global_block().ops: if self.table_name in op.output_arg_names: - trainer_table_param_init_op.append(op) - delete_ops(self.startup_program.global_block(), - trainer_table_param_init_op) + table_param_init_op.append(op) + init_op_num = len(table_param_init_op) + if init_op_num != 1: + raise ValueError("table init op num should be 1, now is " + str( + init_op_num)) + table_init_op = table_param_init_op[1] + self.startup_program.global_block().append_op( + type="fake_init", + inputs={}, + outputs={"Out": table_var}, + attrs={"shape": table_init_op.attr('shape')}) + delete_ops(self.startup_program.global_block(), table_param_init_op) self.origin_program.__str__() -- GitLab