未验证 提交 185a900f 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2static ] infer_program may be incorrect in amp mode. (#44487)

* fix the outputs of net is x,x

* add unittest for duplicate output

* fix

* fix _infer_program use the original program not the amp program.

* get _***program_id back and avoid duplicate cache
ing

* fix
上级 0aa344f0
...@@ -243,6 +243,14 @@ class PartialProgramLayer: ...@@ -243,6 +243,14 @@ class PartialProgramLayer:
def _infer_program_id(self): def _infer_program_id(self):
return _hash_with_id(self._infer_program, self) return _hash_with_id(self._infer_program, self)
@LazyInitialized
def _infer_pure_fp16_program_id(self):
return _hash_with_id(self._infer_pure_fp16_program, self)
@LazyInitialized
def _infer_amp_program_id(self):
return _hash_with_id(self._infer_amp_program, self)
@LazyInitialized @LazyInitialized
def _train_program_id(self): def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self) program_id = _hash_with_id(self._train_program, self)
...@@ -341,7 +349,7 @@ class PartialProgramLayer: ...@@ -341,7 +349,7 @@ class PartialProgramLayer:
elif _in_pure_fp16_guard(): elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program infer_program = self._infer_pure_fp16_program
else: else:
infer_program = self._infer_program infer_program = self.infer_program
return infer_program.desc.block(0).op_size() return infer_program.desc.block(0).op_size()
def __call__(self, inputs): def __call__(self, inputs):
...@@ -380,14 +388,9 @@ class PartialProgramLayer: ...@@ -380,14 +388,9 @@ class PartialProgramLayer:
@property @property
def program(self): def program(self):
if self.training: if self.training:
if _in_amp_guard(): return self.train_program
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
else: else:
return self._infer_program return self.infer_program
@property @property
def program_id(self): def program_id(self):
...@@ -399,7 +402,30 @@ class PartialProgramLayer: ...@@ -399,7 +402,30 @@ class PartialProgramLayer:
else: else:
return self._train_program_id return self._train_program_id
else: else:
return self._infer_program_id if _in_amp_guard():
return self._infer_amp_program_id
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program_id
else:
return self._infer_program_id
@property
def train_program(self):
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
@property
def infer_program(self):
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册