Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
e0f979b8
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e0f979b8
编写于
12月 12, 2018
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
12月 12, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix `predict` with `run_eagerly=True`
PiperOrigin-RevId: 225257343
上级
22af085f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
20 addition
and
3 deletion
+20
-3
tensorflow/python/keras/engine/training_eager_test.py
tensorflow/python/keras/engine/training_eager_test.py
+15
-0
tensorflow/python/keras/engine/training_generator.py
tensorflow/python/keras/engine/training_generator.py
+5
-3
未找到文件。
tensorflow/python/keras/engine/training_eager_test.py
浏览文件 @
e0f979b8
...
...
@@ -246,6 +246,21 @@ class CorrectnessTest(test.TestCase):
layer
(
1.
)
# Plain-value inputs are only valid in eager mode.
self
.
assertEqual
(
1
,
len
(
layer
.
losses
))
def
test_predict_correctness
(
self
):
i1
=
keras
.
layers
.
Input
(
shape
=
(
4
,
5
))
i2
=
keras
.
layers
.
Input
(
shape
=
(
4
,
5
))
i3
=
keras
.
layers
.
Input
(
shape
=
(
4
,
5
))
o
=
keras
.
layers
.
add
([
i1
,
i2
,
i3
])
model
=
keras
.
models
.
Model
([
i1
,
i2
,
i3
],
o
)
model
.
run_eagerly
=
True
x1
=
np
.
random
.
random
((
2
,
4
,
5
))
x2
=
np
.
random
.
random
((
2
,
4
,
5
))
x3
=
np
.
random
.
random
((
2
,
4
,
5
))
out
=
model
.
predict
([
x1
,
x2
,
x3
])
self
.
assertAllClose
(
out
,
x1
+
x2
+
x3
)
if
__name__
==
'__main__'
:
ops
.
enable_eager_execution
()
...
...
tensorflow/python/keras/engine/training_generator.py
浏览文件 @
e0f979b8
...
...
@@ -49,7 +49,7 @@ def model_iteration(model,
max_queue_size
=
10
,
workers
=
1
,
use_multiprocessing
=
False
,
shuffle
=
Tru
e
,
shuffle
=
Fals
e
,
initial_epoch
=
0
,
mode
=
'train'
,
batch_size
=
None
,
...
...
@@ -246,8 +246,10 @@ def model_iteration(model,
# Maintain compatibility with the existing names.
fit_generator
=
functools
.
partial
(
model_iteration
,
mode
=
'train'
)
evaluate_generator
=
functools
.
partial
(
model_iteration
,
mode
=
'test'
)
predict_generator
=
functools
.
partial
(
model_iteration
,
mode
=
'predict'
)
evaluate_generator
=
functools
.
partial
(
model_iteration
,
mode
=
'test'
,
shuffle
=
False
)
predict_generator
=
functools
.
partial
(
model_iteration
,
mode
=
'predict'
,
shuffle
=
False
)
def
_get_next_batch
(
output_generator
,
mode
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录