diff --git a/python/paddle/fluid/tests/unittests/test_sample_logits.py b/python/paddle/fluid/tests/unittests/test_sample_logits.py index ed51b04dca45a29483b68bdd01dcdf7ac982d0e3..d7b2a6207e75dd8aa34402e006dac7deff80d1fd 100644 --- a/python/paddle/fluid/tests/unittests/test_sample_logits.py +++ b/python/paddle/fluid/tests/unittests/test_sample_logits.py @@ -263,7 +263,7 @@ class TestSampleLogitsOpV2(OpTest): 'remove_accidental_hits': remove_accidental_hits, 'seed': seed } - self.inputs = {'Logits': logits, 'Label': label} + self.inputs = {'Logits': logits, 'Label': label.astype(np.int64)} def set_data(self, num_classes, num_samples, seed, remove_accidental_hits): label = np.array([[6, 12, 15, 5, 1], [0, 9, 4, 1, 10], @@ -347,7 +347,7 @@ class TestSampleLogitsOpV3(OpTest): 'remove_accidental_hits': remove_accidental_hits, 'seed': seed } - self.inputs = {'Logits': logits, 'Label': label} + self.inputs = {'Logits': logits, 'Label': label.astype(np.int64)} def set_data(self, num_classes, num_samples, seed, remove_accidental_hits): label = [52, 2, 2, 17, 96, 2, 17, 96, 37, 2]