提交 8eff2d62 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Split _BinaryLogisticHead from _MultiClassHead.

Change: 139971064
上级 c348cead
......@@ -131,13 +131,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
_noop_train_op, logits=logits)
self.assertAlmostEqual(.15514446, sess.run(model_fn_ops.loss))
def testMultiClassWithInvalidNClass(self):
try:
head_lib._multi_class_head(n_classes=1)
self.fail("Softmax with no n_classes did not raise error.")
except ValueError:
# Expected
pass
def testInvalidNClasses(self):
for n_classes in (None, -1, 0, 1):
with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
head_lib._multi_class_head(n_classes=n_classes)
class BinarySvmModelHeadTest(tf.test.TestCase):
......
......@@ -196,14 +196,17 @@ def sdca_model_fn(features, labels, mode, params):
if not isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
raise ValueError("Optimizer must be of type SDCAOptimizer")
if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access
# pylint: disable=protected-access
if isinstance(head, head_lib._BinarySvmHead):
loss_type = "hinge_loss"
elif isinstance(head, head_lib._MultiClassHead): # pylint: disable=protected-access
elif isinstance(
head, (head_lib._MultiClassHead, head_lib._BinaryLogisticHead)):
loss_type = "logistic_loss"
elif isinstance(head, head_lib._RegressionHead): # pylint: disable=protected-access
elif isinstance(head, head_lib._RegressionHead):
loss_type = "squared_loss"
else:
return ValueError("Unsupported head type: {}".format(head))
raise ValueError("Unsupported head type: {}".format(head))
# pylint: enable=protected-access
parent_scope = "linear"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册