提交 0dadbbc8 编写于 作者: F Frederick Liu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 420497751
上级 993dbf54
......@@ -21,6 +21,7 @@ from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling import nn_blocks
layers = tf.keras.layers
VIT_SPECS = {
......@@ -121,6 +122,7 @@ class Encoder(tf.keras.layers.Layer):
inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
add_pos_embed=True,
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
......@@ -132,11 +134,13 @@ class Encoder(tf.keras.layers.Layer):
self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed
def build(self, input_shape):
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
name='posembed_input')
if self._add_pos_embed:
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate)
self._encoder_layers = []
......@@ -160,7 +164,9 @@ class Encoder(tf.keras.layers.Layer):
super().build(input_shape)
def call(self, inputs, training=None):
x = self._pos_embed(inputs, inputs_positions=self._inputs_positions)
x = inputs
if self._add_pos_embed:
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
x = self._dropout(x, training=training)
for encoder_layer in self._encoder_layers:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册