diff --git a/model.py b/model.py index c4c460239dfb242fa6f38b0a3fc129c8c0ae4822..23a111e8516a7cffe0ee68487789ed444e48442f 100644 --- a/model.py +++ b/model.py @@ -247,12 +247,11 @@ class StaticGraphAdapter(object): ids = [str(i) for i in device_ids] ids.sort() - prog_hash = '_'.join([self.mode] + ids) - compiled_prog = self._compiled_progs.get(prog_hash, None) + compiled_prog = self._compiled_progs.get(self.mode, None) if compiled_prog is None: compiled_prog = self._compile_and_initialize( self._progs[self.mode], device, device_ids) - self._compiled_progs[prog_hash] = compiled_prog + self._compiled_progs[self.mode] = compiled_prog feed = {} input_names = [name for name in self._input_desc.keys()] @@ -334,8 +333,17 @@ class StaticGraphAdapter(object): loss_name = None if self._loss_endpoint is not None: loss_name = self._loss_endpoint.name + share_vars_from = None + if self.mode == 'eval' and 'train' in self._compiled_progs: + share_vars_from = self._compiled_progs['train'] + # HACK invalidate eval program if is compiled before train program + # quite hackish, OTOH, it is generally uncommon that the eval + # program will be run before the train program + if self.mode == 'train' and 'eval' in self._compiled_progs: + del self._compiled_progs['eval'] compiled_prog = compiled_prog.with_data_parallel( - loss_name=loss_name, places=places) + loss_name=loss_name, places=places, + share_vars_from=share_vars_from) if self._executor is None: self._executor = fluid.Executor(place)