diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 2ebaab3b1024878e28ae7064bfc5c3d1d091ad94..517418da1cf2f745ee5578e3c2b118394db7fae7 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -106,13 +106,18 @@ class ParallelExecutor(object): else framework.default_main_program() self._compiled_program = compiler.CompiledProgram(main_program) + if share_vars_from: + assert isinstance( + share_vars_from, ParallelExecutor + ), "The share_vars_from should be ParallelExecutor." self._compiled_program.with_data_parallel( loss_name=loss_name, build_strategy=build_strategy, exec_strategy=exec_strategy, - share_vars_from=share_vars_from) + share_vars_from=share_vars_from._compiled_program + if share_vars_from else None) self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() - self._executor = executor.Executor(self._place) + self._exe = executor.Executor(self._place) self._compiled_program._compile(place=self._place, scope=self._scope) def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): @@ -180,11 +185,11 @@ class ParallelExecutor(object): loss = pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name])) """ - return self._executor.run(program=self._compiled_program, - scope=self._scope, - feed=feed, - fetch_list=fetch_list, - return_numpy=return_numpy) + return self._exe.run(program=self._compiled_program, + scope=self._scope, + feed=feed, + fetch_list=fetch_list, + return_numpy=return_numpy) @property def device_count(self):