提交 e60eb1ea 编写于 作者: W wanghaoshuang

fix unitest

上级 668dc53f
......@@ -55,10 +55,10 @@ class TestNCE(OpTest):
'sampled_labels': range(num_sampled_classes)
}
self.inputs = {
'X': input,
'Input': input,
'Label': labels,
'W': weight,
'B': bias,
'Weight': weight,
'Bias': bias,
'SampleWeight': sample_weight
}
......@@ -66,11 +66,12 @@ class TestNCE(OpTest):
self.generate_data(5, 5, 4, 1, 2)
def compute(self):
out = nce(self.inputs['X'], self.inputs['W'], self.inputs['B'],
self.inputs['SampleWeight'], self.inputs['Label'],
self.attrs['num_classes'], self.attrs['num_sampled_classes'])
out = nce(self.inputs['Input'], self.inputs['Weight'],
self.inputs['Bias'], self.inputs['SampleWeight'],
self.inputs['Label'], self.attrs['num_classes'],
self.attrs['num_sampled_classes'])
self.outputs = {
'Out': out[0],
'Cost': out[0],
'SampleLogits': out[1],
'SampleLabels': out[2]
}
......@@ -84,7 +85,8 @@ class TestNCE(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(["X", "W", "B"], "Out", max_relative_error=0.02)
self.check_grad(
["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
class TestNCECase1(TestNCE):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册