提交 0f176f6f 编写于 作者: A A. Unique TensorFlower

Merge pull request #7616 from Huawei-MRC-OSI:fix-bert-messages

PiperOrigin-RevId: 272700366
......@@ -563,11 +563,11 @@ class Dense3D(tf.keras.layers.Layer):
"""Implements build() for the layer."""
dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx())
if not (dtype.is_floating or dtype.is_complex):
raise TypeError("Unable to build `Dense` layer with non-floating point "
"dtype %s" % (dtype,))
raise TypeError("Unable to build `Dense3D` layer with non-floating "
"point (and non-complex) dtype %s" % (dtype,))
input_shape = tf.TensorShape(input_shape)
if tf.compat.dimension_value(input_shape[-1]) is None:
raise ValueError("The last dimension of the inputs to `Dense` "
raise ValueError("The last dimension of the inputs to `Dense3D` "
"should be defined. Found `None`.")
self.last_dim = tf.compat.dimension_value(input_shape[-1])
self.input_spec = tf.keras.layers.InputSpec(
......@@ -648,12 +648,14 @@ class Dense2DProjection(tf.keras.layers.Layer):
"""Implements build() for the layer."""
dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx())
if not (dtype.is_floating or dtype.is_complex):
raise TypeError("Unable to build `Dense` layer with non-floating point "
raise TypeError("Unable to build `Dense2DProjection` layer with "
"non-floating point (and non-complex) "
"dtype %s" % (dtype,))
input_shape = tf.TensorShape(input_shape)
if tf.compat.dimension_value(input_shape[-1]) is None:
raise ValueError("The last dimension of the inputs to `Dense` "
"should be defined. Found `None`.")
raise ValueError("The last dimension of the inputs to "
"`Dense2DProjection` should be defined. "
"Found `None`.")
last_dim = tf.compat.dimension_value(input_shape[-1])
self.input_spec = tf.keras.layers.InputSpec(min_ndim=3, axes={-1: last_dim})
self.kernel = self.add_weight(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册