diff --git a/official/projects/vit/modeling/vit.py b/official/projects/vit/modeling/vit.py index ca7e2a241f0dd66ffa39af24c37400fc1fe0b359..4245711dc921a1c6d4adc5f4c18d782767a1d502 100644 --- a/official/projects/vit/modeling/vit.py +++ b/official/projects/vit/modeling/vit.py @@ -26,7 +26,6 @@ from official.projects.vit.modeling.vit_specs import VIT_SPECS from official.vision.modeling.backbones import factory from official.vision.modeling.layers import nn_layers - layers = tf.keras.layers @@ -67,9 +66,7 @@ class AddPositionEmbs(tf.keras.layers.Layer): self.pos_embedding = self.add_weight( 'pos_embedding', pos_emb_shape, initializer=self.posemb_init) - def _interpolate(self, - pos_embedding: tf.Tensor, - from_shape: Tuple[int, int], + def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int], to_shape: Tuple[int, int]) -> tf.Tensor: """Interpolates the positional embeddings.""" logging.info('Interpolating postional embedding from length: %d to %d', @@ -84,9 +81,10 @@ class AddPositionEmbs(tf.keras.layers.Layer): pos_embedding = self.pos_embedding # inputs.shape is (batch_size, seq_len, emb_dim). if inputs.shape[1] != pos_embedding.shape[1]: - pos_embedding = self._interpolate(pos_embedding, - from_shape=self.posemb_origin_shape, - to_shape=self.posemb_target_shape) + pos_embedding = self._interpolate( + pos_embedding, + from_shape=self.posemb_origin_shape, + to_shape=self.posemb_target_shape) pos_embedding = tf.cast(pos_embedding, inputs.dtype) return inputs + pos_embedding @@ -262,7 +260,8 @@ class VisionTransformer(tf.keras.Model): class_name='TruncatedNormal', config=dict(stddev=.02)), init_stochastic_depth_rate=init_stochastic_depth_rate, pos_embed_origin_shape=pos_embed_shape, - pos_embed_target_shape=pos_embed_target_shape)(x) + pos_embed_target_shape=pos_embed_target_shape)( + x) if pooler == 'token': x = x[:, 0] @@ -303,8 +302,8 @@ def build_vit(input_specs, del norm_activation_config backbone_type = backbone_config.type backbone_cfg = backbone_config.get() - assert backbone_type == 'vit', (f'Inconsistent backbone type ' - f'{backbone_type}') + assert backbone_type == 'legacy_vit', (f'Inconsistent backbone type ' + f'{backbone_type}') backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name]) return VisionTransformer(