diff --git a/official/vision/beta/configs/image_classification.py b/official/vision/beta/configs/image_classification.py index 870e82979fc1d28596b244ba0a1ed0c828b6898c..d88f30667ee5aa106ae0c2bf7a86040f3e7feb86 100644 --- a/official/vision/beta/configs/image_classification.py +++ b/official/vision/beta/configs/image_classification.py @@ -15,7 +15,7 @@ # ============================================================================== """Image classification configuration definition.""" import os -from typing import List +from typing import List, Optional import dataclasses from official.core import config_definitions as cfg from official.core import exp_factory @@ -63,6 +63,8 @@ class ImageClassificationTask(cfg.TaskConfig): validation_data: DataConfig = DataConfig(is_training=False) losses: Losses = Losses() gradient_clip_norm: float = 0.0 + init_checkpoint: Optional[str] = None + init_checkpoint_modules: str = 'all' # all or backbone @exp_factory.register_config_factory('image_classification') diff --git a/official/vision/beta/tasks/image_classification.py b/official/vision/beta/tasks/image_classification.py index ab93d3b67f4c45061ca9711ba629618686fa9d38..d34b4902ba60cfbc78f6b9519f7bb3aa069bdab8 100644 --- a/official/vision/beta/tasks/image_classification.py +++ b/official/vision/beta/tasks/image_classification.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== """Image classification task definition.""" +from absl import logging import tensorflow as tf from official.core import base_task from official.core import input_reader @@ -46,6 +47,30 @@ class ImageClassificationTask(base_task.Task): l2_regularizer=l2_regularizer) return model + def initialize(self, model: tf.keras.Model): + """Loading pretrained checkpoint.""" + if not self.task_config.init_checkpoint: + return + + ckpt_dir_or_file = self.task_config.init_checkpoint + if tf.io.gfile.isdir(ckpt_dir_or_file): + ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) + + # Restoring checkpoint. + if self.task_config.init_checkpoint_modules == 'all': + ckpt = tf.train.Checkpoint(**model.checkpoint_items) + status = ckpt.restore(ckpt_dir_or_file) + status.assert_consumed() + elif self.task_config.init_checkpoint_modules == 'backbone': + ckpt = tf.train.Checkpoint(backbone=model.backbone) + status = ckpt.restore(ckpt_dir_or_file) + status.expect_partial().assert_existing_objects_matched() + else: + assert "Only 'all' or 'backbone' can be used to initialize the model." + + logging.info('Finished loading pretrained checkpoint from %s', + ckpt_dir_or_file) + def build_inputs(self, params, input_context=None): """Builds classification input."""