From 9c3560931cb6ab8bdb8fa25a01360e1e881d8da5 Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 5 Mar 2019 07:35:55 -0600 Subject: [PATCH] Unified PE and compiler (#16042) * unified PE and compiler test=develop * Polish code test=develop --- python/paddle/fluid/parallel_executor.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 2ebaab3b10..517418da1c 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -106,13 +106,18 @@ class ParallelExecutor(object): else framework.default_main_program() self._compiled_program = compiler.CompiledProgram(main_program) + if share_vars_from: + assert isinstance( + share_vars_from, ParallelExecutor + ), "The share_vars_from should be ParallelExecutor." self._compiled_program.with_data_parallel( loss_name=loss_name, build_strategy=build_strategy, exec_strategy=exec_strategy, - share_vars_from=share_vars_from) + share_vars_from=share_vars_from._compiled_program + if share_vars_from else None) self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() - self._executor = executor.Executor(self._place) + self._exe = executor.Executor(self._place) self._compiled_program._compile(place=self._place, scope=self._scope) def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): @@ -180,11 +185,11 @@ class ParallelExecutor(object): loss = pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name])) """ - return self._executor.run(program=self._compiled_program, - scope=self._scope, - feed=feed, - fetch_list=fetch_list, - return_numpy=return_numpy) + return self._exe.run(program=self._compiled_program, + scope=self._scope, + feed=feed, + fetch_list=fetch_list, + return_numpy=return_numpy) @property def device_count(self): -- GitLab