提交 c923d629 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 326947209
上级 3ea89d24
......@@ -457,7 +457,8 @@ class SequenceBeamSearch(tf.Module):
_StateKeys.ALIVE_LOG_PROBS:
tf.TensorShape([batch_size, self.beam_size]),
_StateKeys.ALIVE_CACHE:
tf.nest.map_structure(_get_shape, alive_cache),
tf.nest.map_structure(lambda state: state.get_shape(),
alive_cache),
_StateKeys.FINISHED_SEQ:
tf.TensorShape(
[batch_size, self.beam_size, self.max_decode_length + 1]),
......@@ -629,11 +630,6 @@ def _get_shape_keep_last_dim(tensor):
return tf.TensorShape(shape_list)
def _get_shape(tensor):
"""Return the shape of the input tensor."""
return tf.TensorShape(_shape_list(tensor))
def _flatten_beam_dim(tensor):
"""Reshapes first two dimensions in to single dimension.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册