diff --git a/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py b/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py index 89f743dd96488468bf430accea12a8eed3d72e13..71b327635579ea812c38deb4190248cff1a187b8 100644 --- a/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py +++ b/research/object_detection/dataset_tools/context_rcnn/generate_detection_data_tf2_test.py @@ -133,6 +133,7 @@ class GenerateDetectionDataTest(tf.test.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( diff --git a/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py b/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py index 677adaf924de65977631d24730d4209f144d70c3..5566d6d5f35b7bccf363be9a9a1088baef326e3f 100644 --- a/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py +++ b/research/object_detection/dataset_tools/context_rcnn/generate_embedding_data_tf2_test.py @@ -139,6 +139,7 @@ class GenerateEmbeddingData(tf.test.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( diff --git a/research/object_detection/exporter_lib_tf2_test.py b/research/object_detection/exporter_lib_tf2_test.py index 07272aebe05546cd9d0e90d32feb4831815dd132..8e85e1124bca40957464b5c80acb6a24ea7fcc3d 100644 --- a/research/object_detection/exporter_lib_tf2_test.py +++ b/research/object_detection/exporter_lib_tf2_test.py @@ -120,6 +120,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( @@ -181,6 +182,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( @@ -217,6 +219,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( @@ -262,6 +265,7 @@ class ExportInferenceGraphTest(tf.test.TestCase, parameterized.TestCase): with mock.patch.object( model_builder, 'build', autospec=True) as mock_builder: mock_builder.return_value = FakeModel() + exporter_lib_v2.INPUT_BUILDER_UTIL_MAP['model_build'] = mock_builder output_directory = os.path.join(tmp_dir, 'output') pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() exporter_lib_v2.export_inference_graph( diff --git a/research/object_detection/exporter_lib_v2.py b/research/object_detection/exporter_lib_v2.py index 2a4f1162eae8a98bb543db9a276b2737a7a22144..5a7a182c62ab4a24d271a64bae9e3b4fb72fcb79 100644 --- a/research/object_detection/exporter_lib_v2.py +++ b/research/object_detection/exporter_lib_v2.py @@ -25,6 +25,11 @@ from object_detection.data_decoders import tf_example_decoder from object_detection.utils import config_util +INPUT_BUILDER_UTIL_MAP = { + 'model_build': model_builder.build, +} + + def _decode_image(encoded_image_string_tensor): image_tensor = tf.image.decode_image(encoded_image_string_tensor, channels=3) @@ -230,8 +235,8 @@ def export_inference_graph(input_type, output_checkpoint_directory = os.path.join(output_directory, 'checkpoint') output_saved_model_directory = os.path.join(output_directory, 'saved_model') - detection_model = model_builder.build(pipeline_config.model, - is_training=False) + detection_model = INPUT_BUILDER_UTIL_MAP['model_build']( + pipeline_config.model, is_training=False) ckpt = tf.train.Checkpoint( model=detection_model)