提交 8ed878f4 编写于 作者: C Chaochao Yan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 481938528
上级 7a78713d
......@@ -26,7 +26,6 @@ from official.projects.vit.modeling.vit_specs import VIT_SPECS
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_layers
layers = tf.keras.layers
......@@ -67,9 +66,7 @@ class AddPositionEmbs(tf.keras.layers.Layer):
self.pos_embedding = self.add_weight(
'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
def _interpolate(self,
pos_embedding: tf.Tensor,
from_shape: Tuple[int, int],
def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int],
to_shape: Tuple[int, int]) -> tf.Tensor:
"""Interpolates the positional embeddings."""
logging.info('Interpolating postional embedding from length: %d to %d',
......@@ -84,9 +81,10 @@ class AddPositionEmbs(tf.keras.layers.Layer):
pos_embedding = self.pos_embedding
# inputs.shape is (batch_size, seq_len, emb_dim).
if inputs.shape[1] != pos_embedding.shape[1]:
pos_embedding = self._interpolate(pos_embedding,
from_shape=self.posemb_origin_shape,
to_shape=self.posemb_target_shape)
pos_embedding = self._interpolate(
pos_embedding,
from_shape=self.posemb_origin_shape,
to_shape=self.posemb_target_shape)
pos_embedding = tf.cast(pos_embedding, inputs.dtype)
return inputs + pos_embedding
......@@ -262,7 +260,8 @@ class VisionTransformer(tf.keras.Model):
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate,
pos_embed_origin_shape=pos_embed_shape,
pos_embed_target_shape=pos_embed_target_shape)(x)
pos_embed_target_shape=pos_embed_target_shape)(
x)
if pooler == 'token':
x = x[:, 0]
......@@ -303,8 +302,8 @@ def build_vit(input_specs,
del norm_activation_config
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'vit', (f'Inconsistent backbone type '
f'{backbone_type}')
assert backbone_type == 'legacy_vit', (f'Inconsistent backbone type '
f'{backbone_type}')
backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])
return VisionTransformer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册