提交 a1821955 编写于 作者: F Fan Yang 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 528500682
上级 e71d3a5e
......@@ -21,6 +21,7 @@ import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import export_base
from official.projects.yolo.configs.yolo import YoloTask
from official.projects.yolo.configs.yolov7 import YoloV7Task
from official.projects.yolo.modeling import factory as yolo_factory
from official.projects.yolo.modeling.backbones import darknet # pylint: disable=unused-import
from official.projects.yolo.modeling.decoders import yolo_decoder # pylint: disable=unused-import
......@@ -163,10 +164,16 @@ def create_yolo_export_module(
input_type, batch_size, input_image_size, num_channels, input_name)
input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
input_image_size + [num_channels])
model, _ = yolo_factory.build_yolo(
input_specs=input_specs,
model_config=params.task.model,
l2_regularization=None)
if isinstance(params.task, YoloTask):
model, _ = yolo_factory.build_yolo(
input_specs=input_specs,
model_config=params.task.model,
l2_regularization=None)
elif isinstance(params.task, YoloV7Task):
model = yolo_factory.build_yolov7(
input_specs=input_specs,
model_config=params.task.model,
l2_regularization=None)
def preprocess_fn(inputs):
image_tensor = export_utils.parse_image(inputs, input_type,
......@@ -247,7 +254,7 @@ def get_export_module(params: cfg.ExperimentConfig,
input_image_size,
num_channels,
input_name)
elif isinstance(params.task, YoloTask):
elif isinstance(params.task, (YoloTask, YoloV7Task)):
export_module = create_yolo_export_module(params, input_type, batch_size,
input_image_size, num_channels,
input_name)
......
......@@ -38,9 +38,8 @@ from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.yolo.configs import yolo as cfg # pylint: disable=unused-import
from official.projects.yolo.common import registry_imports # pylint: disable=unused-import
from official.projects.yolo.serving import export_module_factory
from official.projects.yolo.tasks import yolo as task # pylint: disable=unused-import
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册