提交 ce15d89a 编写于 作者: Y yangyaming

Adapt to new unittest.

上级 6d4c4405
import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
from paddle.v2.framework.op import Operator
import numpy as np
from op_test import OpTest
def modified_huber_loss_forward(val):
......@@ -14,11 +12,9 @@ def modified_huber_loss_forward(val):
return 0
class TestModifiedHuberLossOp_f0(unittest.TestCase):
__metaclass__ = OpTestMeta
class TestModifiedHuberLossOp(OpTest):
def setUp(self):
self.type = 'modified_huber_loss'
self.op_type = 'modified_huber_loss'
samples_num = 32
self.inputs = {
'X': np.random.uniform(-1, 1., (samples_num, 1)).astype('float32'),
......@@ -32,22 +28,11 @@ class TestModifiedHuberLossOp_f0(unittest.TestCase):
'Out': loss.reshape((samples_num, 1))
}
def test_check_output(self):
self.check_output()
class TestModifiedHuberLossGradOp(GradientChecker):
def test_modified_huber_loss_b0(self):
samples_num = 10
inputs = {
'X': np.random.uniform(-1, 1, (samples_num, 1)).astype('float32'),
'Y': np.random.choice([0, 1], samples_num).reshape((samples_num, 1))
}
op = Operator(
"modified_huber_loss",
X='X',
Y='Y',
IntermediateVal='IntermediateVal',
Out='Out')
self.compare_grad(op, inputs, no_grad_set=set(['IntermediateVal', 'Y']))
self.check_grad(op, inputs, set(["X"]), "Out")
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.005)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册