提交 67df9550 编写于 作者: Y Yang Zhang

Set `share_vars_from` for eval program

上级 57f4170d
......@@ -247,12 +247,11 @@ class StaticGraphAdapter(object):
ids = [str(i) for i in device_ids]
ids.sort()
prog_hash = '_'.join([self.mode] + ids)
compiled_prog = self._compiled_progs.get(prog_hash, None)
compiled_prog = self._compiled_progs.get(self.mode, None)
if compiled_prog is None:
compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids)
self._compiled_progs[prog_hash] = compiled_prog
self._compiled_progs[self.mode] = compiled_prog
feed = {}
input_names = [name for name in self._input_desc.keys()]
......@@ -334,8 +333,17 @@ class StaticGraphAdapter(object):
loss_name = None
if self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs:
share_vars_from = self._compiled_progs['train']
# HACK invalidate eval program if is compiled before train program
# quite hackish, OTOH, it is generally uncommon that the eval
# program will be run before the train program
if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval']
compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places)
loss_name=loss_name, places=places,
share_vars_from=share_vars_from)
if self._executor is None:
self._executor = fluid.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册