未验证 提交 7a64d48f 编写于 作者: T tangwei12 提交者: GitHub

fix test_save_load with pickle (#14410)

* fix test_save_load with pickle

test=develop

* fix test_save_load with pickle

test=develop

* fix test_save_load with pickle

test=develop
上级 d3aed98d
......@@ -26,6 +26,7 @@ from multiprocessing import Process
from functools import reduce
import numpy as np
import pickle
import unittest
import six
......@@ -166,7 +167,10 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
io.save_persistables(startup_exe, model_dir, trainer_prog)
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor())
print(np.ravel(var).tolist())
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
if __name__ == "__main__":
......
......@@ -65,14 +65,14 @@ class TestDistSaveLoadDense2x2(TestDistBase):
shutil.rmtree(model_dir)
local_np = np.array(eval(local_var[0]))
train0_np = np.array(eval(tr0_var[0]))
train1_np = np.array(eval(tr1_var[0]))
local_np = np.array(local_var)
train0_np = np.array(tr0_var)
train1_np = np.array(tr1_var)
self.assertAlmostEqual(local_np.all(), train0_np.all(), delta=delta)
self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta)
self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta)
@unittest.skip(reason="CI fail")
def test_dist(self):
need_envs = {
"IS_DISTRIBUTED": '0',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册