diff --git a/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py index 21bcb7b7314ab72acc618be9268b6d9906becaf9..7118466a521019fb9e854e862265119645dba487 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_inference_ipu.py @@ -43,12 +43,6 @@ class TestBase(IPUOpTest): self.feed_shape = [x.shape for x in self.feed_fp32.values()] self.feed_list = list(self.feed_fp32.keys()) - def dtype_check(self, program, to_fp16_var_names): - block = program.global_block() - assert len(to_fp16_var_names) > 0 - for var_name in to_fp16_var_names: - assert (block.var(var_name).dtype, paddle.float16) - def set_attrs(self): self.num_ipus = 1 self.enable_pipelining = False @@ -84,7 +78,6 @@ class TestBase(IPUOpTest): amp_list.unsupported_list = {} to_fp16_var_names = paddle.static.amp.cast_model_to_fp16( self.main_prog, amp_list, use_fp16_guard=True) - self.dtype_check(self.main_prog, to_fp16_var_names) if self.is_ipu_mode(exec_mode): place = paddle.CPUPlace() diff --git a/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py index a733a26d606164b98c46d14d9e598bc2747dd569..51a0e91a29c3bc320eed49e8e136455dc1fc1e75 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_mixed_precision_training_ipu.py @@ -55,12 +55,6 @@ class TestBase(IPUOpTest): self.enable_manual_shard = False self.batches_per_step = 1 - def dtype_check(self, program, to_fp16_var_names): - block = program.global_block() - assert len(to_fp16_var_names) > 0 - for var_name in to_fp16_var_names: - assert (block.var(var_name).dtype, paddle.float16) - @IPUOpTest.static_graph def build_model(self): x = paddle.static.data(name=self.feed_list[0], @@ -94,7 +88,6 @@ class TestBase(IPUOpTest): amp_list.unsupported_list = {} to_fp16_var_names = paddle.static.amp.cast_model_to_fp16( self.main_prog, amp_list) - self.dtype_check(self.main_prog, to_fp16_var_names) if self.is_ipu_mode(exec_mode): place = paddle.CPUPlace()