diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 01308e416a9313bad13ded4e40c79bb0550e31ed..133d3f72dbd6ab13c98d124369038309c94cba5b 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -53,7 +53,7 @@ class SamplingIdKernel : public framework::OpKernel { static_cast(context.Attr("min")), static_cast(context.Attr("max"))); - std::vector ids(batch_size); + std::vector ids(batch_size); for (int i = 0; i < batch_size; ++i) { T r = dist(engine); int idx = width - 1; @@ -63,7 +63,7 @@ class SamplingIdKernel : public framework::OpKernel { break; } } - ids[i] = ins_vector[idx]; + ids[i] = int64_t(idx); } std::vector out_dim; diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index 708265b4576809b1f4157d54989c6138c6e5a2b0..674ef2ddf44edb4246c9d952cb75b36fe3d6ddc8 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -25,9 +25,9 @@ class TestSamplingIdOp(OpTest): self.op_type = "sampling_id" self.use_mkldnn = False self.init_kernel_type() - self.X = np.random.random((8, 4)).astype('float32') + self.X = np.random.random((100, 10)).astype('float32') self.inputs = {"X": self.X} - self.Y = np.random.random(8).astype('float32') + self.Y = np.random.random(100).astype('int64') self.outputs = {'Out': self.Y} self.attrs = {'max': 1.0, 'min': 0.0, 'seed': 1} @@ -36,6 +36,16 @@ class TestSamplingIdOp(OpTest): y1 = self.out self.check_output_customized(self.verify_output) y2 = self.out + + # check dtype + assert y1.dtype == np.int64 + assert y2.dtype == np.int64 + + # check output is index ids of inputs + inputs_ids = np.arange(self.X.shape[1]) + assert np.isin(y1, inputs_ids).all() + assert np.isin(y2, inputs_ids).all() + self.assertTrue(np.array_equal(y1, y2)) self.assertEqual(len(y1), len(self.Y))