提交 e5c71d51 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Refactor: use common tf_utils

PiperOrigin-RevId: 285613648
上级 4b06a97a
......@@ -23,6 +23,7 @@ import copy
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.xlnet import data_utils
......@@ -102,52 +103,6 @@ def is_special_none_tensor(tensor):
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
def unpack_inputs(inputs):
"""Unpacks a tuple of `inputs` tensors to a tuple.
Args:
inputs: A list of tensors.
Returns:
A tuple of tensors. If any input is a special constant tensor, replace it
with None.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if is_special_none_tensor(x):
outputs.append(None)
else:
outputs.append(x)
x = tuple(outputs)
# To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
# from triggering.
if len(x) == 1:
return x[0]
return tuple(outputs)
def pack_inputs(inputs):
"""Packs a list of `inputs` tensors to a tuple.
Args:
inputs: A list of tensors.
Returns:
A tuple of tensors. If any input is None, replace it with a special constant
tensor.
"""
inputs = tf.nest.flatten(inputs)
outputs = []
for x in inputs:
if x is None:
outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
else:
outputs.append(x)
return tuple(outputs)
class PositionalEmbedding(tf.keras.layers.Layer):
"""Generates relative positional embeddings used in Transformer-XL and XLNet."""
......@@ -196,7 +151,7 @@ class RelativeAttention(tf.keras.layers.Layer):
def __call__(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask, **kwargs):
inputs = pack_inputs([
inputs = tf_utils.pack_inputs([
q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask
])
......@@ -205,7 +160,7 @@ class RelativeAttention(tf.keras.layers.Layer):
def call(self, inputs):
"""Implements call() for the layer."""
(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask) = unpack_inputs(inputs)
r_r_bias, r_s_bias, attn_mask) = tf_utils.unpack_inputs(inputs)
# content based attention score
ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
......@@ -363,7 +318,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping, **kwargs):
inputs = pack_inputs([
inputs = tf_utils.pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping,
])
......@@ -372,7 +327,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
def call(self, inputs):
"""Implements call() for the layer."""
(h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping) = unpack_inputs(inputs)
attn_mask_g, mems, target_mapping) = tf_utils.unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0)
......@@ -1011,12 +966,12 @@ class LMLossLayer(tf.keras.layers.Layer):
super(LMLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, target, lookup_table, target_mask, **kwargs):
inputs = pack_inputs([hidden, target, lookup_table, target_mask])
inputs = tf_utils.pack_inputs([hidden, target, lookup_table, target_mask])
return super(LMLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
(hidden, target, lookup_table, tgt_mask) = unpack_inputs(inputs)
(hidden, target, lookup_table, tgt_mask) = tf_utils.unpack_inputs(inputs)
if self.use_proj:
hidden = self.proj_layer_norm(self.proj_layer(hidden))
if self.tie_weight:
......@@ -1122,12 +1077,12 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
super(ClassificationLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, labels, **kwargs):
inputs = pack_inputs([hidden, labels])
inputs = tf_utils.pack_inputs([hidden, labels])
return super(ClassificationLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs):
"""Implements call() for the layer."""
(hidden, labels) = unpack_inputs(inputs)
(hidden, labels) = tf_utils.unpack_inputs(inputs)
logits = self.proj_layer(hidden)
one_hot_target = tf.one_hot(labels, self.n_class, dtype=hidden.dtype) # pytype: disable=attribute-error
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册