提交 fb1c4cd8 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add None check for seq_len_mask before reshape.

Change: 150477638
上级 547a5402
......@@ -88,10 +88,13 @@ def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined):
rank = m.get_shape().ndims
rank = rank if rank is not None else array_ops.rank(m)
extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32)
seq_len_mask = array_ops.reshape(
seq_len_mask,
array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0))
return m * seq_len_mask if memory_sequence_length is not None else m
if memory_sequence_length is not None:
seq_len_mask = array_ops.reshape(
seq_len_mask,
array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0))
return m * seq_len_mask
else:
return m
return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册