提交 9c356093 编写于 作者: C chengduo 提交者: ceci3

Unified PE and compiler (#16042)

* unified PE and compiler
test=develop

* Polish code
test=develop
上级 9cc6f400
......@@ -106,13 +106,18 @@ class ParallelExecutor(object):
else framework.default_main_program()
self._compiled_program = compiler.CompiledProgram(main_program)
if share_vars_from:
assert isinstance(
share_vars_from, ParallelExecutor
), "The share_vars_from should be ParallelExecutor."
self._compiled_program.with_data_parallel(
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy,
share_vars_from=share_vars_from)
share_vars_from=share_vars_from._compiled_program
if share_vars_from else None)
self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
self._executor = executor.Executor(self._place)
self._exe = executor.Executor(self._place)
self._compiled_program._compile(place=self._place, scope=self._scope)
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
......@@ -180,11 +185,11 @@ class ParallelExecutor(object):
loss = pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))
"""
return self._executor.run(program=self._compiled_program,
scope=self._scope,
feed=feed,
fetch_list=fetch_list,
return_numpy=return_numpy)
return self._exe.run(program=self._compiled_program,
scope=self._scope,
feed=feed,
fetch_list=fetch_list,
return_numpy=return_numpy)
@property
def device_count(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册