提交 467db6c5 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 500425343
上级 c5335062
......@@ -346,7 +346,7 @@ class InputReader:
@property
def tfds_info(self) -> Union[tfds.core.DatasetInfo,
Dict[str, tfds.core.DatasetInfo]]:
dict[str, tfds.core.DatasetInfo]]:
"""Returns TFDS dataset info, if available."""
if self._tfds_builder:
if isinstance(self._tfds_builder, dict):
......@@ -381,7 +381,7 @@ class InputReader:
input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: Optional[
Union[tfds.core.DatasetBuilder,
Dict[str, tfds.core.DatasetBuilder]]] = None):
dict[str, tfds.core.DatasetBuilder]]] = None):
"""Reads the data source (files/tfds) to a dataset."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
......
......@@ -37,6 +37,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
scale_factor: Whether to scale the output embeddings. Defaults to None (that
is, not to scale). Setting this option to a float will let values in
output embeddings multiplied by scale_factor.
weight_fallback_dtype: When keras mix precision inferred wrong dtype for
varibales, `weight_fallback_dtype` will be used to define the dtype of
weights.
"""
def __init__(self,
......@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
initializer="glorot_uniform",
use_one_hot=False,
scale_factor=None,
weight_fallback_dtype=tf.float32,
**kwargs):
super().__init__(**kwargs)
......@@ -53,6 +57,10 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._initializer = initializer
self._use_one_hot = use_one_hot
self._scale_factor = scale_factor
# Backup control of the weight dtype because Keras mix precision sometimes
# depends on the input to infer the compute dtype, but the inputs of
# this layer are int type.
self._weight_fallback_dtype = weight_fallback_dtype
def get_config(self):
config = {
......@@ -61,28 +69,37 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"initializer": self._initializer,
"use_one_hot": self._use_one_hot,
"scale_factor": self._scale_factor,
"weight_fallback_dtype": self._weight_fallback_dtype,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
if (
self.dtype is not None
and not tf.dtypes.as_dtype(self.dtype).is_floating
):
# Keras failed to infer the right dtype.
dtype = self._weight_fallback_dtype
else:
dtype = self.dtype
self.embeddings = self.add_weight(
"embeddings",
shape=[self._vocab_size, self._embedding_width],
initializer=self._initializer,
dtype=tf.float32)
dtype=dtype)
super().build(input_shape)
def call(self, inputs):
flat_inputs = tf.reshape(inputs, [-1])
if self._use_one_hot:
dtype = self._compute_dtype
dtype = self.compute_dtype
if not tf.dtypes.as_dtype(dtype).is_floating:
# TensorFlow 1 compatibility. In TF1, self._compute_dtype is int32
# TensorFlow 1 compatibility. In TF1, self.compute_dtype is int32
# instead of a floating-point dtype, as the dtype is inferred from the
# dtype of the inputs
dtype = tf.float32
dtype = self._weight_fallback_dtype
one_hot_data = tf.one_hot(
flat_inputs, depth=self._vocab_size, dtype=dtype)
embeddings = tf.matmul(one_hot_data, self.embeddings)
......@@ -90,7 +107,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embeddings = tf.gather(self.embeddings, flat_inputs)
embeddings = tf.reshape(
embeddings,
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._scale_factor:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册