Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_40195168达庆意
keras
提交
434d29d2
K
keras
项目概览
weixin_40195168达庆意
/
keras
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
K
keras
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
434d29d2
编写于
6月 24, 2021
作者:
C
Chenkai Kuang
提交者:
TensorFlower Gardener
6月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix keras metric.result_state when the metric variables are sharded variable.
PiperOrigin-RevId: 381292911
上级
1232f05b
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
12 addition
and
5 deletion
+12
-5
keras/distribute/sharded_variable_test.py
keras/distribute/sharded_variable_test.py
+6
-0
keras/metrics.py
keras/metrics.py
+6
-5
未找到文件。
keras/distribute/sharded_variable_test.py
浏览文件 @
434d29d2
...
...
@@ -108,22 +108,28 @@ class ShardedVariableTest(tf.test.TestCase):
def
test_keras_metrics
(
self
):
with
self
.
strategy
.
scope
():
fp
=
keras
.
metrics
.
FalsePositives
(
thresholds
=
[
0.2
,
0.5
,
0.7
,
0.8
])
auc
=
keras
.
metrics
.
AUC
(
num_thresholds
=
10
)
@
tf
.
function
def
update
():
fp
.
update_state
([
0.
,
1.
,
0.
,
0.
],
[
0.
,
0.
,
0.3
,
0.9
])
auc
.
update_state
([
0
,
0
,
1
,
1
],
[
0
,
0.5
,
0.3
,
0.9
])
@
tf
.
function
def
reset
():
fp
.
reset_state
()
auc
.
reset_state
()
update
()
self
.
assertEqual
(
auc
.
result
(),
0.75
)
self
.
assertAllEqual
(
fp
.
result
(),
[
2.
,
1.
,
1.
,
1.
])
reset
()
self
.
assertEqual
(
auc
.
result
(),
0.0
)
self
.
assertAllEqual
(
fp
.
result
(),
[
0.
,
0.
,
0.
,
0.
])
self
.
assertTrue
(
hasattr
(
auc
.
true_positives
,
'variables'
))
self
.
assertTrue
(
hasattr
(
fp
.
accumulator
,
'variables'
))
def
test_saved_model
(
self
):
...
...
keras/metrics.py
浏览文件 @
434d29d2
...
...
@@ -1038,9 +1038,9 @@ class _ConfusionMatrixConditionCount(Metric):
return
tf
.
convert_to_tensor
(
result
)
def
reset_state
(
self
):
num_thresholds
=
len
(
to_list
(
self
.
thresholds
))
backend
.
batch_set_value
(
[(
v
,
np
.
zeros
((
num_thresholds
,)))
for
v
in
self
.
variables
])
backend
.
batch_set_value
([
(
v
,
np
.
zeros
(
v
.
shape
.
as_list
()))
for
v
in
self
.
variables
])
def
get_config
(
self
):
config
=
{
'thresholds'
:
self
.
init_thresholds
}
...
...
@@ -3175,8 +3175,9 @@ class MeanTensor(Metric):
def
reset_state
(
self
):
if
self
.
_built
:
backend
.
batch_set_value
(
[(
v
,
np
.
zeros
(
self
.
_shape
.
as_list
()))
for
v
in
self
.
variables
])
backend
.
batch_set_value
([
(
v
,
np
.
zeros
(
v
.
shape
.
as_list
()))
for
v
in
self
.
variables
])
@
keras_export
(
'keras.metrics.BinaryCrossentropy'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录