From 36f3d0af220a7ad3cbcd431f5b00463e4fab7630 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 26 Sep 2017 17:24:59 -0700 Subject: [PATCH] Fix error in unit test of ModifiedHuberLossOp --- .../tests/test_modified_huber_loss_op.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py b/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py index a7e2b57529b..18a6e9e8a40 100644 --- a/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py +++ b/python/paddle/v2/framework/tests/test_modified_huber_loss_op.py @@ -5,22 +5,31 @@ from op_test import OpTest def modified_huber_loss_forward(val): if val < -1: - return -4 * val + return -4. * val elif val < 1: - return (1 - val) * (1 - val) + return (1. - val) * (1. - val) else: - return 0 + return 0. class TestModifiedHuberLossOp(OpTest): def setUp(self): self.op_type = 'modified_huber_loss' samples_num = 32 - self.inputs = { - 'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'), - 'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1)) - } - product_res = self.inputs['X'] * (2 * self.inputs['Y'] - 1) + + x_np = np.random.uniform(-2., 2., (samples_num, 1)).astype('float32') + y_np = np.random.choice([0, 1], samples_num).reshape( + (samples_num, 1)).astype('float32') + product_res = x_np * (2. * y_np - 1.) + # keep away from the junction of piecewise function + for pos, val in np.ndenumerate(product_res): + while abs(val - 1.) < 0.05: + x_np[pos] = np.random.uniform(-2., 2.) + y_np[pos] = np.random.choice([0, 1]) + product_res[pos] = x_np[pos] * (2 * y_np[pos] - 1) + val = product_res[pos] + + self.inputs = {'X': x_np, 'Y': y_np} loss = np.vectorize(modified_huber_loss_forward)(product_res) self.outputs = { @@ -32,7 +41,7 @@ class TestModifiedHuberLossOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Out', max_relative_error=0.005) + self.check_grad(['X'], 'Out', max_relative_error=0.01) if __name__ == '__main__': -- GitLab