From 2fdbc1ce65be36770f840e405222fc6f222d0d50 Mon Sep 17 00:00:00 2001 From: Yancey Date: Wed, 20 Jun 2018 16:15:20 +0800 Subject: [PATCH] hidden bcast_params call in dist train (#11575) --- benchmark/fluid/fluid_benchmark.py | 2 -- python/paddle/fluid/parallel_executor.py | 10 +++++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index 2450c2d777..ece1102dce 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -264,8 +264,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, break else: loss, = exe.run([avg_loss.name], feed=feeder.feed(data)) - if args.update_method == "pserver": - exe.bcast_params() if args.use_reader_op: num_samples += args.batch_size * args.gpus else: diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index afa6d91145..25cc1355d5 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -71,7 +71,6 @@ class ParallelExecutor(object): num_trainers=1, trainer_id=0, **kwargs): - if len(kwargs) != 0: err_msg = "" for key in kwargs: @@ -130,6 +129,11 @@ class ParallelExecutor(object): main = main_program main = main if main else framework.default_main_program() scope = executor.global_scope() + # FIXME(Yancey1989): it's a temporary approach to determinate the distribute + # train program, call self.bcast_param() at the end of each mini-batch. + self.is_dist = True if "recv" in [ + op.type for op in main.global_block().ops + ] else False if share_vars_from and not isinstance(share_vars_from, ParallelExecutor): @@ -262,6 +266,10 @@ class ParallelExecutor(object): fetch_var_name = '@FETCHED_VAR_NAME@' self.executor.run(fetch_list, fetch_var_name) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() + + if self.is_dist: + self.bcast_params() + return [arr[i] for i in range(len(arr))] def bcast_params(self): -- GitLab