提交 01e24a52 编写于 作者: Z Zhichao Lu 提交者: TF Object Detection Team

Internal change.

PiperOrigin-RevId: 338383183
上级 58b6666f
......@@ -133,6 +133,7 @@ class GenerateDetectionDataTest(tf.test.TestCase):
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
......
......@@ -139,6 +139,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
......
......@@ -120,6 +120,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
......@@ -181,6 +182,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
......@@ -217,6 +219,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
......@@ -262,6 +265,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph(
......
......@@ -25,6 +25,11 @@ from object_detection.data_decoders import tf_example_decoder
from object_detection.utils import config_util
INPUT_BUILDER_UTIL_MAP = {
'model_build': model_builder.build,
}
def _decode_image(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3)
......@@ -230,8 +235,8 @@ def export_inference_graph(input_type,
output_checkpoint_directory = os.path.join(output_directory, 'checkpoint')
output_saved_model_directory = os.path.join(output_directory, 'saved_model')
detection_model = model_builder.build(pipeline_config.model,
is_training=False)
detection_model = INPUT_BUILDER_UTIL_MAP['model_build'](
pipeline_config.model, is_training=False)
ckpt = tf.train.Checkpoint(
model=detection_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册