提交 7e15e571 编写于 作者: Y Yang Zhang

Compile programs even for single device

上级 b4614894
......@@ -347,25 +347,25 @@ class StaticGraphAdapter(object):
startup_prog = self._startup_prog._prune(uninitialized)
self._executor.run(startup_prog)
if len(device_ids) < 2:
return prog
share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs:
share_vars_from = self._compiled_progs['train']
# HACK invalidate eval program if is compiled before train program
# quite hackish, OTOH, it is generally uncommon that the eval
# program will be run before the train program
if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval']
compiled_prog = fluid.CompiledProgram(prog)
loss_name = None
if self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places,
share_vars_from=share_vars_from)
if len(device_ids) > 1:
loss_name = None
if self.mode == 'train' and self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs:
share_vars_from = self._compiled_progs['train']
# HACK invalidate eval program if is compiled before train program
# quite hackish, OTOH, it is generally uncommon that the eval
# program will be run before the train program
if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval']
compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places,
share_vars_from=share_vars_from)
self._compiled_progs[self.mode] = compiled_prog
return compiled_prog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册