提交 9505850e 编写于 作者: X xuezhong

int type of numpy in windows default int32, need to set int64

test=develop
上级 9b24ac34
...@@ -305,7 +305,8 @@ class TestSampleLogitsOpV2(OpTest): ...@@ -305,7 +305,8 @@ class TestSampleLogitsOpV2(OpTest):
out = sample_logits(self.inputs["Logits"], self.inputs["Label"], out = sample_logits(self.inputs["Logits"], self.inputs["Label"],
self.attrs["num_samples"], self.attrs["seed"], self.attrs["num_samples"], self.attrs["seed"],
self.attrs["remove_accidental_hits"], True, self.attrs["remove_accidental_hits"], True,
self.fetched_samples, self.probabilities) self.fetched_samples.astype(np.int64),
self.probabilities)
self.outputs = { self.outputs = {
'SampledLogits': out[0], 'SampledLogits': out[0],
'Samples': out[1], 'Samples': out[1],
...@@ -365,7 +366,6 @@ class TestSampleLogitsOpV3(OpTest): ...@@ -365,7 +366,6 @@ class TestSampleLogitsOpV3(OpTest):
batch_size, num_true = label.shape batch_size, num_true = label.shape
use_custom_samples = False use_custom_samples = False
#import pdb; pdb.set_trace()
num_sampled_classes = num_samples + num_true num_sampled_classes = num_samples + num_true
logits = np.random.randn(batch_size, num_classes) logits = np.random.randn(batch_size, num_classes)
...@@ -391,7 +391,8 @@ class TestSampleLogitsOpV3(OpTest): ...@@ -391,7 +391,8 @@ class TestSampleLogitsOpV3(OpTest):
out = sample_logits(self.inputs["Logits"], self.inputs["Label"], out = sample_logits(self.inputs["Logits"], self.inputs["Label"],
self.attrs["num_samples"], self.attrs["seed"], self.attrs["num_samples"], self.attrs["seed"],
self.attrs["remove_accidental_hits"], True, self.attrs["remove_accidental_hits"], True,
self.fetched_samples, self.probabilities) self.fetched_samples.astype(np.int64),
self.probabilities)
self.outputs = { self.outputs = {
'SampledLogits': out[0], 'SampledLogits': out[0],
'Samples': out[1], 'Samples': out[1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册