From 01e24a525234c6542f3ce312327a3627bae70d05 Mon Sep 17 00:00:00 2001 From: Zhichao Lu Date: Wed, 21 Oct 2020 18:45:08 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 338383183 --- .../context_rcnn/generate_detection_data_tf2_test.py | 1 + .../context_rcnn/generate_embedding_data_tf2_test.py | 1 + research/object_detection/exporter_lib_tf2_test.py | 4 ++++ research/object_detection/exporter_lib_v2.py | 9 +++++++-- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py b/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py index 89f743dd9..71b327635 100644 --- a/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py +++ b/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py @@ -133,6 +133,7 @@ class GenerateDetectionDataTest(tf.test.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( diff --git a/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py b/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py index 677adaf92..5566d6d5f 100644 --- a/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py +++ b/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py @@ -139,6 +139,7 @@ class GenerateEmbeddingData(tf.test.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( diff --git a/research/object_detection/exporter_lib_tf2_test.py b/research/object_detection/exporter_lib_tf2_test.py index 07272aebe..8e85e1124 100644 --- a/research/object_detection/exporter_lib_tf2_test.py +++ b/research/object_detection/exporter_lib_tf2_test.py @@ -120,6 +120,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( @@ -181,6 +182,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( @@ -217,6 +219,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( @@ -262,6 +265,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( diff --git a/research/object_detection/exporter_lib_v2.py b/research/object_detection/exporter_lib_v2.py index 2a4f1162e..5a7a182c6 100644 --- a/research/object_detection/exporter_lib_v2.py +++ b/research/object_detection/exporter_lib_v2.py @@ -25,6 +25,11 @@ from object_detection.data_decoders import tf_example_decoder from object_detection.utils import config_util +INPUT_BUILDER_UTIL_MAP = { + 'model_build': model_builder.build, +} + + def _decode_image(encoded_image_string_tensor): image_tensor = tf.image.decode_image(encoded_image_string_tensor, channels=3) @@ -230,8 +235,8 @@ def export_inference_graph(input_type, output_checkpoint_directory = os.path.join(output_directory, 'checkpoint') output_saved_model_directory = os.path.join(output_directory, 'saved_model') - detection_model = model_builder.build(pipeline_config.model, - is_training=False) + detection_model = INPUT_BUILDER_UTIL_MAP['model_build']( + pipeline_config.model, is_training=False) ckpt = tf.train.Checkpoint( model=detection_model) -- GitLab