diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index 708265b4576809b1f4157d54989c6138c6e5a2b0..6e0c26943aad2389a4b3341ec257c15d0562515d 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -25,9 +25,9 @@ class TestSamplingIdOp(OpTest): self.op_type = "sampling_id" self.use_mkldnn = False self.init_kernel_type() - self.X = np.random.random((8, 4)).astype('float32') + self.X = np.random.random((100, 10)).astype('float32') self.inputs = {"X": self.X} - self.Y = np.random.random(8).astype('float32') + self.Y = np.random.random(100).astype('float32') self.outputs = {'Out': self.Y} self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1} @@ -36,6 +36,16 @@ class TestSamplingIdOp(OpTest): y1 = self.out self.check_output_customized(self.verify_output) y2 = self.out + + # check dtype + assert y1.dtype == np.int64 + assert y2.dtype == np.int64 + + # check output is index ids of inputs + inputs_ids = np.arange(self.X.shape[1]) + assert np.isin(y1, inputs_ids).all() + assert np.isin(y2, inputs_ids).all() + self.assertTrue(np.array_equal(y1, y2)) self.assertEqual(len(y1), len(self.Y))