From 9505850e33d6d8bf0db7851ab7973aaca5f29876 Mon Sep 17 00:00:00 2001 From: xuezhong Date: Tue, 12 Feb 2019 09:16:41 +0000 Subject: [PATCH] int type of numpy in windows default int32, need to set int64 test=develop --- python/paddle/fluid/tests/unittests/test_sample_logits.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sample_logits.py b/python/paddle/fluid/tests/unittests/test_sample_logits.py index 7419cc513b..ed51b04dca 100644 --- a/python/paddle/fluid/tests/unittests/test_sample_logits.py +++ b/python/paddle/fluid/tests/unittests/test_sample_logits.py @@ -305,7 +305,8 @@ class TestSampleLogitsOpV2(OpTest): out = sample_logits(self.inputs["Logits"], self.inputs["Label"], self.attrs["num_samples"], self.attrs["seed"], self.attrs["remove_accidental_hits"], True, - self.fetched_samples, self.probabilities) + self.fetched_samples.astype(np.int64), + self.probabilities) self.outputs = { 'SampledLogits': out[0], 'Samples': out[1], @@ -365,7 +366,6 @@ class TestSampleLogitsOpV3(OpTest): batch_size, num_true = label.shape use_custom_samples = False - #import pdb; pdb.set_trace() num_sampled_classes = num_samples + num_true logits = np.random.randn(batch_size, num_classes) @@ -391,7 +391,8 @@ class TestSampleLogitsOpV3(OpTest): out = sample_logits(self.inputs["Logits"], self.inputs["Label"], self.attrs["num_samples"], self.attrs["seed"], self.attrs["remove_accidental_hits"], True, - self.fetched_samples, self.probabilities) + self.fetched_samples.astype(np.int64), + self.probabilities) self.outputs = { 'SampledLogits': out[0], 'Samples': out[1], -- GitLab