From 5675042dac638cea3ce2c095202519ee3b44874b Mon Sep 17 00:00:00 2001 From: Feng Xing <79969986+xingfeng01@users.noreply.github.com> Date: Mon, 6 Sep 2021 10:32:26 +0800 Subject: [PATCH] replase pass with error exception (#35367) This PR adds error exception in fused transformer python interface. The function body are not implemented (will be implemented later). Following zhiqiu's comment in previous PR-35206 (merged already), it is better to raise an exception instead of using "pass". --- python/paddle/nn/layer/fused_transformer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/layer/fused_transformer.py b/python/paddle/nn/layer/fused_transformer.py index 6b24346c2b..0084f7ff33 100644 --- a/python/paddle/nn/layer/fused_transformer.py +++ b/python/paddle/nn/layer/fused_transformer.py @@ -68,6 +68,7 @@ class FusedMultiHeadAttention(Layer): weight_attr=None, bias_attr=None): super(FusedMultiHeadAttention, self).__init__() + raise NotImplementedError() def forward(self, query, key=None, value=None, attn_mask=None, cache=None): """ @@ -119,7 +120,7 @@ class FusedMultiHeadAttention(Layer): reserves tensors concatanating raw tensors with intermediate \ results of current query. """ - pass + raise NotImplementedError() class FusedFeedForward(Layer): @@ -134,9 +135,10 @@ class FusedFeedForward(Layer): bias_attr=None): super(FusedFeedForward, self).__init__() + raise NotImplementedError() def forward(self, src, cache=None): - pass + raise NotImplementedError() class FusedTransformerEncoderLayer(Layer): @@ -212,6 +214,7 @@ class FusedTransformerEncoderLayer(Layer): self._config.pop("__class__", None) # py3 super(FusedTransformerEncoderLayer, self).__init__() + raise NotImplementedError() def forward(self, src, src_mask=None, cache=None): """ @@ -243,7 +246,7 @@ class FusedTransformerEncoderLayer(Layer): incremental length. See `MultiHeadAttention.gen_cache` and \ `MultiHeadAttention.forward` for more details. """ - pass + raise NotImplementedError() class FusedTransformer(Layer): @@ -356,6 +359,7 @@ class FusedTransformer(Layer): custom_encoder=None, custom_decoder=None): super(fusedTransformer, self).__init__() + raise NotImplementedError() def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): - pass + raise NotImplementedError() -- GitLab