未验证 提交 185a900f 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2static ] infer_program may be incorrect in amp mode. (#44487)

* fix the outputs of net is x,x

* add unittest for duplicate output

* fix

* fix _infer_program use the original program not the amp program.

* get _***program_id back and avoid duplicate cache
ing

* fix
上级 0aa344f0
......@@ -243,6 +243,14 @@ class PartialProgramLayer:
def _infer_program_id(self):
return _hash_with_id(self._infer_program, self)
@LazyInitialized
def _infer_pure_fp16_program_id(self):
return _hash_with_id(self._infer_pure_fp16_program, self)
@LazyInitialized
def _infer_amp_program_id(self):
return _hash_with_id(self._infer_amp_program, self)
@LazyInitialized
def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self)
......@@ -341,7 +349,7 @@ class PartialProgramLayer:
elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program
else:
infer_program = self._infer_program
infer_program = self.infer_program
return infer_program.desc.block(0).op_size()
def __call__(self, inputs):
......@@ -380,14 +388,9 @@ class PartialProgramLayer:
@property
def program(self):
if self.training:
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
return self.train_program
else:
return self._infer_program
return self.infer_program
@property
def program_id(self):
......@@ -399,7 +402,30 @@ class PartialProgramLayer:
else:
return self._train_program_id
else:
return self._infer_program_id
if _in_amp_guard():
return self._infer_amp_program_id
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program_id
else:
return self._infer_program_id
@property
def train_program(self):
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
@property
def infer_program(self):
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program
def _prepare(self, inputs):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册