提交 685da94a 编写于 作者: L Liangzhe Yuan 提交者: A. Unique TensorFlower

#movinet Move export_saved_model.py to movinet/tools/ and add support for...

#movinet Move export_saved_model.py to movinet/tools/ and add support for customized classifier activation.

PiperOrigin-RevId: 425669451
上级 98d7c8e7
......@@ -82,6 +82,9 @@ flags.DEFINE_string(
flags.DEFINE_string(
'activation', 'swish',
'The main activation to use across layers.')
flags.DEFINE_string(
'classifier_activation', 'swish',
'The classifier activation to use.')
flags.DEFINE_string(
'gating_activation', 'sigmoid',
'The gating activation to use in squeeze-excitation layers.')
......@@ -124,11 +127,15 @@ def main(_) -> None:
# states. These dimensions can be set to `None` once the model is built.
input_shape = [1 if s is None else s for s in input_specs.shape]
# Override swish activation implementation to remove custom gradients
activation = FLAGS.activation
if activation == 'swish':
# Override swish activation implementation to remove custom gradients
activation = 'simple_swish'
classifier_activation = FLAGS.classifier_activation
if classifier_activation == 'swish':
classifier_activation = 'simple_swish'
backbone = movinet.Movinet(
model_id=FLAGS.model_id,
causal=FLAGS.causal,
......@@ -145,9 +152,7 @@ def main(_) -> None:
num_classes=FLAGS.num_classes,
output_states=FLAGS.causal,
input_specs=dict(image=input_specs),
# TODO(dankondratyuk): currently set to swish, but will need to
# re-train to use other activations.
activation='simple_swish')
activation=classifier_activation)
model.build(input_shape)
# Compile model to generate some internal Keras variables.
......
......@@ -18,7 +18,7 @@ from absl import flags
import tensorflow as tf
import tensorflow_hub as hub
from official.projects.movinet import export_saved_model
from official.projects.movinet.tools import export_saved_model
FLAGS = flags.FLAGS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册