提交 a75e870b 编写于 作者: A Abdullah Rashwan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 339917189
上级 08d3c799
......@@ -15,7 +15,7 @@
# ==============================================================================
"""Image classification configuration definition."""
import os
from typing import List
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
......@@ -63,6 +63,8 @@ class ImageClassificationTask(cfg.TaskConfig):
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
gradient_clip_norm: float = 0.0
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
@exp_factory.register_config_factory('image_classification')
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Image classification task definition."""
from absl import logging
import tensorflow as tf
from official.core import base_task
from official.core import input_reader
......@@ -46,6 +47,30 @@ class ImageClassificationTask(base_task.Task):
l2_regularizer=l2_regularizer)
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if self.task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self.task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
else:
assert "Only 'all' or 'backbone' can be used to initialize the model."
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
"""Builds classification input."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册