未验证 提交 c11f5312 编写于 作者: C chengduo 提交者: GitHub

Unified PE and compiler (#16042)

* unified PE and compiler
test=develop

* Polish code
test=develop
上级 6375fe45
...@@ -106,13 +106,18 @@ class ParallelExecutor(object): ...@@ -106,13 +106,18 @@ class ParallelExecutor(object):
else framework.default_main_program() else framework.default_main_program()
self._compiled_program = compiler.CompiledProgram(main_program) self._compiled_program = compiler.CompiledProgram(main_program)
if share_vars_from:
assert isinstance(
share_vars_from, ParallelExecutor
), "The share_vars_from should be ParallelExecutor."
self._compiled_program.with_data_parallel( self._compiled_program.with_data_parallel(
loss_name=loss_name, loss_name=loss_name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
share_vars_from=share_vars_from) share_vars_from=share_vars_from._compiled_program
if share_vars_from else None)
self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
self._executor = executor.Executor(self._place) self._exe = executor.Executor(self._place)
self._compiled_program._compile(place=self._place, scope=self._scope) self._compiled_program._compile(place=self._place, scope=self._scope)
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
...@@ -180,11 +185,11 @@ class ParallelExecutor(object): ...@@ -180,11 +185,11 @@ class ParallelExecutor(object):
loss = pe.run(feed=feeder.feed(cur_batch), loss = pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name])) fetch_list=[avg_cost.name]))
""" """
return self._executor.run(program=self._compiled_program, return self._exe.run(program=self._compiled_program,
scope=self._scope, scope=self._scope,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=return_numpy) return_numpy=return_numpy)
@property @property
def device_count(self): def device_count(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册