提交 01e24a52 编写于 作者: Z Zhichao Lu 提交者: TF Object Detection Team

Internal change.

PiperOrigin-RevId: 338383183
上级 58b6666f
...@@ -133,6 +133,7 @@ class GenerateDetectionDataTest(tf.test.TestCase): ...@@ -133,6 +133,7 @@ class GenerateDetectionDataTest(tf.test.TestCase):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output') output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph( exporter_lib_v2.export_inference_graph(
......
...@@ -139,6 +139,7 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -139,6 +139,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output') output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph( exporter_lib_v2.export_inference_graph(
......
...@@ -120,6 +120,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): ...@@ -120,6 +120,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output') output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph( exporter_lib_v2.export_inference_graph(
...@@ -181,6 +182,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): ...@@ -181,6 +182,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output') output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph( exporter_lib_v2.export_inference_graph(
...@@ -217,6 +219,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): ...@@ -217,6 +219,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output') output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph( exporter_lib_v2.export_inference_graph(
...@@ -262,6 +265,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): ...@@ -262,6 +265,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase):
with mock.patch.object( with mock.patch.object(
model_builder, 'build', autospec=True) as mock_builder: model_builder, 'build', autospec=True) as mock_builder:
mock_builder.return_value = FakeModel() mock_builder.return_value = FakeModel()
exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder
output_directory = os.path.join(tmp_dir, 'output') output_directory = os.path.join(tmp_dir, 'output')
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
exporter_lib_v2.export_inference_graph( exporter_lib_v2.export_inference_graph(
......
...@@ -25,6 +25,11 @@ from object_detection.data_decoders import tf_example_decoder ...@@ -25,6 +25,11 @@ from object_detection.data_decoders import tf_example_decoder
from object_detection.utils import config_util from object_detection.utils import config_util
INPUT_BUILDER_UTIL_MAP = {
'model_build': model_builder.build,
}
def _decode_image(encoded_image_string_tensor): def _decode_image(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor, image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3) channels=3)
...@@ -230,8 +235,8 @@ def export_inference_graph(input_type, ...@@ -230,8 +235,8 @@ def export_inference_graph(input_type,
output_checkpoint_directory = os.path.join(output_directory, 'checkpoint') output_checkpoint_directory = os.path.join(output_directory, 'checkpoint')
output_saved_model_directory = os.path.join(output_directory, 'saved_model') output_saved_model_directory = os.path.join(output_directory, 'saved_model')
detection_model = model_builder.build(pipeline_config.model, detection_model = INPUT_BUILDER_UTIL_MAP['model_build'](
is_training=False) pipeline_config.model, is_training=False)
ckpt = tf.train.Checkpoint( ckpt = tf.train.Checkpoint(
model=detection_model) model=detection_model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册