提交 4ed7a014 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Fixing RNN compute_dtype in v1.

In v1, since there isn't a global policy, the layer compute_dtype will be "_inferred" from input, and the inferred dtype are actually populate on the cell.

PiperOrigin-RevId: 394779149
上级 1e859f75
......@@ -880,7 +880,9 @@ class RNN(Layer):
else:
initial_state = self.states
initial_state = tf.nest.map_structure(
lambda v: tf.cast(v, self.compute_dtype), initial_state
# When the layer has a inferred dtype, use the dtype from the cell.
lambda v: tf.cast(v, self.compute_dtype or self.cell.compute_dtype),
initial_state
)
elif initial_state is None:
initial_state = self.get_initial_state(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册