未验证 提交 d0b0e274 编写于 作者: H huangjun12 提交者: GitHub

refine huber loss unittest, test=develop (#24263)

上级 356f5ee2
......@@ -89,14 +89,17 @@ class TestHuberLossOpError(unittest.TestCase):
xr = fluid.data(name='xr', shape=[None, 6], dtype="float32")
lw = np.random.random((6, 6)).astype("float32")
lr = fluid.data(name='lr', shape=[None, 6], dtype="float32")
self.assertRaises(TypeError, fluid.layers.huber_loss, xw, lr)
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw)
delta = 1.0
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw, delta)
self.assertRaises(TypeError, fluid.layers.huber_loss, xw, lr, delta)
# the dtype of input and label must be float32 or float64
xw2 = fluid.data(name='xw2', shape=[None, 6], dtype="int32")
lw2 = fluid.data(name='lw2', shape=[None, 6], dtype="int32")
self.assertRaises(TypeError, fluid.layers.huber_loss, xw2, lr)
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw2)
self.assertRaises(TypeError, fluid.layers.huber_loss, xw2, lr,
delta)
self.assertRaises(TypeError, fluid.layers.huber_loss, xr, lw2,
delta)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册