提交 98a558b7 编写于 作者: L Liangzhe Yuan 提交者: A. Unique TensorFlower

#movinet Add se_type option in tools/convert_3d_2plus1d.py

PiperOrigin-RevId: 420368239
上级 d58be675
......@@ -29,6 +29,8 @@ flags.DEFINE_string(
'Export path to save the saved_model file.')
flags.DEFINE_string(
'model_id', 'a0', 'MoViNet model name.')
flags.DEFINE_string(
'se_type', '2plus3d', 'MoViNet model SE type.')
flags.DEFINE_bool(
'causal', True, 'Run the model in causal mode.')
flags.DEFINE_bool(
......@@ -46,6 +48,7 @@ def main(_) -> None:
backbone_2plus1d = movinet.Movinet(
model_id=FLAGS.model_id,
causal=FLAGS.causal,
se_type=FLAGS.se_type,
conv_type='2plus1d',
use_positional_encoding=FLAGS.use_positional_encoding)
model_2plus1d = movinet_model.MovinetClassifier(
......@@ -56,6 +59,7 @@ def main(_) -> None:
backbone_3d_2plus1d = movinet.Movinet(
model_id=FLAGS.model_id,
causal=FLAGS.causal,
se_type=FLAGS.se_type,
conv_type='3d_2plus1d',
use_positional_encoding=FLAGS.use_positional_encoding)
model_3d_2plus1d = movinet_model.MovinetClassifier(
......
......@@ -36,6 +36,7 @@ class Convert3d2plus1dTest(tf.test.TestCase):
model_3d_2plus1d = movinet_model.MovinetClassifier(
backbone=movinet.Movinet(
model_id='a0',
se_type='2plus3d',
conv_type='3d_2plus1d'),
num_classes=600)
model_3d_2plus1d.build([1, 1, 1, 1, 3])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册