diff --git a/python/paddle/fluid/tests/unittests/test_sample_logits.py b/python/paddle/fluid/tests/unittests/test_sample_logits.py index 7419cc513bee1f4ce3695922cad6491a37a9b7b2..ed51b04dca45a29483b68bdd01dcdf7ac982d0e3 100644 --- a/python/paddle/fluid/tests/unittests/test_sample_logits.py +++ b/python/paddle/fluid/tests/unittests/test_sample_logits.py @@ -305,7 +305,8 @@ class TestSampleLogitsOpV2(OpTest): out = sample_logits(self.inputs["Logits"], self.inputs["Label"], self.attrs["num_samples"], self.attrs["seed"], self.attrs["remove_accidental_hits"], True, - self.fetched_samples, self.probabilities) + self.fetched_samples.astype(np.int64), + self.probabilities) self.outputs = { 'SampledLogits': out[0], 'Samples': out[1], @@ -365,7 +366,6 @@ class TestSampleLogitsOpV3(OpTest): batch_size, num_true = label.shape use_custom_samples = False - #import pdb; pdb.set_trace() num_sampled_classes = num_samples + num_true logits = np.random.randn(batch_size, num_classes) @@ -391,7 +391,8 @@ class TestSampleLogitsOpV3(OpTest): out = sample_logits(self.inputs["Logits"], self.inputs["Label"], self.attrs["num_samples"], self.attrs["seed"], self.attrs["remove_accidental_hits"], True, - self.fetched_samples, self.probabilities) + self.fetched_samples.astype(np.int64), + self.probabilities) self.outputs = { 'SampledLogits': out[0], 'Samples': out[1],