From 521dd7eb48b8e19ebab38e5cff408017d2de2ada Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Thu, 1 Jul 2021 11:35:44 +0800 Subject: [PATCH] Add error message when parameter is set to 0 (#33859) --- python/paddle/nn/layer/transformer.py | 32 +++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 891177532a..5aba8ae85a 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -161,6 +161,12 @@ class MultiHeadAttention(Layer): weight_attr=None, bias_attr=None): super(MultiHeadAttention, self).__init__() + + assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " + "but recieved {}".format(embed_dim)) + assert num_heads > 0, ("Expected num_heads to be greater than 0, " + "but recieved {}".format(num_heads)) + self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim @@ -501,6 +507,15 @@ class TransformerEncoderLayer(Layer): self._config.pop("__class__", None) # py3 super(TransformerEncoderLayer, self).__init__() + + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before @@ -797,6 +812,15 @@ class TransformerDecoderLayer(Layer): self._config.pop("__class__", None) # py3 super(TransformerDecoderLayer, self).__init__() + + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before @@ -1196,6 +1220,14 @@ class Transformer(Layer): custom_decoder=None): super(Transformer, self).__init__() + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + if isinstance(bias_attr, (list, tuple)): if len(bias_attr) == 1: encoder_bias_attr = [bias_attr[0]] * 2 -- GitLab