From 185a900fd143fceca36bbf7f97e28d50646fd41e Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 21 Jul 2022 20:14:23 +0800 Subject: [PATCH] [ Dy2static ] infer_program may be incorrect in amp mode. (#44487) * fix the outputs of net is x,x * add unittest for duplicate output * fix * fix _infer_program use the original program not the amp program. * get _***program_id back and avoid duplicate cache ing * fix --- .../dygraph_to_static/partial_program.py | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 4faa4a098e..da7cdc7f8f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -243,6 +243,14 @@ class PartialProgramLayer: def _infer_program_id(self): return _hash_with_id(self._infer_program, self) + @LazyInitialized + def _infer_pure_fp16_program_id(self): + return _hash_with_id(self._infer_pure_fp16_program, self) + + @LazyInitialized + def _infer_amp_program_id(self): + return _hash_with_id(self._infer_amp_program, self) + @LazyInitialized def _train_program_id(self): program_id = _hash_with_id(self._train_program, self) @@ -341,7 +349,7 @@ class PartialProgramLayer: elif _in_pure_fp16_guard(): infer_program = self._infer_pure_fp16_program else: - infer_program = self._infer_program + infer_program = self.infer_program return infer_program.desc.block(0).op_size() def __call__(self, inputs): @@ -380,14 +388,9 @@ class PartialProgramLayer: @property def program(self): if self.training: - if _in_amp_guard(): - return self._train_amp_program - elif _in_pure_fp16_guard(): - return self._train_pure_fp16_program - else: - return self._train_program + return self.train_program else: - return self._infer_program + return self.infer_program @property def program_id(self): @@ -399,7 +402,30 @@ class PartialProgramLayer: else: return self._train_program_id else: - return self._infer_program_id + if _in_amp_guard(): + return self._infer_amp_program_id + elif _in_pure_fp16_guard(): + return self._infer_pure_fp16_program_id + else: + return self._infer_program_id + + @property + def train_program(self): + if _in_amp_guard(): + return self._train_amp_program + elif _in_pure_fp16_guard(): + return self._train_pure_fp16_program + else: + return self._train_program + + @property + def infer_program(self): + if _in_amp_guard(): + return self._infer_amp_program + elif _in_pure_fp16_guard(): + return self._infer_pure_fp16_program + else: + return self._infer_program def _prepare(self, inputs): """ -- GitLab