From fb1c4cd8283f262bca95ccd04df6f9eb4ae1da0c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 17 Mar 2017 12:32:32 -0800 Subject: [PATCH] Add None check for seq_len_mask before reshape. Change: 150477638 --- .../seq2seq/python/ops/dynamic_attention_wrapper.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py index 9e8fd5a7128..c53d087bcf7 100644 --- a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py @@ -88,10 +88,13 @@ def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask if memory_sequence_length is not None else m + if memory_sequence_length is not None: + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask + else: + return m return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) -- GitLab