未验证 提交 c0b31c51 编写于 作者: H Haoyu Zhang 提交者: GitHub

Fix trivial model to work properly with fp16 (#6760)

* Fix trivial model to work properly with fp16

* Add comment on manual casting
上级 5e876e6e
......@@ -201,7 +201,7 @@ def run(flags_obj):
input_layer_batch_size = None
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
else:
model = resnet_model.resnet50(
num_classes=imagenet_main.NUM_CLASSES,
......
......@@ -23,15 +23,19 @@ from tensorflow.python.keras import layers
from tensorflow.python.keras import models
def trivial_model(num_classes):
def trivial_model(num_classes, dtype='float32'):
"""Trivial model for ImageNet dataset."""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)
img_input = layers.Input(shape=input_shape, dtype=dtype)
x = layers.Lambda(lambda x: backend.reshape(x, [-1, 224 * 224 * 3]),
name='reshape')(img_input)
x = layers.Dense(1, name='fc1')(x)
x = layers.Dense(num_classes, activation='softmax', name='fc1000')(x)
x = layers.Dense(num_classes, name='fc1000')(x)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
x = backend.cast(x, 'float32')
x = layers.Activation('softmax')(x)
return models.Model(img_input, x, name='trivial')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册