提交 e60eb1ea 编写于 作者: W wanghaoshuang

fix unitest

上级 668dc53f
...@@ -55,10 +55,10 @@ class TestNCE(OpTest): ...@@ -55,10 +55,10 @@ class TestNCE(OpTest):
'sampled_labels': range(num_sampled_classes) 'sampled_labels': range(num_sampled_classes)
} }
self.inputs = { self.inputs = {
'X': input, 'Input': input,
'Label': labels, 'Label': labels,
'W': weight, 'Weight': weight,
'B': bias, 'Bias': bias,
'SampleWeight': sample_weight 'SampleWeight': sample_weight
} }
...@@ -66,11 +66,12 @@ class TestNCE(OpTest): ...@@ -66,11 +66,12 @@ class TestNCE(OpTest):
self.generate_data(5, 5, 4, 1, 2) self.generate_data(5, 5, 4, 1, 2)
def compute(self): def compute(self):
out = nce(self.inputs['X'], self.inputs['W'], self.inputs['B'], out = nce(self.inputs['Input'], self.inputs['Weight'],
self.inputs['SampleWeight'], self.inputs['Label'], self.inputs['Bias'], self.inputs['SampleWeight'],
self.attrs['num_classes'], self.attrs['num_sampled_classes']) self.inputs['Label'], self.attrs['num_classes'],
self.attrs['num_sampled_classes'])
self.outputs = { self.outputs = {
'Out': out[0], 'Cost': out[0],
'SampleLogits': out[1], 'SampleLogits': out[1],
'SampleLabels': out[2] 'SampleLabels': out[2]
} }
...@@ -84,7 +85,8 @@ class TestNCE(OpTest): ...@@ -84,7 +85,8 @@ class TestNCE(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["X", "W", "B"], "Out", max_relative_error=0.02) self.check_grad(
["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02)
class TestNCECase1(TestNCE): class TestNCECase1(TestNCE):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册