diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index 4faa4a098e0163c907990d6efb0fac60affa5808..da7cdc7f8f525bbf039a64c25fae9f4dcf5ac39c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -243,6 +243,14 @@ class PartialProgramLayer: def _infer_program_id(self): return _hash_with_id(self._infer_program, self) + @LazyInitialized + def _infer_pure_fp16_program_id(self): + return _hash_with_id(self._infer_pure_fp16_program, self) + + @LazyInitialized + def _infer_amp_program_id(self): + return _hash_with_id(self._infer_amp_program, self) + @LazyInitialized def _train_program_id(self): program_id = _hash_with_id(self._train_program, self) @@ -341,7 +349,7 @@ class PartialProgramLayer: elif _in_pure_fp16_guard(): infer_program = self._infer_pure_fp16_program else: - infer_program = self._infer_program + infer_program = self.infer_program return infer_program.desc.block(0).op_size() def __call__(self, inputs): @@ -380,14 +388,9 @@ class PartialProgramLayer: @property def program(self): if self.training: - if _in_amp_guard(): - return self._train_amp_program - elif _in_pure_fp16_guard(): - return self._train_pure_fp16_program - else: - return self._train_program + return self.train_program else: - return self._infer_program + return self.infer_program @property def program_id(self): @@ -399,7 +402,30 @@ class PartialProgramLayer: else: return self._train_program_id else: - return self._infer_program_id + if _in_amp_guard(): + return self._infer_amp_program_id + elif _in_pure_fp16_guard(): + return self._infer_pure_fp16_program_id + else: + return self._infer_program_id + + @property + def train_program(self): + if _in_amp_guard(): + return self._train_amp_program + elif _in_pure_fp16_guard(): + return self._train_pure_fp16_program + else: + return self._train_program + + @property + def infer_program(self): + if _in_amp_guard(): + return self._infer_amp_program + elif _in_pure_fp16_guard(): + return self._infer_pure_fp16_program + else: + return self._infer_program def _prepare(self, inputs): """