提交 7b385ff2 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #4407 from Canpio/fix_huber_loss_test_error

Fix error in unit test of ModifiedHuberLossOp
...@@ -5,22 +5,31 @@ from op_test import OpTest ...@@ -5,22 +5,31 @@ from op_test import OpTest
def modified_huber_loss_forward(val): def modified_huber_loss_forward(val):
if val < -1: if val < -1:
return -4 * val return -4. * val
elif val < 1: elif val < 1:
return (1 - val) * (1 - val) return (1. - val) * (1. - val)
else: else:
return 0 return 0.
class TestModifiedHuberLossOp(OpTest): class TestModifiedHuberLossOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'modified_huber_loss' self.op_type = 'modified_huber_loss'
samples_num = 32 samples_num = 32
self.inputs = {
'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'), x_np = np.random.uniform(-2., 2., (samples_num, 1)).astype('float32')
'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1)) y_np = np.random.choice([0, 1], samples_num).reshape(
} (samples_num, 1)).astype('float32')
product_res = self.inputs['X'] * (2 * self.inputs['Y'] - 1) product_res = x_np * (2. * y_np - 1.)
# keep away from the junction of piecewise function
for pos, val in np.ndenumerate(product_res):
while abs(val - 1.) < 0.05:
x_np[pos] = np.random.uniform(-2., 2.)
y_np[pos] = np.random.choice([0, 1])
product_res[pos] = x_np[pos] * (2 * y_np[pos] - 1)
val = product_res[pos]
self.inputs = {'X': x_np, 'Y': y_np}
loss = np.vectorize(modified_huber_loss_forward)(product_res) loss = np.vectorize(modified_huber_loss_forward)(product_res)
self.outputs = { self.outputs = {
...@@ -32,7 +41,7 @@ class TestModifiedHuberLossOp(OpTest): ...@@ -32,7 +41,7 @@ class TestModifiedHuberLossOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.005) self.check_grad(['X'], 'Out', max_relative_error=0.01)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册