未验证 提交 317f7ce2 编写于 作者: G Guo Sheng 提交者: GitHub

[API 2.0] Add transformer apis (#26418)

* Add MultiHeadAttention api.
test=develop

* Add MultiHeadAttention cache type and gen_cache.
test=develop

* Add TransformerEncoderLayer and TransformerEncoder.
test=develop

* Add Transformer decoder apis.
test=develop

* Add Transformer api.
test=develop

* add unittests for transformer api

* add unittests for transformer api

* Fix some bugs in Transformer apis.
test=develop

* add unittests for encoder, decoder and transformer

* clean conflicts infor in code

* clean Chinese comments

* Add TransformerDecoderCell and TransformerBeamSearchDecoder.
test=develop

* Remove TransformerDecoderCell and TransformerBeamSearchDecoder temporarily.
test=develop

* Add import for Transformer apis.
test=develop

* Update usage of weight_attr and Tensor in Transformer api docs.
test=develop

* Update Transformer apis by renaming MultiheadAttention and cal_kv according to comments.
test=develop

* Fix MultiHeadAttention in test_transformer_api.py.
test=develop
Co-authored-by: NLiuChiaChi <709153940@qq.com>
上级 8645591d
...@@ -130,6 +130,12 @@ from .layer.norm import InstanceNorm #DEFINE_ALIAS ...@@ -130,6 +130,12 @@ from .layer.norm import InstanceNorm #DEFINE_ALIAS
# from .layer.rnn import RNNCell #DEFINE_ALIAS # from .layer.rnn import RNNCell #DEFINE_ALIAS
# from .layer.rnn import GRUCell #DEFINE_ALIAS # from .layer.rnn import GRUCell #DEFINE_ALIAS
# from .layer.rnn import LSTMCell #DEFINE_ALIAS # from .layer.rnn import LSTMCell #DEFINE_ALIAS
from .layer.transformer import MultiHeadAttention
from .layer.transformer import TransformerEncoderLayer
from .layer.transformer import TransformerEncoder
from .layer.transformer import TransformerDecoderLayer
from .layer.transformer import TransformerDecoder
from .layer.transformer import Transformer
from .layer.distance import PairwiseDistance #DEFINE_ALIAS from .layer.distance import PairwiseDistance #DEFINE_ALIAS
from .layer import loss #DEFINE_ALIAS from .layer import loss #DEFINE_ALIAS
......
...@@ -21,6 +21,7 @@ from . import extension ...@@ -21,6 +21,7 @@ from . import extension
from . import activation from . import activation
from . import norm from . import norm
from . import distance from . import distance
from . import transformer
from .activation import * from .activation import *
from .loss import * from .loss import *
...@@ -28,6 +29,7 @@ from .conv import * ...@@ -28,6 +29,7 @@ from .conv import *
from .extension import * from .extension import *
from .activation import * from .activation import *
from .norm import * from .norm import *
from .transformer import *
# from .activation import PReLU #DEFINE_ALIAS # from .activation import PReLU #DEFINE_ALIAS
from .activation import ReLU #DEFINE_ALIAS from .activation import ReLU #DEFINE_ALIAS
from .activation import LeakyReLU #DEFINE_ALIAS from .activation import LeakyReLU #DEFINE_ALIAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册