Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
62df65c7
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
62df65c7
编写于
10月 20, 2017
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
10月 20, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add dtype argument to Mean and Accuracy object-oriented metrics.
PiperOrigin-RevId: 172957714
上级
29c7b465
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
36 addition
and
11 deletion
+36
-11
tensorflow/contrib/eager/python/metrics_impl.py
tensorflow/contrib/eager/python/metrics_impl.py
+16
-11
tensorflow/contrib/eager/python/metrics_test.py
tensorflow/contrib/eager/python/metrics_test.py
+20
-0
未找到文件。
tensorflow/contrib/eager/python/metrics_impl.py
浏览文件 @
62df65c7
...
...
@@ -198,13 +198,19 @@ class Mean(Metric):
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
# Or defaults to type of the input if it is tf.float32, else tf.float64?
def
build
(
self
,
values
,
weights
=
None
):
del
values
,
weights
# build() does not use call's arguments
def
__init__
(
self
,
name
=
None
,
dtype
=
dtypes
.
float64
):
super
(
Mean
,
self
).
__init__
(
name
=
name
)
self
.
dtype
=
dtype
def
build
(
self
,
*
args
,
**
kwargs
):
# build() does not use call's arguments, by using *args, **kwargs
# we make it easier to inherit from Mean().
del
args
,
kwargs
self
.
numer
=
self
.
add_variable
(
name
=
"numer"
,
shape
=
(),
dtype
=
dtypes
.
float64
,
dtype
=
self
.
dtype
,
initializer
=
init_ops
.
zeros_initializer
)
self
.
denom
=
self
.
add_variable
(
name
=
"denom"
,
shape
=
(),
dtype
=
dtypes
.
float64
,
dtype
=
self
.
dtype
,
initializer
=
init_ops
.
zeros_initializer
)
def
call
(
self
,
values
,
weights
=
None
):
...
...
@@ -219,13 +225,13 @@ class Mean(Metric):
"""
if
weights
is
None
:
self
.
denom
.
assign_add
(
math_ops
.
cast
(
array_ops
.
size
(
values
),
dtypes
.
float64
))
math_ops
.
cast
(
array_ops
.
size
(
values
),
self
.
dtype
))
values
=
math_ops
.
reduce_sum
(
values
)
self
.
numer
.
assign_add
(
math_ops
.
cast
(
values
,
dtypes
.
float64
))
self
.
numer
.
assign_add
(
math_ops
.
cast
(
values
,
self
.
dtype
))
else
:
weights
=
math_ops
.
cast
(
weights
,
dtypes
.
float64
)
weights
=
math_ops
.
cast
(
weights
,
self
.
dtype
)
self
.
denom
.
assign_add
(
math_ops
.
reduce_sum
(
weights
))
values
=
math_ops
.
cast
(
values
,
dtypes
.
float64
)
*
weights
values
=
math_ops
.
cast
(
values
,
self
.
dtype
)
*
weights
self
.
numer
.
assign_add
(
math_ops
.
reduce_sum
(
values
))
def
result
(
self
):
...
...
@@ -235,9 +241,8 @@ class Mean(Metric):
class
Accuracy
(
Mean
):
"""Calculates how often `predictions` matches `labels`."""
def
build
(
self
,
labels
,
predictions
,
weights
=
None
):
del
labels
,
predictions
,
weights
super
(
Accuracy
,
self
).
build
(
None
)
# Arguments are unused
def
__init__
(
self
,
name
=
None
,
dtype
=
dtypes
.
float64
):
super
(
Accuracy
,
self
).
__init__
(
name
=
name
,
dtype
=
dtype
)
def
call
(
self
,
labels
,
predictions
,
weights
=
None
):
"""Accumulate accuracy statistics.
...
...
tensorflow/contrib/eager/python/metrics_test.py
浏览文件 @
62df65c7
...
...
@@ -34,6 +34,8 @@ class MetricsTest(test.TestCase):
m
(
1000
)
m
([
10000.0
,
100000.0
])
self
.
assertEqual
(
111111.0
/
6
,
m
.
result
().
numpy
())
self
.
assertEqual
(
dtypes
.
float64
,
m
.
dtype
)
self
.
assertEqual
(
dtypes
.
float64
,
m
.
result
().
dtype
)
def
testWeightedMean
(
self
):
m
=
metrics
.
Mean
()
...
...
@@ -41,6 +43,14 @@ class MetricsTest(test.TestCase):
m
([
500000
,
5000
,
500
])
# weights of 1 each
self
.
assertNear
(
535521
/
4.5
,
m
.
result
().
numpy
(),
0.001
)
def
testMeanDtype
(
self
):
# Can override default dtype of float64.
m
=
metrics
.
Mean
(
dtype
=
dtypes
.
float32
)
m
([
0
,
2
])
self
.
assertEqual
(
1
,
m
.
result
().
numpy
())
self
.
assertEqual
(
dtypes
.
float32
,
m
.
dtype
)
self
.
assertEqual
(
dtypes
.
float32
,
m
.
result
().
dtype
)
def
testAccuracy
(
self
):
m
=
metrics
.
Accuracy
()
m
([
0
,
1
,
2
,
3
],
[
0
,
0
,
0
,
0
])
# 1 correct
...
...
@@ -49,6 +59,8 @@ class MetricsTest(test.TestCase):
m
([
6
],
[
6
])
# 1 correct
m
([
7
],
[
2
])
# 0 correct
self
.
assertEqual
(
3.0
/
8
,
m
.
result
().
numpy
())
self
.
assertEqual
(
dtypes
.
float64
,
m
.
dtype
)
self
.
assertEqual
(
dtypes
.
float64
,
m
.
result
().
dtype
)
def
testWeightedAccuracy
(
self
):
m
=
metrics
.
Accuracy
()
...
...
@@ -60,6 +72,14 @@ class MetricsTest(test.TestCase):
m
([
7
],
[
2
])
# 0 correct, weight 1
self
.
assertEqual
(
2.5
/
5
,
m
.
result
().
numpy
())
def
testAccuracyDtype
(
self
):
# Can override default dtype of float64.
m
=
metrics
.
Accuracy
(
dtype
=
dtypes
.
float32
)
m
([
0
,
0
],
[
0
,
1
])
self
.
assertEqual
(
0.5
,
m
.
result
().
numpy
())
self
.
assertEqual
(
dtypes
.
float32
,
m
.
dtype
)
self
.
assertEqual
(
dtypes
.
float32
,
m
.
result
().
dtype
)
def
testTwoMeans
(
self
):
# Verify two metrics with the same class and name don't
# accidentally share state.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录