提交 c5742f79 编写于 作者: X xuezhong

set label type to int64 to pass windows test

test=develop
上级 9505850e
......@@ -263,7 +263,7 @@ class TestSampleLogitsOpV2(OpTest):
'remove_accidental_hits': remove_accidental_hits,
'seed': seed
}
self.inputs = {'Logits': logits, 'Label': label}
self.inputs = {'Logits': logits, 'Label': label.astype(np.int64)}
def set_data(self, num_classes, num_samples, seed, remove_accidental_hits):
label = np.array([[6, 12, 15, 5, 1], [0, 9, 4, 1, 10],
......@@ -347,7 +347,7 @@ class TestSampleLogitsOpV3(OpTest):
'remove_accidental_hits': remove_accidental_hits,
'seed': seed
}
self.inputs = {'Logits': logits, 'Label': label}
self.inputs = {'Logits': logits, 'Label': label.astype(np.int64)}
def set_data(self, num_classes, num_samples, seed, remove_accidental_hits):
label = [52, 2, 2, 17, 96, 2, 17, 96, 37, 2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册