提交 f13be764 编写于 作者: L Liangzhe Yuan 提交者: A. Unique TensorFlower

internal change.

PiperOrigin-RevId: 422665603
上级 27fb855b
......@@ -88,14 +88,13 @@ class MovinetClassifier(tf.keras.Model):
# Move backbone after super() call so Keras is happy
self._backbone = backbone
def _build_network(
def _build_backbone(
self,
backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network.
) -> Tuple[Mapping[str, Any], Any, Any]:
"""Builds the backbone network and gets states and endpoints.
Args:
backbone: the model backbone.
......@@ -104,9 +103,9 @@ class MovinetClassifier(tf.keras.Model):
layer, will overwrite the contents of the buffer(s).
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and (optionally) output states.
inputs: a dict of input specs.
endpoints: a dict of model endpoints.
states: a dict of model states.
"""
state_specs = state_specs if state_specs is not None else {}
......@@ -145,7 +144,30 @@ class MovinetClassifier(tf.keras.Model):
mismatched_shapes))
else:
endpoints, states = backbone(inputs)
return inputs, endpoints, states
def _build_network(
self,
backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network.
Args:
backbone: the model backbone.
input_specs: the model input spec to use.
state_specs: a dict of states such that, if any of the keys match for a
layer, will overwrite the contents of the buffer(s).
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and (optionally) output states.
"""
inputs, endpoints, states = self._build_backbone(
backbone=backbone, input_specs=input_specs, state_specs=state_specs)
x = endpoints['head']
x = movinet_layers.ClassifierHead(
......
......@@ -46,7 +46,8 @@ from official.modeling import performance
# Import movinet libraries to register the backbone and model into tf.vision
# model garden factory.
# pylint: disable=unused-import
# the followings are the necessary imports.
from official.projects.movinet.google.configs import movinet_google
from official.projects.movinet.google.modeling import movinet_model_google
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model
# pylint: enable=unused-import
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册