提交 d08af90d 编写于 作者: S Scott Zhu 提交者: A. Unique TensorFlower

Restructure the Keras class hierarchy for Network, Model and Sequential.

The intention of this change is to reduce the code complexity within Keras class, especially for Network, which currently contains logic for both subclass Model and functional Model.

After this change, the subclass model and functional model become individual class and become self contained.

1. Model is now the base class for subclass model. It doesn't contains network structure management, and the topology will be created within __init__ and __call__, which is for user to implement. It also contains compile/fit/eval/predict, which is the basic functionality for model training.

2. Functional is created based on existing Network class. It extends the Model, which allows it leverage compile/fit/eval/predict. In addition, it also take input/output as init parameter and manage the network topology.

3. Sequential model is now a subclass of Functional, since it will use Functional's method to manage it topology (layer stacking).

Model(input, output) will create a Functional under the hood, and behave the same way as before.

PiperOrigin-RevId: 311232972
上级 3ae3aecf
......@@ -21,13 +21,12 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class AlbertTransformerEncoder(network.Network):
class AlbertTransformerEncoder(tf.keras.Model):
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.
This network implements the encoder described in the paper "ALBERT: A Lite
......
......@@ -21,12 +21,9 @@ from __future__ import print_function
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.engine import network
@tf.keras.utils.register_keras_serializable(package='Text')
class Classification(network.Network):
class Classification(tf.keras.Model):
"""Classification network head for BERT modeling.
This network implements a simple classifier head based on a dense layer.
......
......@@ -25,13 +25,12 @@ import inspect
import gin
import tensorflow as tf
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class EncoderScaffold(network.Network):
class EncoderScaffold(tf.keras.Model):
"""Bi-directional Transformer-based encoder network scaffold.
This network allows users to flexibly implement an encoder similar to the one
......
......@@ -21,12 +21,11 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(network.Network):
class MaskedLM(tf.keras.Model):
"""Masked language model network head for BERT modeling.
This network implements a masked language model based on the provided network.
......
......@@ -21,12 +21,9 @@ from __future__ import print_function
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.engine import network
@tf.keras.utils.register_keras_serializable(package='Text')
class SpanLabeling(network.Network):
class SpanLabeling(tf.keras.Model):
"""Span labeling network head for BERT modeling.
This network implements a simple single-span labeler based on a dense layer.
......
......@@ -21,13 +21,12 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.engine import network # pylint: disable=g-direct-tensorflow-import
from official.modeling import activations
from official.nlp.modeling import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class TransformerEncoder(network.Network):
class TransformerEncoder(tf.keras.Model):
"""Bi-directional Transformer-based encoder network.
This network implements a bi-directional Transformer-based encoder as
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册