未验证 提交 1d8c82b6 编写于 作者: K Kai Song 提交者: GitHub

fix strided_slice ut (#53553)

* fix strided_slice ut

* remove check_dygraph
上级 a5a0e8fe
...@@ -554,10 +554,10 @@ class TestStridedSliceOp_strides_Tensor(OpTest): ...@@ -554,10 +554,10 @@ class TestStridedSliceOp_strides_Tensor(OpTest):
class TestStridedSliceAPI(unittest.TestCase): class TestStridedSliceAPI(unittest.TestCase):
def test_1(self): def test_1(self):
input = np.random.random([3, 4, 5, 6]).astype("float64") input = np.random.random([3, 4, 5, 6]).astype("float64")
minus_1 = paddle.tensor.fill_constant([1], "int32", -1) minus_1 = paddle.tensor.fill_constant([], "int32", -1)
minus_3 = paddle.tensor.fill_constant([1], "int32", -3) minus_3 = paddle.tensor.fill_constant([], "int32", -3)
starts = paddle.static.data(name='starts', shape=[3], dtype='int32') starts = paddle.static.data(name='starts', shape=[3], dtype='int32')
ends = paddle.static.data(name='ends', shape=[3], dtype='int64') ends = paddle.static.data(name='ends', shape=[3], dtype='int32')
strides = paddle.static.data(name='strides', shape=[3], dtype='int32') strides = paddle.static.data(name='strides', shape=[3], dtype='int32')
x = paddle.static.data( x = paddle.static.data(
...@@ -1032,10 +1032,10 @@ class TestStrideSliceFP16Op(OpTest): ...@@ -1032,10 +1032,10 @@ class TestStrideSliceFP16Op(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad({'Input'}, 'Out', check_eager=True) self.check_grad({'Input'}, 'Out')
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(100) self.input = np.random.rand(100)
...@@ -1069,10 +1069,10 @@ class TestStrideSliceBF16Op(OpTest): ...@@ -1069,10 +1069,10 @@ class TestStrideSliceBF16Op(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad({'Input'}, 'Out', check_eager=True) self.check_grad({'Input'}, 'Out')
def initTestCase(self): def initTestCase(self):
self.input = np.random.rand(100) self.input = np.random.rand(100)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册