未验证 提交 521dd7eb 编写于 作者: L liu zhengxi 提交者: GitHub

Add error message when parameter is set to 0 (#33859)

上级 a0a90798
......@@ -161,6 +161,12 @@ class MultiHeadAttention(Layer):
weight_attr=None,
bias_attr=None):
super(MultiHeadAttention, self).__init__()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but recieved {}".format(embed_dim))
assert num_heads > 0, ("Expected num_heads to be greater than 0, "
"but recieved {}".format(num_heads))
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
......@@ -501,6 +507,15 @@ class TransformerEncoderLayer(Layer):
self._config.pop("__class__", None) # py3
super(TransformerEncoderLayer, self).__init__()
assert d_model > 0, ("Expected d_model to be greater than 0, "
"but recieved {}".format(d_model))
assert nhead > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(nhead))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, "
"but recieved {}".format(dim_feedforward))
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
......@@ -797,6 +812,15 @@ class TransformerDecoderLayer(Layer):
self._config.pop("__class__", None) # py3
super(TransformerDecoderLayer, self).__init__()
assert d_model > 0, ("Expected d_model to be greater than 0, "
"but recieved {}".format(d_model))
assert nhead > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(nhead))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, "
"but recieved {}".format(dim_feedforward))
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
......@@ -1196,6 +1220,14 @@ class Transformer(Layer):
custom_decoder=None):
super(Transformer, self).__init__()
assert d_model > 0, ("Expected d_model to be greater than 0, "
"but recieved {}".format(d_model))
assert nhead > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(nhead))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, "
"but recieved {}".format(dim_feedforward))
if isinstance(bias_attr, (list, tuple)):
if len(bias_attr) == 1:
encoder_bias_attr = [bias_attr[0]] * 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册