未验证 提交 0bcf1365 编写于 作者: A andyjpaddle 提交者: GitHub

fix ut for pinv (#39566)

上级 644a894d
......@@ -36,6 +36,7 @@ class LinalgPinvTestCase(unittest.TestCase):
def generate_input(self):
self._input_shape = (5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -102,6 +103,7 @@ class LinalgPinvTestCase(unittest.TestCase):
class LinalgPinvTestCase1(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (4, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -109,6 +111,7 @@ class LinalgPinvTestCase1(LinalgPinvTestCase):
class LinalgPinvTestCase2(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (5, 4)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -116,6 +119,7 @@ class LinalgPinvTestCase2(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch1(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -123,6 +127,7 @@ class LinalgPinvTestCaseBatch1(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch2(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 4, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -130,6 +135,7 @@ class LinalgPinvTestCaseBatch2(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch3(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 4)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -137,6 +143,7 @@ class LinalgPinvTestCaseBatch3(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch4(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 6, 5, 4)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -144,6 +151,7 @@ class LinalgPinvTestCaseBatch4(LinalgPinvTestCase):
class LinalgPinvTestCaseBatchBig(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (2, 200, 300)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -151,6 +159,7 @@ class LinalgPinvTestCaseBatchBig(LinalgPinvTestCase):
class LinalgPinvTestCaseFP32(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -163,6 +172,7 @@ class LinalgPinvTestCaseFP32(LinalgPinvTestCase):
class LinalgPinvTestCaseRcond(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
......@@ -175,6 +185,7 @@ class LinalgPinvTestCaseRcond(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian1(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose().conj()
......@@ -188,6 +199,7 @@ class LinalgPinvTestCaseHermitian1(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian2(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)).conj()
......@@ -201,6 +213,7 @@ class LinalgPinvTestCaseHermitian2(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian3(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)).conj()
......@@ -214,6 +227,7 @@ class LinalgPinvTestCaseHermitian3(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian4(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose()
......@@ -226,6 +240,7 @@ class LinalgPinvTestCaseHermitian4(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1))
......@@ -238,6 +253,7 @@ class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册