diff --git a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py index 9e8fd5a7128564e3efa63d8ec0b45ba967308aaa..c53d087bcf739c6db73403535273fa7066058f08 100644 --- a/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/dynamic_attention_wrapper.py @@ -88,10 +88,13 @@ def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): rank = m.get_shape().ndims rank = rank if rank is not None else array_ops.rank(m) extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) - seq_len_mask = array_ops.reshape( - seq_len_mask, - array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) - return m * seq_len_mask if memory_sequence_length is not None else m + if memory_sequence_length is not None: + seq_len_mask = array_ops.reshape( + seq_len_mask, + array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) + return m * seq_len_mask + else: + return m return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)