Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
5ed1de77
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5ed1de77
编写于
10月 19, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
10月 19, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Automated rollback of change 136502135
Change: 136641403
上级
ad3bcda5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
13 addition
and
37 deletion
+13
-37
tensorflow/contrib/learn/python/learn/estimators/estimator.py
...orflow/contrib/learn/python/learn/estimators/estimator.py
+6
-31
tensorflow/contrib/learn/python/learn/graph_actions.py
tensorflow/contrib/learn/python/learn/graph_actions.py
+7
-6
未找到文件。
tensorflow/contrib/learn/python/learn/estimators/estimator.py
浏览文件 @
5ed1de77
...
...
@@ -216,12 +216,8 @@ def _make_metrics_ops(metrics, features, targets, predictions):
predictions.
Returns:
`dict` whose keys are summary names, and values are the result of the
metric, either:
- `Tensor` values (in which case only the result of the last eval batch
will be summarized).
- `tuple` of 2 `Tensor` objects, update op and value. The update op will
be run once each eval step, and the value written to summary.
A dict mapping the friendly given in `metrics` to the result of calling the
given metric function.
Raises:
ValueError: If metrics specifications do not work with the type of
...
...
@@ -271,13 +267,6 @@ def _make_metrics_ops(metrics, features, targets, predictions):
return
result
def
_maybe_add_streaming_mean
(
result
,
key
,
value
):
if
key
in
result
:
logging
.
warning
(
'Metrics already contains %s, skipping.'
,
key
)
return
result
[
key
]
=
metrics_lib
.
streaming_mean
(
value
)
class
BaseEstimator
(
sklearn
.
BaseEstimator
,
evaluable
.
Evaluable
,
trainable
.
Trainable
):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
...
...
@@ -585,7 +574,7 @@ class BaseEstimator(
Args:
features: `Tensor` or `dict` of `Tensor` objects.
targets: `Tensor` or `dict` of `Tensor` objects.
metrics: Dict of metrics to run. If
`None`
, the default metric functions
metrics: Dict of metrics to run. If
None
, the default metric functions
are used; if {}, no metrics are used. Otherwise, `metrics` should map
friendly names for the metric to a `MetricSpec` object defining which
model outputs to evaluate against which targets with which metric
...
...
@@ -1055,28 +1044,14 @@ class Estimator(BaseEstimator):
`../metric_spec.py`.
Returns:
`dict` whose keys are summary names, and values are either:
- `Tensor` values (in which case only the result of the last eval batch
will be summarized).
- `tuple` of 2 `Tensor` objects, update op and value. The update op will
be run once each eval step, and the value written to summary.
metrics: `dict` of `Tensor` objects.
Raises:
ValueError: if `metrics` don't match `targets`.
"""
predictions
,
loss
,
_
=
self
.
_call_model_fn
(
features
,
targets
,
ModeKeys
.
EVAL
)
result
=
_make_metrics_ops
(
metrics
,
features
,
targets
,
predictions
)
_maybe_add_streaming_mean
(
result
,
'loss'
,
loss
)
# TODO(ptucker): Work-around until we have an easier way to specify metrics
# from model_fn.
if
predictions
is
not
None
:
if
isinstance
(
predictions
,
dict
):
for
k
,
v
in
six
.
iteritems
(
predictions
):
_maybe_add_streaming_mean
(
result
,
k
,
v
)
else
:
_maybe_add_streaming_mean
(
result
,
'predictions'
,
predictions
)
result
=
{
'loss'
:
metrics_lib
.
streaming_mean
(
loss
)}
result
.
update
(
_make_metrics_ops
(
metrics
,
features
,
targets
,
predictions
))
return
result
def
_get_predict_ops
(
self
,
features
):
...
...
tensorflow/contrib/learn/python/learn/graph_actions.py
浏览文件 @
5ed1de77
...
...
@@ -25,7 +25,8 @@ import threading
import
time
import
numpy
as
np
import
six
from
six
import
reraise
from
tensorflow.contrib.framework
import
load_variable
from
tensorflow.contrib.framework.python.ops
import
ops
as
contrib_ops
...
...
@@ -578,7 +579,7 @@ def _train_internal(graph,
logging
.
error
(
'Got exception during tf.learn final checkpoint %s.'
,
e
)
finally
:
if
excinfo
:
six
.
reraise
(
*
excinfo
)
reraise
(
*
excinfo
)
return
loss_value
...
...
@@ -628,14 +629,14 @@ def _write_summary_results(output_dir, eval_results, current_global_step):
_eval_results_to_str
(
eval_results
))
summary_writer
=
get_summary_writer
(
output_dir
)
summary
=
summary_pb2
.
Summary
()
for
key
,
eval_result
in
six
.
iteritems
(
eval_results
)
:
for
key
in
eval_results
:
if
eval_results
[
key
]
is
None
:
continue
value
=
summary
.
value
.
add
()
value
.
tag
=
key
if
(
isinstance
(
eval_result
,
np
.
float32
)
or
isinstance
(
eval_result
,
float
)):
value
.
simple_value
=
float
(
eval_result
)
if
(
isinstance
(
eval_result
s
[
key
]
,
np
.
float32
)
or
isinstance
(
eval_result
s
[
key
]
,
float
)):
value
.
simple_value
=
float
(
eval_result
s
[
key
]
)
else
:
logging
.
warn
(
'Skipping summary for %s, must be a float or np.float32.'
,
key
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录