未验证 提交 0bcf1365 编写于 作者: A andyjpaddle 提交者: GitHub

fix ut for pinv (#39566)

上级 644a894d
...@@ -36,6 +36,7 @@ class LinalgPinvTestCase(unittest.TestCase): ...@@ -36,6 +36,7 @@ class LinalgPinvTestCase(unittest.TestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (5, 5) self._input_shape = (5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -102,6 +103,7 @@ class LinalgPinvTestCase(unittest.TestCase): ...@@ -102,6 +103,7 @@ class LinalgPinvTestCase(unittest.TestCase):
class LinalgPinvTestCase1(LinalgPinvTestCase): class LinalgPinvTestCase1(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (4, 5) self._input_shape = (4, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -109,6 +111,7 @@ class LinalgPinvTestCase1(LinalgPinvTestCase): ...@@ -109,6 +111,7 @@ class LinalgPinvTestCase1(LinalgPinvTestCase):
class LinalgPinvTestCase2(LinalgPinvTestCase): class LinalgPinvTestCase2(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (5, 4) self._input_shape = (5, 4)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -116,6 +119,7 @@ class LinalgPinvTestCase2(LinalgPinvTestCase): ...@@ -116,6 +119,7 @@ class LinalgPinvTestCase2(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch1(LinalgPinvTestCase): class LinalgPinvTestCaseBatch1(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 5) self._input_shape = (3, 5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -123,6 +127,7 @@ class LinalgPinvTestCaseBatch1(LinalgPinvTestCase): ...@@ -123,6 +127,7 @@ class LinalgPinvTestCaseBatch1(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch2(LinalgPinvTestCase): class LinalgPinvTestCaseBatch2(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 4, 5) self._input_shape = (3, 4, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -130,6 +135,7 @@ class LinalgPinvTestCaseBatch2(LinalgPinvTestCase): ...@@ -130,6 +135,7 @@ class LinalgPinvTestCaseBatch2(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch3(LinalgPinvTestCase): class LinalgPinvTestCaseBatch3(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 4) self._input_shape = (3, 5, 4)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -137,6 +143,7 @@ class LinalgPinvTestCaseBatch3(LinalgPinvTestCase): ...@@ -137,6 +143,7 @@ class LinalgPinvTestCaseBatch3(LinalgPinvTestCase):
class LinalgPinvTestCaseBatch4(LinalgPinvTestCase): class LinalgPinvTestCaseBatch4(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 6, 5, 4) self._input_shape = (3, 6, 5, 4)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -144,6 +151,7 @@ class LinalgPinvTestCaseBatch4(LinalgPinvTestCase): ...@@ -144,6 +151,7 @@ class LinalgPinvTestCaseBatch4(LinalgPinvTestCase):
class LinalgPinvTestCaseBatchBig(LinalgPinvTestCase): class LinalgPinvTestCaseBatchBig(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (2, 200, 300) self._input_shape = (2, 200, 300)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -151,6 +159,7 @@ class LinalgPinvTestCaseBatchBig(LinalgPinvTestCase): ...@@ -151,6 +159,7 @@ class LinalgPinvTestCaseBatchBig(LinalgPinvTestCase):
class LinalgPinvTestCaseFP32(LinalgPinvTestCase): class LinalgPinvTestCaseFP32(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 5) self._input_shape = (3, 5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -163,6 +172,7 @@ class LinalgPinvTestCaseFP32(LinalgPinvTestCase): ...@@ -163,6 +172,7 @@ class LinalgPinvTestCaseFP32(LinalgPinvTestCase):
class LinalgPinvTestCaseRcond(LinalgPinvTestCase): class LinalgPinvTestCaseRcond(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 5) self._input_shape = (3, 5, 5)
np.random.seed(123)
self._input_data = np.random.random(self._input_shape).astype( self._input_data = np.random.random(self._input_shape).astype(
self.dtype) self.dtype)
...@@ -175,6 +185,7 @@ class LinalgPinvTestCaseRcond(LinalgPinvTestCase): ...@@ -175,6 +185,7 @@ class LinalgPinvTestCaseRcond(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian1(LinalgPinvTestCase): class LinalgPinvTestCaseHermitian1(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (5, 5) self._input_shape = (5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) + \ x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype) 1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose().conj() self._input_data = x + x.transpose().conj()
...@@ -188,6 +199,7 @@ class LinalgPinvTestCaseHermitian1(LinalgPinvTestCase): ...@@ -188,6 +199,7 @@ class LinalgPinvTestCaseHermitian1(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian2(LinalgPinvTestCase): class LinalgPinvTestCaseHermitian2(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 5) self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) + \ x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype) 1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)).conj() self._input_data = x + x.transpose((0, 2, 1)).conj()
...@@ -201,6 +213,7 @@ class LinalgPinvTestCaseHermitian2(LinalgPinvTestCase): ...@@ -201,6 +213,7 @@ class LinalgPinvTestCaseHermitian2(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian3(LinalgPinvTestCase): class LinalgPinvTestCaseHermitian3(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 5) self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) + \ x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype) 1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)).conj() self._input_data = x + x.transpose((0, 2, 1)).conj()
...@@ -214,6 +227,7 @@ class LinalgPinvTestCaseHermitian3(LinalgPinvTestCase): ...@@ -214,6 +227,7 @@ class LinalgPinvTestCaseHermitian3(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian4(LinalgPinvTestCase): class LinalgPinvTestCaseHermitian4(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (5, 5) self._input_shape = (5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose() self._input_data = x + x.transpose()
...@@ -226,6 +240,7 @@ class LinalgPinvTestCaseHermitian4(LinalgPinvTestCase): ...@@ -226,6 +240,7 @@ class LinalgPinvTestCaseHermitian4(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase): class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 5) self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)) self._input_data = x + x.transpose((0, 2, 1))
...@@ -238,6 +253,7 @@ class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase): ...@@ -238,6 +253,7 @@ class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase):
class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase): class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase):
def generate_input(self): def generate_input(self):
self._input_shape = (3, 5, 5) self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.random.random(self._input_shape).astype(self.dtype) x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)) self._input_data = x + x.transpose((0, 2, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册