From 2b9ff39f5f66663a60d4d33bdc2ee1da0c1ff364 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 12 Jun 2018 10:25:25 +0800 Subject: [PATCH] fix the default value prefetch_var_name_to_block_id --- paddle/fluid/operators/listen_and_serv_op.cc | 3 ++- .../fluid/transpiler/distribute_transpiler.py | 20 +++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 4cf2c8daa55..4d12278799f 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -340,7 +340,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr(kOptimizeBlock, "BlockID to run on server side."); AddAttr>(kPrefetchVarNameToBlockId, - "prefetch block to run on server side."); + "prefetch blocks to run on server side.") + .SetDefault({}); AddAttr("Fanin", "How many clients send to this server.") .SetDefault(1); } diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 924e5ba4f6a..2480d4e76a1 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -530,19 +530,23 @@ class DistributeTranspiler: else: assert len(prefetch_var_name_to_block_id) == 0 + attrs = { + "OptimizeBlock": pserver_program.block(1), + "endpoint": endpoint, + "Fanin": self.trainer_num, + "sync_mode": self.sync_mode, + "grad_to_block_id": grad_to_block_id + } + if len(prefetch_var_name_to_block_id) > 0: + attrs['prefetch_var_name_to_block_id'] \ + = prefetch_var_name_to_block_id + # step5 append the listen_and_serv op pserver_program.global_block().append_op( type="listen_and_serv", inputs={'X': recv_inputs}, outputs={}, - attrs={ - "OptimizeBlock": pserver_program.block(1), - "endpoint": endpoint, - "Fanin": self.trainer_num, - "prefetch_var_name_to_block_id": prefetch_var_name_to_block_id, - "sync_mode": self.sync_mode, - "grad_to_block_id": grad_to_block_id - }) + attrs=attrs) pserver_program.sync_with_cpp() return pserver_program -- GitLab