diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index 891177532a438993a0291ef99d548af04a651ff6..5aba8ae85ad1b32a35de48cddc8dadd5d3929e70 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -161,6 +161,12 @@ class MultiHeadAttention(Layer): weight_attr=None, bias_attr=None): super(MultiHeadAttention, self).__init__() + + assert embed_dim > 0, ("Expected embed_dim to be greater than 0, " + "but recieved {}".format(embed_dim)) + assert num_heads > 0, ("Expected num_heads to be greater than 0, " + "but recieved {}".format(num_heads)) + self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim @@ -501,6 +507,15 @@ class TransformerEncoderLayer(Layer): self._config.pop("__class__", None) # py3 super(TransformerEncoderLayer, self).__init__() + + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before @@ -797,6 +812,15 @@ class TransformerDecoderLayer(Layer): self._config.pop("__class__", None) # py3 super(TransformerDecoderLayer, self).__init__() + + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + attn_dropout = dropout if attn_dropout is None else attn_dropout act_dropout = dropout if act_dropout is None else act_dropout self.normalize_before = normalize_before @@ -1196,6 +1220,14 @@ class Transformer(Layer): custom_decoder=None): super(Transformer, self).__init__() + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + if isinstance(bias_attr, (list, tuple)): if len(bias_attr) == 1: encoder_bias_attr = [bias_attr[0]] * 2