未验证 提交 5675042d 编写于 作者: F Feng Xing 提交者: GitHub

replase pass with error exception (#35367)

This PR adds error exception in fused transformer python interface.
The function body are not implemented (will be implemented later).
Following zhiqiu's comment in previous PR-35206 (merged already), it is better to raise an exception instead of using "pass".
上级 c3ad7775
...@@ -68,6 +68,7 @@ class FusedMultiHeadAttention(Layer): ...@@ -68,6 +68,7 @@ class FusedMultiHeadAttention(Layer):
weight_attr=None, weight_attr=None,
bias_attr=None): bias_attr=None):
super(FusedMultiHeadAttention, self).__init__() super(FusedMultiHeadAttention, self).__init__()
raise NotImplementedError()
def forward(self, query, key=None, value=None, attn_mask=None, cache=None): def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
""" """
...@@ -119,7 +120,7 @@ class FusedMultiHeadAttention(Layer): ...@@ -119,7 +120,7 @@ class FusedMultiHeadAttention(Layer):
reserves tensors concatanating raw tensors with intermediate \ reserves tensors concatanating raw tensors with intermediate \
results of current query. results of current query.
""" """
pass raise NotImplementedError()
class FusedFeedForward(Layer): class FusedFeedForward(Layer):
...@@ -134,9 +135,10 @@ class FusedFeedForward(Layer): ...@@ -134,9 +135,10 @@ class FusedFeedForward(Layer):
bias_attr=None): bias_attr=None):
super(FusedFeedForward, self).__init__() super(FusedFeedForward, self).__init__()
raise NotImplementedError()
def forward(self, src, cache=None): def forward(self, src, cache=None):
pass raise NotImplementedError()
class FusedTransformerEncoderLayer(Layer): class FusedTransformerEncoderLayer(Layer):
...@@ -212,6 +214,7 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -212,6 +214,7 @@ class FusedTransformerEncoderLayer(Layer):
self._config.pop("__class__", None) # py3 self._config.pop("__class__", None) # py3
super(FusedTransformerEncoderLayer, self).__init__() super(FusedTransformerEncoderLayer, self).__init__()
raise NotImplementedError()
def forward(self, src, src_mask=None, cache=None): def forward(self, src, src_mask=None, cache=None):
""" """
...@@ -243,7 +246,7 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -243,7 +246,7 @@ class FusedTransformerEncoderLayer(Layer):
incremental length. See `MultiHeadAttention.gen_cache` and \ incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details. `MultiHeadAttention.forward` for more details.
""" """
pass raise NotImplementedError()
class FusedTransformer(Layer): class FusedTransformer(Layer):
...@@ -356,6 +359,7 @@ class FusedTransformer(Layer): ...@@ -356,6 +359,7 @@ class FusedTransformer(Layer):
custom_encoder=None, custom_encoder=None,
custom_decoder=None): custom_decoder=None):
super(fusedTransformer, self).__init__() super(fusedTransformer, self).__init__()
raise NotImplementedError()
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None): def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
pass raise NotImplementedError()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册