From c7848aca556d1984391edb35a212fdae41709e63 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Thu, 20 May 2021 17:09:17 +0800 Subject: [PATCH] [Cherry-Pick]fix test_paddle_save_load and test_paddle_save_load_binary (#32949) (#33008) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test_paddle_save_load 单测随机挂:使用np.ndarray生成随机数组,可能生成nan,造成做对比时结果不匹配(nan != nan)。改为np.random.randn生成随机数组。 test_paddle_save_load_binary随机挂: 如果一个字符串不能解析为Program,windows上会有超时风险。解决方法:不在windows平台不加载'不能解析为Program的字符串'。 原始PR:#32949 --- .../fluid/tests/unittests/test_paddle_save_load.py | 9 ++++----- .../unittests/test_paddle_save_load_binary.py | 14 ++++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index 3a5c43b2bab..be2a6a653cc 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -412,11 +412,10 @@ class TestSaveLoadAny(unittest.TestCase): ] obj2 = {'k1': obj1, 'k2': state_dict, 'epoch': 123} obj3 = (paddle.randn( - [5, 4], dtype='float32'), np.ndarray( - [3, 4], dtype="float32"), { - "state_dict": state_dict, - "opt": state_dict - }) + [5, 4], dtype='float32'), np.random.randn(3, 4).astype("float32"), { + "state_dict": state_dict, + "opt": state_dict + }) obj4 = (np.random.randn(5, 6), (123, )) path1 = "test_save_load_any_complex_object_dygraph/obj1" diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py index 8b508d5c9ae..7385da56bea 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load_binary.py @@ -19,6 +19,7 @@ import numpy as np import os import sys import six +import platform import paddle import paddle.nn as nn @@ -162,12 +163,13 @@ class TestSaveLoadBinaryFormat(unittest.TestCase): with self.assertRaises(NotImplementedError): path = 'test_save_load_error/temp' paddle.save({}, path, use_binary_format=True) - - with self.assertRaises(ValueError): - path = 'test_save_load_error/temp' - with open(path, "w") as f: - f.write('\0') - paddle.load(path) + # On the Windows platform, when parsing a string that can't be parsed as a `Program`, `desc_.ParseFromString` has a timeout risk. + if 'Windows' != platform.system(): + with self.assertRaises(ValueError): + path = 'test_save_load_error/temp' + with open(path, "w") as f: + f.write('\0') + paddle.load(path) with self.assertRaises(ValueError): temp_lod = fluid.core.LoDTensor() -- GitLab