提交 8eff2d62 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Split _BinaryLogisticHead from _MultiClassHead.

Change: 139971064
上级 c348cead
......@@ -102,16 +102,18 @@ def _multi_class_head(n_classes, label_name=None, weight_column_name=None,
Raises:
ValueError: if n_classes is < 2
"""
if n_classes < 2:
raise ValueError("n_classes must be > 1 for classification.")
if (n_classes is None) or (n_classes < 2):
raise ValueError(
"n_classes must be > 1 for classification: %s." % n_classes)
if n_classes == 2:
loss_fn = _log_loss_with_two_classes
else:
loss_fn = _softmax_cross_entropy_loss
return _MultiClassHead(train_loss_fn=loss_fn,
eval_loss_fn=loss_fn,
n_classes=n_classes,
return _BinaryLogisticHead(label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
head_name=head_name,
thresholds=thresholds)
return _MultiClassHead(n_classes=n_classes,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
......@@ -268,7 +270,6 @@ class _RegressionHead(_Head):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
predictions = self._predictions(logits)
if (mode == model_fn.ModeKeys.INFER) or (labels is None):
loss = None
train_op = None
......@@ -278,15 +279,14 @@ class _RegressionHead(_Head):
train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
else self._train_op(features, labels, train_op_fn, logits))
eval_metric_ops = self._eval_metric_ops(features, labels, logits)
signature_fn = self._signature_fn()
return model_fn.ModelFnOps(
mode=mode,
predictions=predictions,
predictions=self._predictions(logits),
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
signature_fn=signature_fn)
signature_fn=self._signature_fn())
def _training_loss(self, features, labels, logits, name="training_loss"):
"""Returns training loss tensor for this head.
......@@ -403,18 +403,22 @@ class _RegressionHead(_Head):
self._weight_column_name)}
class _MultiClassHead(_Head):
"""_Head for classification."""
def _log_loss_with_two_classes(logits, labels):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if len(labels.get_shape()) == 1:
labels = array_ops.expand_dims(labels, dim=[1])
return nn.sigmoid_cross_entropy_with_logits(
logits, math_ops.to_float(labels))
def __init__(self, train_loss_fn, eval_loss_fn, n_classes, label_name,
weight_column_name, enable_centered_bias, head_name,
thresholds=None):
class _BinaryLogisticHead(_Head):
"""_Head for binary logistic classifciation."""
def __init__(self, label_name, weight_column_name, enable_centered_bias,
head_name, loss_fn=_log_loss_with_two_classes, thresholds=None):
"""Base type for all single heads.
Args:
train_loss_fn: loss_fn for training.
eval_loss_fn: loss_fn for eval.
n_classes: number of classes.
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
......@@ -425,53 +429,47 @@ class _MultiClassHead(_Head):
residual after centered bias.
head_name: name of the head. If provided, predictions, summary and metrics
keys will be prefixed by the head_name and an underscore.
loss_fn: Loss function.
thresholds: thresholds for eval.
Raises:
ValueError: if n_classes is invalid.
"""
if n_classes < 2:
raise ValueError("n_classes must be >= 2")
self._thresholds = thresholds if thresholds else [.5]
self._train_loss_fn = train_loss_fn
self._eval_loss_fn = eval_loss_fn
self._logits_dimension = 1 if n_classes == 2 else n_classes
self._label_name = label_name
self._weight_column_name = weight_column_name
self._head_name = head_name
self._loss_fn = loss_fn
self._enable_centered_bias = enable_centered_bias
self._centered_bias_weight_collection = _head_prefixed(head_name,
"centered_bias")
@property
def logits_dimension(self):
return self._logits_dimension
return 1
def head_ops(self, features, labels, mode, train_op_fn, logits=None,
logits_input=None, scope=None):
logits_input=None):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
predictions = self._predictions(logits)
if (mode == model_fn.ModeKeys.INFER) or (labels is None):
loss = None
train_op = None
eval_metric_ops = None
else:
loss = self._training_loss(features, labels, logits)
train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
train_op = (None if train_op_fn is None
else self._train_op(features, labels, train_op_fn, logits))
eval_metric_ops = self._eval_metric_ops(features, labels, logits)
signature_fn = self._signature_fn()
return model_fn.ModelFnOps(
mode=mode,
predictions=predictions,
predictions=self._predictions(logits),
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
signature_fn=signature_fn)
signature_fn=self._signature_fn())
def _training_loss(self, features, labels, logits=None, name="training_loss"):
"""Returns training loss tensor for this head.
......@@ -492,7 +490,7 @@ class _MultiClassHead(_Head):
name: Op name.
Returns:
A loss `Tensor`.
A loss `Output`.
"""
labels = _check_labels(labels, self._label_name)
......@@ -501,7 +499,7 @@ class _MultiClassHead(_Head):
self.logits_dimension,
self._centered_bias_weight_collection))
loss_unweighted = self._train_loss_fn(logits, labels)
loss_unweighted = self._loss_fn(logits, labels)
loss, weighted_average_loss = _loss(
loss_unweighted,
_weight_tensor(features, self._weight_column_name),
......@@ -520,7 +518,7 @@ class _MultiClassHead(_Head):
self.logits_dimension,
self._centered_bias_weight_collection,
labels,
self._train_loss_fn)]
self._loss_fn)]
train_op = control_flow_ops.group(train_op, *centered_bias_step)
return train_op
......@@ -536,10 +534,10 @@ class _MultiClassHead(_Head):
"""Returns a dict of predictions.
Args:
logits: logits `Tensor` before applying possible centered bias.
logits: logits `Output` before applying possible centered bias.
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
Dict of prediction `Output` keyed by `PredictionKey`.
"""
if self._enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(
......@@ -551,13 +549,12 @@ class _MultiClassHead(_Head):
"""Returns a dict of predictions.
Args:
logits: logits `Tensor` after applying possible centered bias.
logits: logits `Output` after applying possible centered bias.
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
Dict of prediction `Output` keyed by `PredictionKey`.
"""
predictions = {prediction_key.PredictionKey.LOGITS: logits}
if self.logits_dimension == 1:
predictions[prediction_key.PredictionKey.LOGISTIC] = math_ops.sigmoid(
logits)
logits = array_ops.concat(1, [array_ops.zeros_like(logits), logits])
......@@ -565,7 +562,6 @@ class _MultiClassHead(_Head):
logits)
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
logits, 1)
return predictions
def _signature_fn(self):
......@@ -591,7 +587,7 @@ class _MultiClassHead(_Head):
"""Returns a dict of `MetricSpec` objects keyed by name."""
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
self._loss_fn,
prediction_key.PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
......@@ -603,7 +599,6 @@ class _MultiClassHead(_Head):
prediction_key.PredictionKey.CLASSES,
self._label_name,
self._weight_column_name))
if self.logits_dimension == 1:
def _add_binary_metric(key, metric_fn):
metrics[_head_prefixed(self._head_name, key)] = (
metric_spec.MetricSpec(metric_fn,
......@@ -638,6 +633,216 @@ class _MultiClassHead(_Head):
return metrics
def _softmax_cross_entropy_loss(logits, labels):
# Check that we got integer for classification.
if not labels.dtype.is_integer:
raise ValueError("Labels dtype should be integer "
"Instead got %s." % labels.dtype)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
if len(labels.get_shape()) == 2:
labels = array_ops.squeeze(labels, squeeze_dims=[1])
return nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
class _MultiClassHead(_Head):
"""_Head for classification."""
def __init__(self, n_classes, label_name,
weight_column_name, enable_centered_bias, head_name,
loss_fn=_softmax_cross_entropy_loss, thresholds=None):
"""Base type for all single heads.
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHead`).
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
enable_centered_bias: A bool. If True, estimator will learn a centered
bias variable for each class. Rest of the model structure learns the
residual after centered bias.
head_name: name of the head. If provided, predictions, summary and metrics
keys will be prefixed by the head_name and an underscore.
loss_fn: Loss function.
thresholds: thresholds for eval.
Raises:
ValueError: if n_classes is invalid.
"""
if (n_classes is None) or (n_classes <= 2):
raise ValueError("n_classes must be > 2: %s." % n_classes)
self._thresholds = thresholds if thresholds else [.5]
self._logits_dimension = n_classes
self._label_name = label_name
self._weight_column_name = weight_column_name
self._head_name = head_name
self._loss_fn = loss_fn
self._enable_centered_bias = enable_centered_bias
self._centered_bias_weight_collection = _head_prefixed(head_name,
"centered_bias")
@property
def logits_dimension(self):
return self._logits_dimension
def head_ops(self, features, labels, mode, train_op_fn, logits=None,
logits_input=None, scope=None):
"""See `_Head`."""
_check_mode_valid(mode)
_check_logits_input_not_supported(logits, logits_input)
if (mode == model_fn.ModeKeys.INFER) or (labels is None):
loss = None
train_op = None
eval_metric_ops = None
else:
loss = self._training_loss(features, labels, logits)
train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
else self._train_op(features, labels, train_op_fn, logits))
eval_metric_ops = self._eval_metric_ops(features, labels, logits)
return model_fn.ModelFnOps(
mode=mode,
predictions=self._predictions(logits),
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
signature_fn=self._signature_fn())
def _training_loss(self, features, labels, logits=None, name="training_loss"):
"""Returns training loss tensor for this head.
Training loss is different from the loss reported on the tensorboard as we
should respect the example weights when computing the gradient.
L = sum_{i} w_{i} * l_{i} / B
where B is the number of examples in the batch, l_{i}, w_{i} are individual
losses, and example weight.
Args:
features: features dict.
labels: either a tensor for labels or in multihead case, a dict of string
to labels tensor.
logits: logits, a float tensor.
name: Op name.
Returns:
A loss `Tensor`.
"""
labels = _check_labels(labels, self._label_name)
if self._enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(
self.logits_dimension,
self._centered_bias_weight_collection))
loss_unweighted = self._loss_fn(logits, labels)
loss, weighted_average_loss = _loss(
loss_unweighted,
_weight_tensor(features, self._weight_column_name),
name=name)
summary.scalar(
_head_prefixed(self._head_name, "loss"), weighted_average_loss)
return loss
def _train_op(self, features, labels, train_op_fn, logits):
"""Returns op for the training step."""
loss = self._training_loss(features, labels, logits)
train_op = train_op_fn(loss)
if self._enable_centered_bias:
centered_bias_step = [_centered_bias_step(
self.logits_dimension,
self._centered_bias_weight_collection,
labels,
self._loss_fn)]
train_op = control_flow_ops.group(train_op, *centered_bias_step)
return train_op
def _eval_metric_ops(self, features, labels, logits):
"""Returns a dict of metric ops keyed by name."""
labels = _check_labels(labels, self._label_name)
predictions = self._predictions(logits)
return estimator._make_metrics_ops( # pylint: disable=protected-access
self._default_metrics(), features, labels, predictions)
def _predictions(self, logits):
"""Returns a dict of predictions.
Args:
logits: logits `Tensor` before applying possible centered bias.
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
"""
if self._enable_centered_bias:
logits = nn.bias_add(logits, _centered_bias(
self.logits_dimension,
self._centered_bias_weight_collection))
return self._logits_to_predictions(logits)
def _logits_to_predictions(self, logits):
"""Returns a dict of predictions.
Args:
logits: logits `Tensor` after applying possible centered bias.
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
"""
predictions = {prediction_key.PredictionKey.LOGITS: logits}
predictions[prediction_key.PredictionKey.PROBABILITIES] = nn.softmax(
logits)
predictions[prediction_key.PredictionKey.CLASSES] = math_ops.argmax(
logits, 1)
return predictions
def _signature_fn(self):
"""Returns the signature_fn to be used in exporting."""
def _classification_signature_fn(examples, unused_features, predictions):
"""Servo signature function."""
if isinstance(predictions, dict):
default_signature = exporter.classification_signature(
input_tensor=examples,
classes_tensor=predictions[prediction_key.PredictionKey.CLASSES],
scores_tensor=predictions[
prediction_key.PredictionKey.PROBABILITIES])
else:
default_signature = exporter.classification_signature(
input_tensor=examples,
scores_tensor=predictions)
# TODO(zakaria): add validation
return default_signature, {}
return _classification_signature_fn
def _default_metrics(self):
"""Returns a dict of `MetricSpec` objects keyed by name."""
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._loss_fn,
prediction_key.PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics[_head_prefixed(self._head_name, metric_key.MetricKey.ACCURACY)] = (
metric_spec.MetricSpec(metrics_lib.streaming_accuracy,
prediction_key.PredictionKey.CLASSES,
self._label_name,
self._weight_column_name))
# TODO(b/32953199): Add multiclass metrics.
return metrics
def _check_labels(labels, label_name):
labels = labels[label_name] if isinstance(labels, dict) else labels
if isinstance(labels, sparse_tensor.SparseTensor):
......@@ -645,12 +850,12 @@ def _check_labels(labels, label_name):
return labels
class _BinarySvmHead(_MultiClassHead):
class _BinarySvmHead(_BinaryLogisticHead):
"""_Head for binary classification using SVMs."""
def __init__(self, label_name, weight_column_name, enable_centered_bias,
head_name, thresholds):
def loss_fn(logits, labels):
def _loss_fn(logits, labels):
check_shape_op = control_flow_ops.Assert(
math_ops.less_equal(array_ops.rank(labels), 2),
["labels shape should be either [batch_size, 1] or [batch_size]"])
......@@ -660,9 +865,7 @@ class _BinarySvmHead(_MultiClassHead):
return losses.hinge_loss(logits, labels)
super(_BinarySvmHead, self).__init__(
train_loss_fn=loss_fn,
eval_loss_fn=loss_fn,
n_classes=2,
loss_fn=_loss_fn,
label_name=label_name,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias,
......@@ -683,7 +886,7 @@ class _BinarySvmHead(_MultiClassHead):
"""See `_MultiClassHead`."""
metrics = {_head_prefixed(self._head_name, metric_key.MetricKey.LOSS):
_weighted_average_loss_metric_spec(
self._eval_loss_fn,
self._loss_fn,
prediction_key.PredictionKey.LOGITS,
self._label_name,
self._weight_column_name)}
......@@ -821,27 +1024,6 @@ def _mean_squared_loss(logits, labels):
return math_ops.square(logits - math_ops.to_float(labels))
def _log_loss_with_two_classes(logits, labels):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if len(labels.get_shape()) == 1:
labels = array_ops.expand_dims(labels, dim=[1])
loss_vec = nn.sigmoid_cross_entropy_with_logits(logits,
math_ops.to_float(labels))
return loss_vec
def _softmax_cross_entropy_loss(logits, labels):
# Check that we got integer for classification.
if not labels.dtype.is_integer:
raise ValueError("Labels dtype should be integer "
"Instead got %s." % labels.dtype)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
if len(labels.get_shape()) == 2:
labels = array_ops.squeeze(labels, squeeze_dims=[1])
loss_vec = nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
return loss_vec
def _sigmoid_cross_entropy_loss(logits, labels):
# sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
return nn.sigmoid_cross_entropy_with_logits(logits, math_ops.to_float(labels))
......
......@@ -131,13 +131,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
_noop_train_op, logits=logits)
self.assertAlmostEqual(.15514446, sess.run(model_fn_ops.loss))
def testMultiClassWithInvalidNClass(self):
try:
head_lib._multi_class_head(n_classes=1)
self.fail("Softmax with no n_classes did not raise error.")
except ValueError:
# Expected
pass
def testInvalidNClasses(self):
for n_classes in (None, -1, 0, 1):
with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
head_lib._multi_class_head(n_classes=n_classes)
class BinarySvmModelHeadTest(tf.test.TestCase):
......
......@@ -196,14 +196,17 @@ def sdca_model_fn(features, labels, mode, params):
if not isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
raise ValueError("Optimizer must be of type SDCAOptimizer")
if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access
# pylint: disable=protected-access
if isinstance(head, head_lib._BinarySvmHead):
loss_type = "hinge_loss"
elif isinstance(head, head_lib._MultiClassHead): # pylint: disable=protected-access
elif isinstance(
head, (head_lib._MultiClassHead, head_lib._BinaryLogisticHead)):
loss_type = "logistic_loss"
elif isinstance(head, head_lib._RegressionHead): # pylint: disable=protected-access
elif isinstance(head, head_lib._RegressionHead):
loss_type = "squared_loss"
else:
return ValueError("Unsupported head type: {}".format(head))
raise ValueError("Unsupported head type: {}".format(head))
# pylint: enable=protected-access
parent_scope = "linear"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册