未验证 提交 a1cdbad1 编写于 作者: A Allen Guo 提交者: GitHub

rm fp16 dtype_check (#46739) (#46866)

...@@ -43,12 +43,6 @@ class TestBase(IPUOpTest): ...@@ -43,12 +43,6 @@ class TestBase(IPUOpTest):
self.feed_shape = [x.shape for x in self.feed_fp32.values()] self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys()) self.feed_list = list(self.feed_fp32.keys())
def dtype_check(self, program, to_fp16_var_names):
block = program.global_block()
assert len(to_fp16_var_names) > 0
for var_name in to_fp16_var_names:
assert (block.var(var_name).dtype, paddle.float16)
def set_attrs(self): def set_attrs(self):
self.num_ipus = 1 self.num_ipus = 1
self.enable_pipelining = False self.enable_pipelining = False
...@@ -84,7 +78,6 @@ class TestBase(IPUOpTest): ...@@ -84,7 +78,6 @@ class TestBase(IPUOpTest):
amp_list.unsupported_list = {} amp_list.unsupported_list = {}
to_fp16_var_names = paddle.static.amp.cast_model_to_fp16( to_fp16_var_names = paddle.static.amp.cast_model_to_fp16(
self.main_prog, amp_list, use_fp16_guard=True) self.main_prog, amp_list, use_fp16_guard=True)
self.dtype_check(self.main_prog, to_fp16_var_names)
if self.is_ipu_mode(exec_mode): if self.is_ipu_mode(exec_mode):
place = paddle.CPUPlace() place = paddle.CPUPlace()
......
...@@ -55,12 +55,6 @@ class TestBase(IPUOpTest): ...@@ -55,12 +55,6 @@ class TestBase(IPUOpTest):
self.enable_manual_shard = False self.enable_manual_shard = False
self.batches_per_step = 1 self.batches_per_step = 1
def dtype_check(self, program, to_fp16_var_names):
block = program.global_block()
assert len(to_fp16_var_names) > 0
for var_name in to_fp16_var_names:
assert (block.var(var_name).dtype, paddle.float16)
@IPUOpTest.static_graph @IPUOpTest.static_graph
def build_model(self): def build_model(self):
x = paddle.static.data(name=self.feed_list[0], x = paddle.static.data(name=self.feed_list[0],
...@@ -94,7 +88,6 @@ class TestBase(IPUOpTest): ...@@ -94,7 +88,6 @@ class TestBase(IPUOpTest):
amp_list.unsupported_list = {} amp_list.unsupported_list = {}
to_fp16_var_names = paddle.static.amp.cast_model_to_fp16( to_fp16_var_names = paddle.static.amp.cast_model_to_fp16(
self.main_prog, amp_list) self.main_prog, amp_list)
self.dtype_check(self.main_prog, to_fp16_var_names)
if self.is_ipu_mode(exec_mode): if self.is_ipu_mode(exec_mode):
place = paddle.CPUPlace() place = paddle.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部