From 67df9550eb9111f18eb7b4204a1ebde15b790090 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Fri, 3 Jan 2020 17:39:38 +0800 Subject: [PATCH] Set `share_vars_from` for eval program --- model.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/model.py b/model.py index c4c4602..23a111e 100644 --- a/model.py +++ b/model.py @@ -247,12 +247,11 @@ class StaticGraphAdapter(object): ids = [str(i) for i in device_ids] ids.sort() - prog_hash = '_'.join([self.mode] + ids) - compiled_prog = self._compiled_progs.get(prog_hash, None) + compiled_prog = self._compiled_progs.get(self.mode, None) if compiled_prog is None: compiled_prog = self._compile_and_initialize( self._progs[self.mode], device, device_ids) - self._compiled_progs[prog_hash] = compiled_prog + self._compiled_progs[self.mode] = compiled_prog feed = {} input_names = [name for name in self._input_desc.keys()] @@ -334,8 +333,17 @@ class StaticGraphAdapter(object): loss_name = None if self._loss_endpoint is not None: loss_name = self._loss_endpoint.name + share_vars_from = None + if self.mode == 'eval' and 'train' in self._compiled_progs: + share_vars_from = self._compiled_progs['train'] + # HACK invalidate eval program if is compiled before train program + # quite hackish, OTOH, it is generally uncommon that the eval + # program will be run before the train program + if self.mode == 'train' and 'eval' in self._compiled_progs: + del self._compiled_progs['eval'] compiled_prog = compiled_prog.with_data_parallel( - loss_name=loss_name, places=places) + loss_name=loss_name, places=places, + share_vars_from=share_vars_from) if self._executor is None: self._executor = fluid.Executor(place) -- GitLab