Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
8eff2d62
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8eff2d62
编写于
11月 22, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
11月 22, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Split _BinaryLogisticHead from _MultiClassHead.
Change: 139971064
上级
c348cead
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
285 addition
and
103 deletion
+285
-103
tensorflow/contrib/learn/python/learn/estimators/head.py
tensorflow/contrib/learn/python/learn/estimators/head.py
+274
-92
tensorflow/contrib/learn/python/learn/estimators/head_test.py
...orflow/contrib/learn/python/learn/estimators/head_test.py
+4
-7
tensorflow/contrib/learn/python/learn/estimators/linear.py
tensorflow/contrib/learn/python/learn/estimators/linear.py
+7
-4
未找到文件。
tensorflow/contrib/learn/python/learn/estimators/head.py
浏览文件 @
8eff2d62
...
@@ -102,16 +102,18 @@ def _multi_class_head(n_classes, label_name=None, weight_column_name=None,
...
@@ -102,16 +102,18 @@ def _multi_class_head(n_classes, label_name=None, weight_column_name=None,
Raises:
Raises:
ValueError: if n_classes is < 2
ValueError: if n_classes is < 2
"""
"""
if
n_classes
<
2
:
if
(
n_classes
is
None
)
or
(
n_classes
<
2
):
raise
ValueError
(
"n_classes must be > 1 for classification."
)
raise
ValueError
(
"n_classes must be > 1 for classification: %s."
%
n_classes
)
if
n_classes
==
2
:
if
n_classes
==
2
:
loss_fn
=
_log_loss_with_two_classes
return
_BinaryLogisticHead
(
label_name
=
label_name
,
else
:
weight_column_name
=
weight_column_name
,
loss_fn
=
_softmax_cross_entropy_loss
enable_centered_bias
=
enable_centered_bias
,
return
_MultiClassHead
(
train_loss_fn
=
loss_fn
,
head_name
=
head_name
,
eval_loss_fn
=
loss_fn
,
thresholds
=
thresholds
)
n_classes
=
n_classes
,
return
_MultiClassHead
(
n_classes
=
n_classes
,
label_name
=
label_name
,
label_name
=
label_name
,
weight_column_name
=
weight_column_name
,
weight_column_name
=
weight_column_name
,
enable_centered_bias
=
enable_centered_bias
,
enable_centered_bias
=
enable_centered_bias
,
...
@@ -268,7 +270,6 @@ class _RegressionHead(_Head):
...
@@ -268,7 +270,6 @@ class _RegressionHead(_Head):
"""See `_Head`."""
"""See `_Head`."""
_check_mode_valid
(
mode
)
_check_mode_valid
(
mode
)
_check_logits_input_not_supported
(
logits
,
logits_input
)
_check_logits_input_not_supported
(
logits
,
logits_input
)
predictions
=
self
.
_predictions
(
logits
)
if
(
mode
==
model_fn
.
ModeKeys
.
INFER
)
or
(
labels
is
None
):
if
(
mode
==
model_fn
.
ModeKeys
.
INFER
)
or
(
labels
is
None
):
loss
=
None
loss
=
None
train_op
=
None
train_op
=
None
...
@@ -278,15 +279,14 @@ class _RegressionHead(_Head):
...
@@ -278,15 +279,14 @@ class _RegressionHead(_Head):
train_op
=
(
None
if
train_op_fn
is
None
or
mode
==
model_fn
.
ModeKeys
.
EVAL
train_op
=
(
None
if
train_op_fn
is
None
or
mode
==
model_fn
.
ModeKeys
.
EVAL
else
self
.
_train_op
(
features
,
labels
,
train_op_fn
,
logits
))
else
self
.
_train_op
(
features
,
labels
,
train_op_fn
,
logits
))
eval_metric_ops
=
self
.
_eval_metric_ops
(
features
,
labels
,
logits
)
eval_metric_ops
=
self
.
_eval_metric_ops
(
features
,
labels
,
logits
)
signature_fn
=
self
.
_signature_fn
()
return
model_fn
.
ModelFnOps
(
return
model_fn
.
ModelFnOps
(
mode
=
mode
,
mode
=
mode
,
predictions
=
predictions
,
predictions
=
self
.
_predictions
(
logits
)
,
loss
=
loss
,
loss
=
loss
,
train_op
=
train_op
,
train_op
=
train_op
,
eval_metric_ops
=
eval_metric_ops
,
eval_metric_ops
=
eval_metric_ops
,
signature_fn
=
s
ignature_fn
)
signature_fn
=
s
elf
.
_signature_fn
()
)
def
_training_loss
(
self
,
features
,
labels
,
logits
,
name
=
"training_loss"
):
def
_training_loss
(
self
,
features
,
labels
,
logits
,
name
=
"training_loss"
):
"""Returns training loss tensor for this head.
"""Returns training loss tensor for this head.
...
@@ -403,18 +403,22 @@ class _RegressionHead(_Head):
...
@@ -403,18 +403,22 @@ class _RegressionHead(_Head):
self
.
_weight_column_name
)}
self
.
_weight_column_name
)}
class
_MultiClassHead
(
_Head
):
def
_log_loss_with_two_classes
(
logits
,
labels
):
"""_Head for classification."""
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if
len
(
labels
.
get_shape
())
==
1
:
labels
=
array_ops
.
expand_dims
(
labels
,
dim
=
[
1
])
return
nn
.
sigmoid_cross_entropy_with_logits
(
logits
,
math_ops
.
to_float
(
labels
))
def
__init__
(
self
,
train_loss_fn
,
eval_loss_fn
,
n_classes
,
label_name
,
weight_column_name
,
enable_centered_bias
,
head_name
,
class
_BinaryLogisticHead
(
_Head
):
thresholds
=
None
):
"""_Head for binary logistic classifciation."""
def
__init__
(
self
,
label_name
,
weight_column_name
,
enable_centered_bias
,
head_name
,
loss_fn
=
_log_loss_with_two_classes
,
thresholds
=
None
):
"""Base type for all single heads.
"""Base type for all single heads.
Args:
Args:
train_loss_fn: loss_fn for training.
eval_loss_fn: loss_fn for eval.
n_classes: number of classes.
label_name: String, name of the key in label dict. Can be null if label
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
weight_column_name: A string defining feature column name representing
...
@@ -425,53 +429,47 @@ class _MultiClassHead(_Head):
...
@@ -425,53 +429,47 @@ class _MultiClassHead(_Head):
residual after centered bias.
residual after centered bias.
head_name: name of the head. If provided, predictions, summary and metrics
head_name: name of the head. If provided, predictions, summary and metrics
keys will be prefixed by the head_name and an underscore.
keys will be prefixed by the head_name and an underscore.
loss_fn: Loss function.
thresholds: thresholds for eval.
thresholds: thresholds for eval.
Raises:
Raises:
ValueError: if n_classes is invalid.
ValueError: if n_classes is invalid.
"""
"""
if
n_classes
<
2
:
raise
ValueError
(
"n_classes must be >= 2"
)
self
.
_thresholds
=
thresholds
if
thresholds
else
[.
5
]
self
.
_thresholds
=
thresholds
if
thresholds
else
[.
5
]
self
.
_train_loss_fn
=
train_loss_fn
self
.
_eval_loss_fn
=
eval_loss_fn
self
.
_logits_dimension
=
1
if
n_classes
==
2
else
n_classes
self
.
_label_name
=
label_name
self
.
_label_name
=
label_name
self
.
_weight_column_name
=
weight_column_name
self
.
_weight_column_name
=
weight_column_name
self
.
_head_name
=
head_name
self
.
_head_name
=
head_name
self
.
_loss_fn
=
loss_fn
self
.
_enable_centered_bias
=
enable_centered_bias
self
.
_enable_centered_bias
=
enable_centered_bias
self
.
_centered_bias_weight_collection
=
_head_prefixed
(
head_name
,
self
.
_centered_bias_weight_collection
=
_head_prefixed
(
head_name
,
"centered_bias"
)
"centered_bias"
)
@
property
@
property
def
logits_dimension
(
self
):
def
logits_dimension
(
self
):
return
self
.
_logits_dimension
return
1
def
head_ops
(
self
,
features
,
labels
,
mode
,
train_op_fn
,
logits
=
None
,
def
head_ops
(
self
,
features
,
labels
,
mode
,
train_op_fn
,
logits
=
None
,
logits_input
=
None
,
scope
=
None
):
logits_input
=
None
):
"""See `_Head`."""
"""See `_Head`."""
_check_mode_valid
(
mode
)
_check_mode_valid
(
mode
)
_check_logits_input_not_supported
(
logits
,
logits_input
)
_check_logits_input_not_supported
(
logits
,
logits_input
)
predictions
=
self
.
_predictions
(
logits
)
if
(
mode
==
model_fn
.
ModeKeys
.
INFER
)
or
(
labels
is
None
):
if
(
mode
==
model_fn
.
ModeKeys
.
INFER
)
or
(
labels
is
None
):
loss
=
None
loss
=
None
train_op
=
None
train_op
=
None
eval_metric_ops
=
None
eval_metric_ops
=
None
else
:
else
:
loss
=
self
.
_training_loss
(
features
,
labels
,
logits
)
loss
=
self
.
_training_loss
(
features
,
labels
,
logits
)
train_op
=
(
None
if
train_op_fn
is
None
or
mode
==
model_fn
.
ModeKeys
.
EVAL
train_op
=
(
None
if
train_op_fn
is
None
else
self
.
_train_op
(
features
,
labels
,
train_op_fn
,
logits
))
else
self
.
_train_op
(
features
,
labels
,
train_op_fn
,
logits
))
eval_metric_ops
=
self
.
_eval_metric_ops
(
features
,
labels
,
logits
)
eval_metric_ops
=
self
.
_eval_metric_ops
(
features
,
labels
,
logits
)
signature_fn
=
self
.
_signature_fn
()
return
model_fn
.
ModelFnOps
(
return
model_fn
.
ModelFnOps
(
mode
=
mode
,
mode
=
mode
,
predictions
=
predictions
,
predictions
=
self
.
_predictions
(
logits
)
,
loss
=
loss
,
loss
=
loss
,
train_op
=
train_op
,
train_op
=
train_op
,
eval_metric_ops
=
eval_metric_ops
,
eval_metric_ops
=
eval_metric_ops
,
signature_fn
=
s
ignature_fn
)
signature_fn
=
s
elf
.
_signature_fn
()
)
def
_training_loss
(
self
,
features
,
labels
,
logits
=
None
,
name
=
"training_loss"
):
def
_training_loss
(
self
,
features
,
labels
,
logits
=
None
,
name
=
"training_loss"
):
"""Returns training loss tensor for this head.
"""Returns training loss tensor for this head.
...
@@ -492,7 +490,7 @@ class _MultiClassHead(_Head):
...
@@ -492,7 +490,7 @@ class _MultiClassHead(_Head):
name: Op name.
name: Op name.
Returns:
Returns:
A loss `
Tensor
`.
A loss `
Output
`.
"""
"""
labels
=
_check_labels
(
labels
,
self
.
_label_name
)
labels
=
_check_labels
(
labels
,
self
.
_label_name
)
...
@@ -501,7 +499,7 @@ class _MultiClassHead(_Head):
...
@@ -501,7 +499,7 @@ class _MultiClassHead(_Head):
self
.
logits_dimension
,
self
.
logits_dimension
,
self
.
_centered_bias_weight_collection
))
self
.
_centered_bias_weight_collection
))
loss_unweighted
=
self
.
_
train_
loss_fn
(
logits
,
labels
)
loss_unweighted
=
self
.
_loss_fn
(
logits
,
labels
)
loss
,
weighted_average_loss
=
_loss
(
loss
,
weighted_average_loss
=
_loss
(
loss_unweighted
,
loss_unweighted
,
_weight_tensor
(
features
,
self
.
_weight_column_name
),
_weight_tensor
(
features
,
self
.
_weight_column_name
),
...
@@ -520,7 +518,7 @@ class _MultiClassHead(_Head):
...
@@ -520,7 +518,7 @@ class _MultiClassHead(_Head):
self
.
logits_dimension
,
self
.
logits_dimension
,
self
.
_centered_bias_weight_collection
,
self
.
_centered_bias_weight_collection
,
labels
,
labels
,
self
.
_
train_
loss_fn
)]
self
.
_loss_fn
)]
train_op
=
control_flow_ops
.
group
(
train_op
,
*
centered_bias_step
)
train_op
=
control_flow_ops
.
group
(
train_op
,
*
centered_bias_step
)
return
train_op
return
train_op
...
@@ -536,10 +534,10 @@ class _MultiClassHead(_Head):
...
@@ -536,10 +534,10 @@ class _MultiClassHead(_Head):
"""Returns a dict of predictions.
"""Returns a dict of predictions.
Args:
Args:
logits: logits `
Tensor
` before applying possible centered bias.
logits: logits `
Output
` before applying possible centered bias.
Returns:
Returns:
Dict of prediction `
Tensor
` keyed by `PredictionKey`.
Dict of prediction `
Output
` keyed by `PredictionKey`.
"""
"""
if
self
.
_enable_centered_bias
:
if
self
.
_enable_centered_bias
:
logits
=
nn
.
bias_add
(
logits
,
_centered_bias
(
logits
=
nn
.
bias_add
(
logits
,
_centered_bias
(
...
@@ -551,13 +549,12 @@ class _MultiClassHead(_Head):
...
@@ -551,13 +549,12 @@ class _MultiClassHead(_Head):
"""Returns a dict of predictions.
"""Returns a dict of predictions.
Args:
Args:
logits: logits `
Tensor
` after applying possible centered bias.
logits: logits `
Output
` after applying possible centered bias.
Returns:
Returns:
Dict of prediction `
Tensor
` keyed by `PredictionKey`.
Dict of prediction `
Output
` keyed by `PredictionKey`.
"""
"""
predictions
=
{
prediction_key
.
PredictionKey
.
LOGITS
:
logits
}
predictions
=
{
prediction_key
.
PredictionKey
.
LOGITS
:
logits
}
if
self
.
logits_dimension
==
1
:
predictions
[
prediction_key
.
PredictionKey
.
LOGISTIC
]
=
math_ops
.
sigmoid
(
predictions
[
prediction_key
.
PredictionKey
.
LOGISTIC
]
=
math_ops
.
sigmoid
(
logits
)
logits
)
logits
=
array_ops
.
concat
(
1
,
[
array_ops
.
zeros_like
(
logits
),
logits
])
logits
=
array_ops
.
concat
(
1
,
[
array_ops
.
zeros_like
(
logits
),
logits
])
...
@@ -565,7 +562,6 @@ class _MultiClassHead(_Head):
...
@@ -565,7 +562,6 @@ class _MultiClassHead(_Head):
logits
)
logits
)
predictions
[
prediction_key
.
PredictionKey
.
CLASSES
]
=
math_ops
.
argmax
(
predictions
[
prediction_key
.
PredictionKey
.
CLASSES
]
=
math_ops
.
argmax
(
logits
,
1
)
logits
,
1
)
return
predictions
return
predictions
def
_signature_fn
(
self
):
def
_signature_fn
(
self
):
...
@@ -591,7 +587,7 @@ class _MultiClassHead(_Head):
...
@@ -591,7 +587,7 @@ class _MultiClassHead(_Head):
"""Returns a dict of `MetricSpec` objects keyed by name."""
"""Returns a dict of `MetricSpec` objects keyed by name."""
metrics
=
{
_head_prefixed
(
self
.
_head_name
,
metric_key
.
MetricKey
.
LOSS
):
metrics
=
{
_head_prefixed
(
self
.
_head_name
,
metric_key
.
MetricKey
.
LOSS
):
_weighted_average_loss_metric_spec
(
_weighted_average_loss_metric_spec
(
self
.
_
eval_
loss_fn
,
self
.
_loss_fn
,
prediction_key
.
PredictionKey
.
LOGITS
,
prediction_key
.
PredictionKey
.
LOGITS
,
self
.
_label_name
,
self
.
_label_name
,
self
.
_weight_column_name
)}
self
.
_weight_column_name
)}
...
@@ -603,7 +599,6 @@ class _MultiClassHead(_Head):
...
@@ -603,7 +599,6 @@ class _MultiClassHead(_Head):
prediction_key
.
PredictionKey
.
CLASSES
,
prediction_key
.
PredictionKey
.
CLASSES
,
self
.
_label_name
,
self
.
_label_name
,
self
.
_weight_column_name
))
self
.
_weight_column_name
))
if
self
.
logits_dimension
==
1
:
def
_add_binary_metric
(
key
,
metric_fn
):
def
_add_binary_metric
(
key
,
metric_fn
):
metrics
[
_head_prefixed
(
self
.
_head_name
,
key
)]
=
(
metrics
[
_head_prefixed
(
self
.
_head_name
,
key
)]
=
(
metric_spec
.
MetricSpec
(
metric_fn
,
metric_spec
.
MetricSpec
(
metric_fn
,
...
@@ -638,6 +633,216 @@ class _MultiClassHead(_Head):
...
@@ -638,6 +633,216 @@ class _MultiClassHead(_Head):
return
metrics
return
metrics
def
_softmax_cross_entropy_loss
(
logits
,
labels
):
# Check that we got integer for classification.
if
not
labels
.
dtype
.
is_integer
:
raise
ValueError
(
"Labels dtype should be integer "
"Instead got %s."
%
labels
.
dtype
)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
if
len
(
labels
.
get_shape
())
==
2
:
labels
=
array_ops
.
squeeze
(
labels
,
squeeze_dims
=
[
1
])
return
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
labels
)
class
_MultiClassHead
(
_Head
):
"""_Head for classification."""
def
__init__
(
self
,
n_classes
,
label_name
,
weight_column_name
,
enable_centered_bias
,
head_name
,
loss_fn
=
_softmax_cross_entropy_loss
,
thresholds
=
None
):
"""Base type for all single heads.
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHead`).
label_name: String, name of the key in label dict. Can be null if label
is a tensor (single headed models).
weight_column_name: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
enable_centered_bias: A bool. If True, estimator will learn a centered
bias variable for each class. Rest of the model structure learns the
residual after centered bias.
head_name: name of the head. If provided, predictions, summary and metrics
keys will be prefixed by the head_name and an underscore.
loss_fn: Loss function.
thresholds: thresholds for eval.
Raises:
ValueError: if n_classes is invalid.
"""
if
(
n_classes
is
None
)
or
(
n_classes
<=
2
):
raise
ValueError
(
"n_classes must be > 2: %s."
%
n_classes
)
self
.
_thresholds
=
thresholds
if
thresholds
else
[.
5
]
self
.
_logits_dimension
=
n_classes
self
.
_label_name
=
label_name
self
.
_weight_column_name
=
weight_column_name
self
.
_head_name
=
head_name
self
.
_loss_fn
=
loss_fn
self
.
_enable_centered_bias
=
enable_centered_bias
self
.
_centered_bias_weight_collection
=
_head_prefixed
(
head_name
,
"centered_bias"
)
@
property
def
logits_dimension
(
self
):
return
self
.
_logits_dimension
def
head_ops
(
self
,
features
,
labels
,
mode
,
train_op_fn
,
logits
=
None
,
logits_input
=
None
,
scope
=
None
):
"""See `_Head`."""
_check_mode_valid
(
mode
)
_check_logits_input_not_supported
(
logits
,
logits_input
)
if
(
mode
==
model_fn
.
ModeKeys
.
INFER
)
or
(
labels
is
None
):
loss
=
None
train_op
=
None
eval_metric_ops
=
None
else
:
loss
=
self
.
_training_loss
(
features
,
labels
,
logits
)
train_op
=
(
None
if
train_op_fn
is
None
or
mode
==
model_fn
.
ModeKeys
.
EVAL
else
self
.
_train_op
(
features
,
labels
,
train_op_fn
,
logits
))
eval_metric_ops
=
self
.
_eval_metric_ops
(
features
,
labels
,
logits
)
return
model_fn
.
ModelFnOps
(
mode
=
mode
,
predictions
=
self
.
_predictions
(
logits
),
loss
=
loss
,
train_op
=
train_op
,
eval_metric_ops
=
eval_metric_ops
,
signature_fn
=
self
.
_signature_fn
())
def
_training_loss
(
self
,
features
,
labels
,
logits
=
None
,
name
=
"training_loss"
):
"""Returns training loss tensor for this head.
Training loss is different from the loss reported on the tensorboard as we
should respect the example weights when computing the gradient.
L = sum_{i} w_{i} * l_{i} / B
where B is the number of examples in the batch, l_{i}, w_{i} are individual
losses, and example weight.
Args:
features: features dict.
labels: either a tensor for labels or in multihead case, a dict of string
to labels tensor.
logits: logits, a float tensor.
name: Op name.
Returns:
A loss `Tensor`.
"""
labels
=
_check_labels
(
labels
,
self
.
_label_name
)
if
self
.
_enable_centered_bias
:
logits
=
nn
.
bias_add
(
logits
,
_centered_bias
(
self
.
logits_dimension
,
self
.
_centered_bias_weight_collection
))
loss_unweighted
=
self
.
_loss_fn
(
logits
,
labels
)
loss
,
weighted_average_loss
=
_loss
(
loss_unweighted
,
_weight_tensor
(
features
,
self
.
_weight_column_name
),
name
=
name
)
summary
.
scalar
(
_head_prefixed
(
self
.
_head_name
,
"loss"
),
weighted_average_loss
)
return
loss
def
_train_op
(
self
,
features
,
labels
,
train_op_fn
,
logits
):
"""Returns op for the training step."""
loss
=
self
.
_training_loss
(
features
,
labels
,
logits
)
train_op
=
train_op_fn
(
loss
)
if
self
.
_enable_centered_bias
:
centered_bias_step
=
[
_centered_bias_step
(
self
.
logits_dimension
,
self
.
_centered_bias_weight_collection
,
labels
,
self
.
_loss_fn
)]
train_op
=
control_flow_ops
.
group
(
train_op
,
*
centered_bias_step
)
return
train_op
def
_eval_metric_ops
(
self
,
features
,
labels
,
logits
):
"""Returns a dict of metric ops keyed by name."""
labels
=
_check_labels
(
labels
,
self
.
_label_name
)
predictions
=
self
.
_predictions
(
logits
)
return
estimator
.
_make_metrics_ops
(
# pylint: disable=protected-access
self
.
_default_metrics
(),
features
,
labels
,
predictions
)
def
_predictions
(
self
,
logits
):
"""Returns a dict of predictions.
Args:
logits: logits `Tensor` before applying possible centered bias.
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
"""
if
self
.
_enable_centered_bias
:
logits
=
nn
.
bias_add
(
logits
,
_centered_bias
(
self
.
logits_dimension
,
self
.
_centered_bias_weight_collection
))
return
self
.
_logits_to_predictions
(
logits
)
def
_logits_to_predictions
(
self
,
logits
):
"""Returns a dict of predictions.
Args:
logits: logits `Tensor` after applying possible centered bias.
Returns:
Dict of prediction `Tensor` keyed by `PredictionKey`.
"""
predictions
=
{
prediction_key
.
PredictionKey
.
LOGITS
:
logits
}
predictions
[
prediction_key
.
PredictionKey
.
PROBABILITIES
]
=
nn
.
softmax
(
logits
)
predictions
[
prediction_key
.
PredictionKey
.
CLASSES
]
=
math_ops
.
argmax
(
logits
,
1
)
return
predictions
def
_signature_fn
(
self
):
"""Returns the signature_fn to be used in exporting."""
def
_classification_signature_fn
(
examples
,
unused_features
,
predictions
):
"""Servo signature function."""
if
isinstance
(
predictions
,
dict
):
default_signature
=
exporter
.
classification_signature
(
input_tensor
=
examples
,
classes_tensor
=
predictions
[
prediction_key
.
PredictionKey
.
CLASSES
],
scores_tensor
=
predictions
[
prediction_key
.
PredictionKey
.
PROBABILITIES
])
else
:
default_signature
=
exporter
.
classification_signature
(
input_tensor
=
examples
,
scores_tensor
=
predictions
)
# TODO(zakaria): add validation
return
default_signature
,
{}
return
_classification_signature_fn
def
_default_metrics
(
self
):
"""Returns a dict of `MetricSpec` objects keyed by name."""
metrics
=
{
_head_prefixed
(
self
.
_head_name
,
metric_key
.
MetricKey
.
LOSS
):
_weighted_average_loss_metric_spec
(
self
.
_loss_fn
,
prediction_key
.
PredictionKey
.
LOGITS
,
self
.
_label_name
,
self
.
_weight_column_name
)}
# TODO(b/29366811): This currently results in both an "accuracy" and an
# "accuracy/threshold_0.500000_mean" metric for binary classification.
metrics
[
_head_prefixed
(
self
.
_head_name
,
metric_key
.
MetricKey
.
ACCURACY
)]
=
(
metric_spec
.
MetricSpec
(
metrics_lib
.
streaming_accuracy
,
prediction_key
.
PredictionKey
.
CLASSES
,
self
.
_label_name
,
self
.
_weight_column_name
))
# TODO(b/32953199): Add multiclass metrics.
return
metrics
def
_check_labels
(
labels
,
label_name
):
def
_check_labels
(
labels
,
label_name
):
labels
=
labels
[
label_name
]
if
isinstance
(
labels
,
dict
)
else
labels
labels
=
labels
[
label_name
]
if
isinstance
(
labels
,
dict
)
else
labels
if
isinstance
(
labels
,
sparse_tensor
.
SparseTensor
):
if
isinstance
(
labels
,
sparse_tensor
.
SparseTensor
):
...
@@ -645,12 +850,12 @@ def _check_labels(labels, label_name):
...
@@ -645,12 +850,12 @@ def _check_labels(labels, label_name):
return
labels
return
labels
class
_BinarySvmHead
(
_
MultiClass
Head
):
class
_BinarySvmHead
(
_
BinaryLogistic
Head
):
"""_Head for binary classification using SVMs."""
"""_Head for binary classification using SVMs."""
def
__init__
(
self
,
label_name
,
weight_column_name
,
enable_centered_bias
,
def
__init__
(
self
,
label_name
,
weight_column_name
,
enable_centered_bias
,
head_name
,
thresholds
):
head_name
,
thresholds
):
def
loss_fn
(
logits
,
labels
):
def
_
loss_fn
(
logits
,
labels
):
check_shape_op
=
control_flow_ops
.
Assert
(
check_shape_op
=
control_flow_ops
.
Assert
(
math_ops
.
less_equal
(
array_ops
.
rank
(
labels
),
2
),
math_ops
.
less_equal
(
array_ops
.
rank
(
labels
),
2
),
[
"labels shape should be either [batch_size, 1] or [batch_size]"
])
[
"labels shape should be either [batch_size, 1] or [batch_size]"
])
...
@@ -660,9 +865,7 @@ class _BinarySvmHead(_MultiClassHead):
...
@@ -660,9 +865,7 @@ class _BinarySvmHead(_MultiClassHead):
return
losses
.
hinge_loss
(
logits
,
labels
)
return
losses
.
hinge_loss
(
logits
,
labels
)
super
(
_BinarySvmHead
,
self
).
__init__
(
super
(
_BinarySvmHead
,
self
).
__init__
(
train_loss_fn
=
loss_fn
,
loss_fn
=
_loss_fn
,
eval_loss_fn
=
loss_fn
,
n_classes
=
2
,
label_name
=
label_name
,
label_name
=
label_name
,
weight_column_name
=
weight_column_name
,
weight_column_name
=
weight_column_name
,
enable_centered_bias
=
enable_centered_bias
,
enable_centered_bias
=
enable_centered_bias
,
...
@@ -683,7 +886,7 @@ class _BinarySvmHead(_MultiClassHead):
...
@@ -683,7 +886,7 @@ class _BinarySvmHead(_MultiClassHead):
"""See `_MultiClassHead`."""
"""See `_MultiClassHead`."""
metrics
=
{
_head_prefixed
(
self
.
_head_name
,
metric_key
.
MetricKey
.
LOSS
):
metrics
=
{
_head_prefixed
(
self
.
_head_name
,
metric_key
.
MetricKey
.
LOSS
):
_weighted_average_loss_metric_spec
(
_weighted_average_loss_metric_spec
(
self
.
_
eval_
loss_fn
,
self
.
_loss_fn
,
prediction_key
.
PredictionKey
.
LOGITS
,
prediction_key
.
PredictionKey
.
LOGITS
,
self
.
_label_name
,
self
.
_label_name
,
self
.
_weight_column_name
)}
self
.
_weight_column_name
)}
...
@@ -821,27 +1024,6 @@ def _mean_squared_loss(logits, labels):
...
@@ -821,27 +1024,6 @@ def _mean_squared_loss(logits, labels):
return
math_ops
.
square
(
logits
-
math_ops
.
to_float
(
labels
))
return
math_ops
.
square
(
logits
-
math_ops
.
to_float
(
labels
))
def
_log_loss_with_two_classes
(
logits
,
labels
):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if
len
(
labels
.
get_shape
())
==
1
:
labels
=
array_ops
.
expand_dims
(
labels
,
dim
=
[
1
])
loss_vec
=
nn
.
sigmoid_cross_entropy_with_logits
(
logits
,
math_ops
.
to_float
(
labels
))
return
loss_vec
def
_softmax_cross_entropy_loss
(
logits
,
labels
):
# Check that we got integer for classification.
if
not
labels
.
dtype
.
is_integer
:
raise
ValueError
(
"Labels dtype should be integer "
"Instead got %s."
%
labels
.
dtype
)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] labels.
if
len
(
labels
.
get_shape
())
==
2
:
labels
=
array_ops
.
squeeze
(
labels
,
squeeze_dims
=
[
1
])
loss_vec
=
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
labels
)
return
loss_vec
def
_sigmoid_cross_entropy_loss
(
logits
,
labels
):
def
_sigmoid_cross_entropy_loss
(
logits
,
labels
):
# sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
# sigmoid_cross_entropy_with_logits requires [batch_size, n_classes] labels.
return
nn
.
sigmoid_cross_entropy_with_logits
(
logits
,
math_ops
.
to_float
(
labels
))
return
nn
.
sigmoid_cross_entropy_with_logits
(
logits
,
math_ops
.
to_float
(
labels
))
...
...
tensorflow/contrib/learn/python/learn/estimators/head_test.py
浏览文件 @
8eff2d62
...
@@ -131,13 +131,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
...
@@ -131,13 +131,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
_noop_train_op
,
logits
=
logits
)
_noop_train_op
,
logits
=
logits
)
self
.
assertAlmostEqual
(.
15514446
,
sess
.
run
(
model_fn_ops
.
loss
))
self
.
assertAlmostEqual
(.
15514446
,
sess
.
run
(
model_fn_ops
.
loss
))
def
testMultiClassWithInvalidNClass
(
self
):
def
testInvalidNClasses
(
self
):
try
:
for
n_classes
in
(
None
,
-
1
,
0
,
1
):
head_lib
.
_multi_class_head
(
n_classes
=
1
)
with
self
.
assertRaisesRegexp
(
ValueError
,
"n_classes must be > 1"
):
self
.
fail
(
"Softmax with no n_classes did not raise error."
)
head_lib
.
_multi_class_head
(
n_classes
=
n_classes
)
except
ValueError
:
# Expected
pass
class
BinarySvmModelHeadTest
(
tf
.
test
.
TestCase
):
class
BinarySvmModelHeadTest
(
tf
.
test
.
TestCase
):
...
...
tensorflow/contrib/learn/python/learn/estimators/linear.py
浏览文件 @
8eff2d62
...
@@ -196,14 +196,17 @@ def sdca_model_fn(features, labels, mode, params):
...
@@ -196,14 +196,17 @@ def sdca_model_fn(features, labels, mode, params):
if
not
isinstance
(
optimizer
,
sdca_optimizer
.
SDCAOptimizer
):
if
not
isinstance
(
optimizer
,
sdca_optimizer
.
SDCAOptimizer
):
raise
ValueError
(
"Optimizer must be of type SDCAOptimizer"
)
raise
ValueError
(
"Optimizer must be of type SDCAOptimizer"
)
if
isinstance
(
head
,
head_lib
.
_BinarySvmHead
):
# pylint: disable=protected-access
# pylint: disable=protected-access
if
isinstance
(
head
,
head_lib
.
_BinarySvmHead
):
loss_type
=
"hinge_loss"
loss_type
=
"hinge_loss"
elif
isinstance
(
head
,
head_lib
.
_MultiClassHead
):
# pylint: disable=protected-access
elif
isinstance
(
head
,
(
head_lib
.
_MultiClassHead
,
head_lib
.
_BinaryLogisticHead
)):
loss_type
=
"logistic_loss"
loss_type
=
"logistic_loss"
elif
isinstance
(
head
,
head_lib
.
_RegressionHead
):
# pylint: disable=protected-access
elif
isinstance
(
head
,
head_lib
.
_RegressionHead
):
loss_type
=
"squared_loss"
loss_type
=
"squared_loss"
else
:
else
:
return
ValueError
(
"Unsupported head type: {}"
.
format
(
head
))
raise
ValueError
(
"Unsupported head type: {}"
.
format
(
head
))
# pylint: enable=protected-access
parent_scope
=
"linear"
parent_scope
=
"linear"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录