From 05a95e16dec3572ef0aca70ae134e846c1a6c4fc Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Tue, 27 Oct 2020 09:40:10 -0700 Subject: [PATCH] Add in a comment about legacy input masks for XLNet. PiperOrigin-RevId: 339267061 --- official/nlp/xlnet/xlnet_modeling.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/official/nlp/xlnet/xlnet_modeling.py b/official/nlp/xlnet/xlnet_modeling.py index d5899addc..3cf5d85a2 100644 --- a/official/nlp/xlnet/xlnet_modeling.py +++ b/official/nlp/xlnet/xlnet_modeling.py @@ -795,6 +795,8 @@ class PretrainingXLNetModel(tf.keras.Model): masked_tokens = features["input_q"] seg_ids = features["seg_id"] if self._use_legacy_mask: + # Legacy input mask assumes `real` values are 0 and `padding` + # values are 1. perm_mask = 1 - features["perm_mask"] else: perm_mask = features["perm_mask"] @@ -885,6 +887,8 @@ class ClassificationXLNetModel(tf.keras.Model): input_ids = features["input_ids"] segment_ids = features["segment_ids"] if self._use_legacy_mask: + # Legacy input mask assumes `real` values are 0 and `padding` + # values are 1. input_mask = 1 - features["input_mask"] else: input_mask = features["input_mask"] @@ -1130,6 +1134,8 @@ class QAXLNetModel(tf.keras.Model): input_ids = features["input_ids"] segment_ids = features["segment_ids"] if self._use_legacy_mask: + # Legacy input mask assumes `real` values are 0 and `padding` + # values are 1. input_mask = 1 - features["input_mask"] else: input_mask = features["input_mask"] -- GitLab