diff --git a/official/nlp/serving/export_savedmodel_util.py b/official/nlp/serving/export_savedmodel_util.py index e06dd487becb493971757a26fb0b436832cc9c93..0d783b30dac0bb7594137eb6200e2469bbb331bf 100644 --- a/official/nlp/serving/export_savedmodel_util.py +++ b/official/nlp/serving/export_savedmodel_util.py @@ -13,12 +13,19 @@ # limitations under the License. """Common library to export a SavedModel from the export module.""" +import os +import time from typing import Dict, List, Optional, Text, Union +from absl import logging import tensorflow as tf + from official.core import export_base +MAX_DIRECTORY_CREATION_ATTEMPTS = 10 + + def export(export_module: export_base.ExportModule, function_keys: Union[List[Text], Dict[Text, Text]], export_savedmodel_dir: Text, @@ -39,7 +46,39 @@ def export(export_module: export_base.ExportModule, The savedmodel directory path. """ save_options = tf.saved_model.SaveOptions(function_aliases={ - "tpu_candidate": export_module.serve, + 'tpu_candidate': export_module.serve, }) return export_base.export(export_module, function_keys, export_savedmodel_dir, checkpoint_path, timestamped, save_options) + + +def get_timestamped_export_dir(export_dir_base): + """Builds a path to a new subdirectory within the base directory. + + Args: + export_dir_base: A string containing a directory to write the exported graph + and checkpoints. + + Returns: + The full path of the new subdirectory (which is not actually created yet). + + Raises: + RuntimeError: if repeated attempts fail to obtain a unique timestamped + directory name. + """ + attempts = 0 + while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS: + timestamp = int(time.time()) + + result_dir = os.path.join(export_dir_base, str(timestamp)) + if not tf.io.gfile.exists(result_dir): + # Collisions are still possible (though extremely unlikely): this + # directory is not actually created yet, but it will be almost + # instantly on return from this function. + return result_dir + time.sleep(1) + attempts += 1 + logging.warning('Directory %s already exists; retrying (attempt %s/%s)', + str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS) + raise RuntimeError('Failed to obtain a unique export directory name after ' + f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')