提交 7e15e571 编写于 作者: Y Yang Zhang

Compile programs even for single device

上级 b4614894
...@@ -347,8 +347,11 @@ class StaticGraphAdapter(object): ...@@ -347,8 +347,11 @@ class StaticGraphAdapter(object):
startup_prog = self._startup_prog._prune(uninitialized) startup_prog = self._startup_prog._prune(uninitialized)
self._executor.run(startup_prog) self._executor.run(startup_prog)
if len(device_ids) < 2: compiled_prog = fluid.CompiledProgram(prog)
return prog if len(device_ids) > 1:
loss_name = None
if self.mode == 'train' and self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
share_vars_from = None share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs: if self.mode == 'eval' and 'train' in self._compiled_progs:
...@@ -359,13 +362,10 @@ class StaticGraphAdapter(object): ...@@ -359,13 +362,10 @@ class StaticGraphAdapter(object):
if self.mode == 'train' and 'eval' in self._compiled_progs: if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval'] del self._compiled_progs['eval']
compiled_prog = fluid.CompiledProgram(prog)
loss_name = None
if self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
compiled_prog = compiled_prog.with_data_parallel( compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places, loss_name=loss_name, places=places,
share_vars_from=share_vars_from) share_vars_from=share_vars_from)
self._compiled_progs[self.mode] = compiled_prog self._compiled_progs[self.mode] = compiled_prog
return compiled_prog return compiled_prog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册