未验证 提交 29c211ee 编写于 作者: W Weilong Wu 提交者: GitHub

Support test_numpy_bridge and thread_local_has_grad (#38835)

上级 d3ba1895
......@@ -16,10 +16,11 @@ import unittest
import numpy as np
import paddle.fluid as fluid
import warnings
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode
class TestImperativeNumpyBridge(unittest.TestCase):
def test_tensor_from_numpy(self):
def func_tensor_from_numpy(self):
data_np = np.array([[2, 3, 1]]).astype('float32')
with fluid.dygraph.guard(fluid.CPUPlace()):
with warnings.catch_warnings(record=True) as w:
......@@ -39,9 +40,18 @@ class TestImperativeNumpyBridge(unittest.TestCase):
self.assertTrue(np.array_equal(var2.numpy(), data_np))
data_np[0][0] = -1
self.assertEqual(data_np[0][0], -1)
self.assertNotEqual(var2[0][0].numpy()[0], -1)
if _in_eager_mode():
# eager_mode, var2 is EagerTensor, is not subscriptable
self.assertNotEqual(var2.numpy()[0][0], -1)
else:
self.assertNotEqual(var2[0][0].numpy()[0], -1)
self.assertFalse(np.array_equal(var2.numpy(), data_np))
def test_func_tensor_from_numpy(self):
with _test_eager_guard():
self.func_tensor_from_numpy()
self.func_tensor_from_numpy()
if __name__ == '__main__':
unittest.main()
......@@ -18,6 +18,7 @@ import time
import paddle.nn as nn
import numpy as np
import threading
from paddle.fluid.framework import _test_eager_guard, _in_eager_mode
class SimpleNet(nn.Layer):
......@@ -44,7 +45,7 @@ class TestCases(unittest.TestCase):
x = net(x)
self.assertFalse(x.stop_gradient)
def test_main(self):
def func_main(self):
threads = []
for _ in range(10):
threads.append(threading.Thread(target=self.thread_1_main))
......@@ -54,6 +55,11 @@ class TestCases(unittest.TestCase):
for t in threads:
t.join()
def test_main(self):
with _test_eager_guard():
self.func_main()
self.func_main()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册