未验证 提交 66b5da37 编写于 作者: S Steffy-zxf 提交者: GitHub

export text classification model in static graph (#5226)

export text classification model in static graph 
上级 8e356b03
...@@ -61,6 +61,7 @@ pip install paddlenlp==2.0.0b ...@@ -61,6 +61,7 @@ pip install paddlenlp==2.0.0b
```text ```text
pretrained_models/ pretrained_models/
├── export_model.py # 动态图参数导出静态图参数脚本
├── predict.py # 预测脚本 ├── predict.py # 预测脚本
├── README.md # 使用说明 ├── README.md # 使用说明
└── train.py # 训练评估脚本 └── train.py # 训练评估脚本
...@@ -109,6 +110,13 @@ checkpoints/ ...@@ -109,6 +110,13 @@ checkpoints/
**NOTE:** **NOTE:**
* 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams` * 如需恢复模型训练,则可以设置`init_from_ckpt`, 如`init_from_ckpt=checkpoints/model_100/model_state.pdparams`
* 如需使用ernie-tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece` * 如需使用ernie-tiny模型,则需要提前先安装sentencepiece依赖,如`pip install sentencepiece`
* 使用动态图训练结束之后,还可以将动态图参数导出成静态图参数,具体代码见export_model.py。静态图参数保存在`output_path`指定路径中。
运行方式:
```shell
python export_model.py --model_type=roberta --model_name=roberta-wwm-ext --params_path=./checkpoint/model_200/model_state.pdparams --output_path=./static_graph_params
```
其中`params_path`是指动态图训练保存的参数路径,`output_path`是指静态图参数导出路径。
### 模型预测 ### 模型预测
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import argparse
import os
import random
import time
import numpy as np
import paddle
import paddle.nn.functional as F
from paddlenlp.data import Stack, Tuple, Pad
import paddlenlp as ppnlp
MODEL_CLASSES = {
"bert": (ppnlp.transformers.BertForSequenceClassification,
ppnlp.transformers.BertTokenizer),
'ernie': (ppnlp.transformers.ErnieForSequenceClassification,
ppnlp.transformers.ErnieTokenizer),
'roberta': (ppnlp.transformers.RobertaForSequenceClassification,
ppnlp.transformers.RobertaTokenizer),
}
# yapf: disable
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--model_type", default='roberta', required=True, type=str, help="Model type selected in the list: " +", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default='roberta-wwm-ext', required=True, type=str, help="Path to pre-trained model or shortcut name selected in the list: " +
", ".join(sum([list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values()], [])))
parser.add_argument("--params_path", type=str, required=True, default='./checkpoint/model_200/model_state.pdparams', help="The path to model parameters to be loaded.")
parser.add_argument("--output_path", type=str, default='./static_graph_params', help="The path of model parameter in static graph to be saved.")
args = parser.parse_args()
return args
# yapf: enable
if __name__ == "__main__":
args = parse_args()
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.model_name_or_path == 'ernie-tiny':
# ErnieTinyTokenizer is special for ernie-tiny pretained model.
tokenizer = ppnlp.transformers.ErnieTinyTokenizer.from_pretrained(
args.model_name_or_path)
else:
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
# The number of labels should be in accordance with the training dataset.
label_map = {0: 'negative', 1: 'positive'}
model = model_class.from_pretrained(
args.model_name_or_path, num_classes=len(label_map))
if args.params_path and os.path.isfile(args.params_path):
state_dict = paddle.load(args.params_path)
model.set_dict(state_dict)
print("Loaded parameters from %s" % args.params_path)
model.eval()
# Convert to static graph with specific input description
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None, None], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, None], dtype="int64") # segment_ids
])
# Save in static graph model.
paddle.jit.save(model, args.output_path)
...@@ -32,8 +32,6 @@ MODEL_CLASSES = { ...@@ -32,8 +32,6 @@ MODEL_CLASSES = {
ppnlp.transformers.ErnieTokenizer), ppnlp.transformers.ErnieTokenizer),
'roberta': (ppnlp.transformers.RobertaForSequenceClassification, 'roberta': (ppnlp.transformers.RobertaForSequenceClassification,
ppnlp.transformers.RobertaTokenizer), ppnlp.transformers.RobertaTokenizer),
'electra': (ppnlp.transformers.ElectraForSequenceClassification,
ppnlp.transformers.ElectraTokenizer)
} }
......
...@@ -124,6 +124,7 @@ pip install paddlenlp==2.0.0b ...@@ -124,6 +124,7 @@ pip install paddlenlp==2.0.0b
```text ```text
rnn/ rnn/
├── export_model.py # 动态图参数导出静态图参数脚本
├── predict.py # 模型预测 ├── predict.py # 模型预测
├── utils.py # 数据处理工具 ├── utils.py # 数据处理工具
├── train.py # 训练模型主程序入口,包括训练、评估 ├── train.py # 训练模型主程序入口,包括训练、评估
...@@ -186,7 +187,17 @@ checkpoints/ ...@@ -186,7 +187,17 @@ checkpoints/
└── final.pdparams └── final.pdparams
``` ```
**NOTE:** 如需恢复模型训练,则init_from_ckpt只需指定到文件名即可,不需要添加文件尾缀。如`--init_from_ckpt=checkpoints/0`即可,程序会自动加载模型参数`checkpoints/0.pdparams`,也会自动加载优化器状态`checkpoints/0.pdopt` **NOTE:**
* 如需恢复模型训练,则init_from_ckpt只需指定到文件名即可,不需要添加文件尾缀。如`--init_from_ckpt=checkpoints/0`即可,程序会自动加载模型参数`checkpoints/0.pdparams`,也会自动加载优化器状态`checkpoints/0.pdopt`
* 使用动态图训练结束之后,还可以将动态图参数导出成静态图参数,具体代码见export_model.py。静态图参数保存在`output_path`指定路径中。
运行方式:
```shell
python export_model.py --vocab_path=./senta_word_dict.txt --network=bilstm --params_path=./checkpoints/final.pdparam --output_path=./static_graph_params
```
其中`params_path`是指动态图训练保存的参数路径,`output_path`是指静态图参数导出路径。
### 模型预测 ### 模型预测
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import paddle
import paddlenlp as ppnlp
from utils import load_vocab
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The path to vocabulary.")
parser.add_argument('--network', type=str, default="bilstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?")
parser.add_argument("--params_path", type=str, default='./checkpoints/final.pdparams', help="The path of model parameter to be loaded.")
parser.add_argument("--output_path", type=str, default='./static_graph_params', help="The path of model parameter in static graph to be saved.")
args = parser.parse_args()
# yapf: enable
def main():
# Load vocab.
vocab = load_vocab(args.vocab_path)
label_map = {0: 'negative', 1: 'positive'}
# Construct the newtork.
model = ppnlp.models.Senta(
network=args.network, vocab_size=len(vocab), num_classes=len(label_map))
# Load model parameters.
state_dict = paddle.load(args.params_path)
model.set_dict(state_dict)
model.eval()
inputs = [paddle.static.InputSpec(shape=[None, None], dtype="int64")]
# Convert to static graph with specific input description
if args.network in [
"lstm", "bilstm", "gru", "bigru", "rnn", "birnn", "bilstm_attn"
]:
inputs.append(paddle.static.InputSpec(
shape=[None], dtype="int64")) # seq_len
model = paddle.jit.to_static(model, input_spec=inputs)
# Save in static graph model.
paddle.jit.save(model, args.output_path)
if __name__ == "__main__":
main()
...@@ -15,16 +15,17 @@ import argparse ...@@ -15,16 +15,17 @@ import argparse
import paddle import paddle
import paddlenlp as ppnlp import paddlenlp as ppnlp
import paddle.nn.functional as F
from utils import load_vocab, generate_batch, preprocess_prediction_data from utils import load_vocab, generate_batch, preprocess_prediction_data
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=eval, default=False, help="Whether use GPU for training, input should be True or False") parser.add_argument("--use_gpu", type=eval, default=False, help="Whether use GPU for training, input should be True or False")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.") parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number of a batch for training.")
parser.add_argument("--vocab_path", type=str, default="./word_dict.txt", help="The path to vocabulary.") parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The path to vocabulary.")
parser.add_argument('--network', type=str, default="bilstm_attn", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?") parser.add_argument('--network', type=str, default="bilstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn, cnn and textcnn?")
parser.add_argument("--params_path", type=str, default='./chekpoints/final.pdparams', help="The path of model parameter to be loaded.") parser.add_argument("--params_path", type=str, default='./checkpoints/final.pdparams', help="The path of model parameter to be loaded.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
...@@ -66,7 +67,8 @@ def predict(model, data, label_map, collate_fn, batch_size=1, pad_token_id=0): ...@@ -66,7 +67,8 @@ def predict(model, data, label_map, collate_fn, batch_size=1, pad_token_id=0):
batch, pad_token_id=pad_token_id, return_label=False) batch, pad_token_id=pad_token_id, return_label=False)
texts = paddle.to_tensor(texts) texts = paddle.to_tensor(texts)
seq_lens = paddle.to_tensor(seq_lens) seq_lens = paddle.to_tensor(seq_lens)
probs = model(texts, seq_lens) logits = model(texts, seq_lens)
probs = F.softmax(logits, axis=1)
idx = paddle.argmax(probs, axis=1).numpy() idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist() idx = idx.tolist()
labels = [label_map[i] for i in idx] labels = [label_map[i] for i in idx]
......
...@@ -29,10 +29,10 @@ parser = argparse.ArgumentParser(__doc__) ...@@ -29,10 +29,10 @@ parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--epochs", type=int, default=10, help="Number of epoches for training.") parser.add_argument("--epochs", type=int, default=10, help="Number of epoches for training.")
parser.add_argument('--use_gpu', type=eval, default=False, help="Whether use GPU for training, input should be True or False") parser.add_argument('--use_gpu', type=eval, default=False, help="Whether use GPU for training, input should be True or False")
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate used to train.") parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate used to train.")
parser.add_argument("--save_dir", type=str, default='chekpoints/', help="Directory to save model checkpoint") parser.add_argument("--save_dir", type=str, default='checkpoints/', help="Directory to save model checkpoint")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.") parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.")
parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The directory to dataset.") parser.add_argument("--vocab_path", type=str, default="./senta_word_dict.txt", help="The directory to dataset.")
parser.add_argument('--network', type=str, default="bilstm_attn", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?") parser.add_argument('--network', type=str, default="bilstm", help="Which network you would like to choose bow, lstm, bilstm, gru, bigru, rnn, birnn, bilstm_attn and textcnn?")
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.") parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
args = parser.parse_args() args = parser.parse_args()
# yapf: enable # yapf: enable
......
...@@ -119,6 +119,7 @@ def convert_example(example, vocab, unk_token_id=1, is_test=False): ...@@ -119,6 +119,7 @@ def convert_example(example, vocab, unk_token_id=1, is_test=False):
token_id = vocab.get(token, unk_token_id) token_id = vocab.get(token, unk_token_id)
input_ids.append(token_id) input_ids.append(token_id)
valid_length = np.array(len(input_ids), dtype='int64') valid_length = np.array(len(input_ids), dtype='int64')
input_ids = np.array(input_ids, dtype='int64')
if not is_test: if not is_test:
label = np.array(example[-1], dtype="int64") label = np.array(example[-1], dtype="int64")
......
...@@ -17,6 +17,7 @@ import argparse ...@@ -17,6 +17,7 @@ import argparse
import paddle import paddle
import paddlenlp as ppnlp import paddlenlp as ppnlp
import paddle.nn.functional as F
from utils import load_vocab, generate_batch, preprocess_prediction_data from utils import load_vocab, generate_batch, preprocess_prediction_data
...@@ -70,7 +71,8 @@ def predict(model, data, label_map, collate_fn, batch_size=1, pad_token_id=0): ...@@ -70,7 +71,8 @@ def predict(model, data, label_map, collate_fn, batch_size=1, pad_token_id=0):
titles = paddle.to_tensor(titles) titles = paddle.to_tensor(titles)
query_seq_lens = paddle.to_tensor(query_seq_lens) query_seq_lens = paddle.to_tensor(query_seq_lens)
title_seq_lens = paddle.to_tensor(title_seq_lens) title_seq_lens = paddle.to_tensor(title_seq_lens)
probs = model(queries, titles, query_seq_lens, title_seq_lens) logits = model(queries, titles, query_seq_lens, title_seq_lens)
probs = F.softmax(logits, axis=1)
idx = paddle.argmax(probs, axis=1).numpy() idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist() idx = idx.tolist()
labels = [label_map[i] for i in idx] labels = [label_map[i] for i in idx]
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = '2.0.0b0' __version__ = '2.0.0b3'
from . import data from . import data
from . import datasets from . import datasets
......
...@@ -66,16 +66,13 @@ class Ernie(nn.Layer): ...@@ -66,16 +66,13 @@ class Ernie(nn.Layer):
if self.task in ['seq-cls', 'token-cls']: if self.task in ['seq-cls', 'token-cls']:
logits = self.model(input_ids, token_type_ids, position_ids, logits = self.model(input_ids, token_type_ids, position_ids,
attention_mask) attention_mask)
probs = F.softmax(logits, axis=-1) return logits
return probs
elif self.task == 'qa': elif self.task == 'qa':
start_logits, end_logits = self.model(input_ids, token_type_ids, start_logits, end_logits = self.model(input_ids, token_type_ids,
position_ids, attention_mask) position_ids, attention_mask)
start_position = paddle.unsqueeze(start_position, axis=-1) start_position = paddle.unsqueeze(start_position, axis=-1)
end_position = paddle.unsqueeze(end_position, axis=-1) end_position = paddle.unsqueeze(end_position, axis=-1)
start_probs = F.softmax(start_position, axis=-1) return start_position, end_position
end_probs = F.softmax(end_position, axis=-1)
return start_probs, end_probs
elif self.task is None: elif self.task is None:
sequence_output, pooled_output = self.model( sequence_output, pooled_output = self.model(
input_ids, token_type_ids, position_ids, attention_mask) input_ids, token_type_ids, position_ids, attention_mask)
......
...@@ -39,14 +39,14 @@ class Senta(nn.Layer): ...@@ -39,14 +39,14 @@ class Senta(nn.Layer):
vocab_size, vocab_size,
num_classes, num_classes,
emb_dim, emb_dim,
direction='bidirectional', direction='bidirect',
padding_idx=pad_token_id) padding_idx=pad_token_id)
elif network == 'bilstm': elif network == 'bilstm':
self.model = LSTMModel( self.model = LSTMModel(
vocab_size, vocab_size,
num_classes, num_classes,
emb_dim, emb_dim,
direction='bidirectional', direction='bidirect',
padding_idx=pad_token_id) padding_idx=pad_token_id)
elif network == 'bilstm_attn': elif network == 'bilstm_attn':
lstm_hidden_size = 196 lstm_hidden_size = 196
...@@ -63,7 +63,7 @@ class Senta(nn.Layer): ...@@ -63,7 +63,7 @@ class Senta(nn.Layer):
vocab_size, vocab_size,
num_classes, num_classes,
emb_dim, emb_dim,
direction='bidirectional', direction='bidirect',
padding_idx=pad_token_id) padding_idx=pad_token_id)
elif network == 'cnn': elif network == 'cnn':
self.model = CNNModel( self.model = CNNModel(
...@@ -102,8 +102,7 @@ class Senta(nn.Layer): ...@@ -102,8 +102,7 @@ class Senta(nn.Layer):
def forward(self, text, seq_len=None): def forward(self, text, seq_len=None):
logits = self.model(text, seq_len) logits = self.model(text, seq_len)
probs = F.softmax(logits, axis=-1) return logits
return probs
class BoWModel(nn.Layer): class BoWModel(nn.Layer):
...@@ -185,7 +184,7 @@ class LSTMModel(nn.Layer): ...@@ -185,7 +184,7 @@ class LSTMModel(nn.Layer):
# Shape: (batch_size, num_tokens, embedding_dim) # Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text) embedded_text = self.embedder(text)
# Shape: (batch_size, num_tokens, num_directions*lstm_hidden_size) # Shape: (batch_size, num_tokens, num_directions*lstm_hidden_size)
# num_directions = 2 if direction is 'bidirectional' # num_directions = 2 if direction is 'bidirect'
# if not, num_directions = 1 # if not, num_directions = 1
text_repr = self.lstm_encoder(embedded_text, sequence_length=seq_len) text_repr = self.lstm_encoder(embedded_text, sequence_length=seq_len)
# Shape: (batch_size, fc_hidden_size) # Shape: (batch_size, fc_hidden_size)
...@@ -226,7 +225,7 @@ class GRUModel(nn.Layer): ...@@ -226,7 +225,7 @@ class GRUModel(nn.Layer):
# Shape: (batch_size, num_tokens, embedding_dim) # Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text) embedded_text = self.embedder(text)
# Shape: (batch_size, num_tokens, num_directions*gru_hidden_size) # Shape: (batch_size, num_tokens, num_directions*gru_hidden_size)
# num_directions = 2 if direction is 'bidirectional' # num_directions = 2 if direction is 'bidirect'
# if not, num_directions = 1 # if not, num_directions = 1
text_repr = self.gru_encoder(embedded_text, sequence_length=seq_len) text_repr = self.gru_encoder(embedded_text, sequence_length=seq_len)
# Shape: (batch_size, fc_hidden_size) # Shape: (batch_size, fc_hidden_size)
...@@ -267,7 +266,7 @@ class RNNModel(nn.Layer): ...@@ -267,7 +266,7 @@ class RNNModel(nn.Layer):
# Shape: (batch_size, num_tokens, embedding_dim) # Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(text) embedded_text = self.embedder(text)
# Shape: (batch_size, num_tokens, num_directions*rnn_hidden_size) # Shape: (batch_size, num_tokens, num_directions*rnn_hidden_size)
# num_directions = 2 if direction is 'bidirectional' # num_directions = 2 if direction is 'bidirect'
# if not, num_directions = 1 # if not, num_directions = 1
text_repr = self.rnn_encoder(embedded_text, sequence_length=seq_len) text_repr = self.rnn_encoder(embedded_text, sequence_length=seq_len)
# Shape: (batch_size, fc_hidden_size) # Shape: (batch_size, fc_hidden_size)
...@@ -300,7 +299,7 @@ class BiLSTMAttentionModel(nn.Layer): ...@@ -300,7 +299,7 @@ class BiLSTMAttentionModel(nn.Layer):
hidden_size=lstm_hidden_size, hidden_size=lstm_hidden_size,
num_layers=lstm_layers, num_layers=lstm_layers,
dropout=dropout_rate, dropout=dropout_rate,
direction='bidirectional') direction='bidirect')
self.attention = attention_layer self.attention = attention_layer
if isinstance(attention_layer, SelfAttention): if isinstance(attention_layer, SelfAttention):
self.fc = nn.Linear(lstm_hidden_size, fc_hidden_size) self.fc = nn.Linear(lstm_hidden_size, fc_hidden_size)
...@@ -353,8 +352,8 @@ class SelfAttention(nn.Layer): ...@@ -353,8 +352,8 @@ class SelfAttention(nn.Layer):
# Shape: (batch_size, max_seq_len, hidden_size) # Shape: (batch_size, max_seq_len, hidden_size)
h = paddle.add_n([forward_input, backward_input]) h = paddle.add_n([forward_input, backward_input])
# Shape: (batch_size, hidden_size, 1) # Shape: (batch_size, hidden_size, 1)
att_weight = self.att_weight.expand(shape=(h.shape[0], self.hidden_size, att_weight = self.att_weight.tile(
1)) repeat_times=(paddle.shape(h)[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1) # Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(paddle.tanh(h), att_weight) att_score = paddle.bmm(paddle.tanh(h), att_weight)
if mask is not None: if mask is not None:
...@@ -398,16 +397,14 @@ class SelfInteractiveAttention(nn.Layer): ...@@ -398,16 +397,14 @@ class SelfInteractiveAttention(nn.Layer):
mask (obj: `paddle.Tensor`, optional, defaults to `None`) of shape (batch, seq_len) : mask (obj: `paddle.Tensor`, optional, defaults to `None`) of shape (batch, seq_len) :
Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not. Tensor is a bool tensor, whose each element identifies whether the input word id is pad token or not.
""" """
batch_size = input.shape[0] weight = self.input_weight.tile(
hidden_size = input.shape[2] repeat_times=(paddle.shape(input)[0], 1, 1))
weight = self.input_weight.expand(shape=(batch_size, hidden_size, bias = self.bias.tile(repeat_times=(paddle.shape(input)[0], 1, 1))
hidden_size))
bias = self.bias.expand(shape=(batch_size, 1, hidden_size))
# Shape: (batch_size, max_seq_len, hidden_size) # Shape: (batch_size, max_seq_len, hidden_size)
word_squish = paddle.bmm(input, weight) + bias word_squish = paddle.bmm(input, weight) + bias
att_context_vector = self.att_context_vector.expand(shape=( att_context_vector = self.att_context_vector.tile(
batch_size, hidden_size, 1)) repeat_times=(paddle.shape(input)[0], 1, 1))
# Shape: (batch_size, max_seq_len, 1) # Shape: (batch_size, max_seq_len, 1)
att_score = paddle.bmm(word_squish, att_context_vector) att_score = paddle.bmm(word_squish, att_context_vector)
if mask is not None: if mask is not None:
...@@ -415,7 +412,7 @@ class SelfInteractiveAttention(nn.Layer): ...@@ -415,7 +412,7 @@ class SelfInteractiveAttention(nn.Layer):
mask = paddle.cast(mask, dtype='float32') mask = paddle.cast(mask, dtype='float32')
mask = mask.unsqueeze(axis=-1) mask = mask.unsqueeze(axis=-1)
inf_tensor = paddle.full( inf_tensor = paddle.full(
shape=mask.shape, dtype='float32', fill_value=-INF) shape=paddle.shape(mask), dtype='float32', fill_value=-INF)
att_score = paddle.multiply(att_score, mask) + paddle.multiply( att_score = paddle.multiply(att_score, mask) + paddle.multiply(
inf_tensor, (1 - mask)) inf_tensor, (1 - mask))
att_weight = F.softmax(att_score, axis=1) att_weight = F.softmax(att_score, axis=1)
......
...@@ -56,8 +56,7 @@ class SimNet(nn.Layer): ...@@ -56,8 +56,7 @@ class SimNet(nn.Layer):
def forward(self, query, title, query_seq_len=None, title_seq_len=None): def forward(self, query, title, query_seq_len=None, title_seq_len=None):
logits = self.model(query, title, query_seq_len, title_seq_len) logits = self.model(query, title, query_seq_len, title_seq_len)
probs = F.softmax(logits, axis=-1) return logits
return probs
class BoWModel(nn.Layer): class BoWModel(nn.Layer):
......
...@@ -203,7 +203,7 @@ class GRUEncoder(nn.Layer): ...@@ -203,7 +203,7 @@ class GRUEncoder(nn.Layer):
A GRUEncoder takes as input a sequence of vectors and returns a A GRUEncoder takes as input a sequence of vectors and returns a
single vector, which is a combination of multiple GRU layers. single vector, which is a combination of multiple GRU layers.
The input to this module is of shape `(batch_size, num_tokens, input_size)`, The input to this module is of shape `(batch_size, num_tokens, input_size)`,
The output is of shape `(batch_size, hidden_size*2)` if GRU is bidirectional; The output is of shape `(batch_size, hidden_size*2)` if GRU is bidirection;
If not, output is of shape `(batch_size, hidden_size)`. If not, output is of shape `(batch_size, hidden_size)`.
Paddle's GRU have two outputs: the hidden state for every time step at last layer, Paddle's GRU have two outputs: the hidden state for every time step at last layer,
...@@ -211,7 +211,7 @@ class GRUEncoder(nn.Layer): ...@@ -211,7 +211,7 @@ class GRUEncoder(nn.Layer):
If `pooling_type` is None, we perform the pooling on the hidden state of every time If `pooling_type` is None, we perform the pooling on the hidden state of every time
step at last layer to create a single vector. If not None, we use the hidden state step at last layer to create a single vector. If not None, we use the hidden state
of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`); of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`);
And if direction is bidirectional, the we concat the hidden state of the last forward And if direction is bidirection, the we concat the hidden state of the last forward
gru and backward gru layer to create a single vector (shape of `(batch_size, hidden_size*2)`). gru and backward gru layer to create a single vector (shape of `(batch_size, hidden_size*2)`).
Args: Args:
...@@ -220,9 +220,9 @@ class GRUEncoder(nn.Layer): ...@@ -220,9 +220,9 @@ class GRUEncoder(nn.Layer):
num_layers (obj:`int`, optional, defaults to 1): Number of recurrent layers. num_layers (obj:`int`, optional, defaults to 1): Number of recurrent layers.
E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU,
with the second GRU taking in outputs of the first GRU and computing the final results. with the second GRU taking in outputs of the first GRU and computing the final results.
direction (obj:`str`, optional, defaults to obj:`forwrd`): The direction of the network. direction (obj:`str`, optional, defaults to obj:`forward`): The direction of the network.
It can be "forward" and "bidirectional". It can be "forward" and "bidirect" (it means bidirection network).
When "bidirectional", the way to merge outputs of forward and backward is concatenating. When "bidirect", the way to merge outputs of forward and backward is concatenating.
dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer
on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout.
pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None, pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None,
...@@ -266,7 +266,7 @@ class GRUEncoder(nn.Layer): ...@@ -266,7 +266,7 @@ class GRUEncoder(nn.Layer):
Returns the dimension of the final vector output by this `GRUEncoder`. This is not Returns the dimension of the final vector output by this `GRUEncoder`. This is not
the shape of the returned tensor, but the last element of that shape. the shape of the returned tensor, but the last element of that shape.
""" """
if self._direction == "bidirectional": if self._direction == "bidirect":
return self._hidden_size * 2 return self._hidden_size * 2
else: else:
return self._hidden_size return self._hidden_size
...@@ -276,7 +276,7 @@ class GRUEncoder(nn.Layer): ...@@ -276,7 +276,7 @@ class GRUEncoder(nn.Layer):
GRUEncoder takes the a sequence of vectors and and returns a GRUEncoder takes the a sequence of vectors and and returns a
single vector, which is a combination of multiple GRU layers. single vector, which is a combination of multiple GRU layers.
The input to this module is of shape `(batch_size, num_tokens, input_size)`, The input to this module is of shape `(batch_size, num_tokens, input_size)`,
The output is of shape `(batch_size, hidden_size*2)` if GRU is bidirectional; The output is of shape `(batch_size, hidden_size*2)` if GRU is bidirection;
If not, output is of shape `(batch_size, hidden_size)`. If not, output is of shape `(batch_size, hidden_size)`.
Args: Args:
...@@ -293,11 +293,11 @@ class GRUEncoder(nn.Layer): ...@@ -293,11 +293,11 @@ class GRUEncoder(nn.Layer):
if not self._pooling_type: if not self._pooling_type:
# We exploit the `last_hidden` (the hidden state at the last time step for every layer) # We exploit the `last_hidden` (the hidden state at the last time step for every layer)
# to create a single vector. # to create a single vector.
# If gru is not bidirectional, then output is the hidden state of the last time step # If gru is not bidirection, then output is the hidden state of the last time step
# at last layer. Output is shape of `(batch_size, hidden_size)`. # at last layer. Output is shape of `(batch_size, hidden_size)`.
# If gru is bidirectional, then output is concatenation of the forward and backward hidden state # If gru is bidirection, then output is concatenation of the forward and backward hidden state
# of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`. # of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`.
if self._direction != 'bidirectional': if self._direction != 'bidirect':
output = last_hidden[-1, :, :] output = last_hidden[-1, :, :]
else: else:
output = paddle.concat( output = paddle.concat(
...@@ -305,8 +305,8 @@ class GRUEncoder(nn.Layer): ...@@ -305,8 +305,8 @@ class GRUEncoder(nn.Layer):
else: else:
# We exploit the `encoded_text` (the hidden state at the every time step for last layer) # We exploit the `encoded_text` (the hidden state at the every time step for last layer)
# to create a single vector. We perform pooling on the encoded text. # to create a single vector. We perform pooling on the encoded text.
# If gru is not bidirectional, output is shape of `(batch_size, hidden_size)`. # If gru is not bidirection, output is shape of `(batch_size, hidden_size)`.
# If gru is bidirectional, then output is shape of `(batch_size, hidden_size*2)`. # If gru is bidirection, then output is shape of `(batch_size, hidden_size*2)`.
if self._pooling_type == 'sum': if self._pooling_type == 'sum':
output = paddle.sum(encoded_text, axis=1) output = paddle.sum(encoded_text, axis=1)
elif self._pooling_type == 'max': elif self._pooling_type == 'max':
...@@ -326,7 +326,7 @@ class LSTMEncoder(nn.Layer): ...@@ -326,7 +326,7 @@ class LSTMEncoder(nn.Layer):
A LSTMEncoder takes as input a sequence of vectors and returns a A LSTMEncoder takes as input a sequence of vectors and returns a
single vector, which is a combination of multiple LSTM layers. single vector, which is a combination of multiple LSTM layers.
The input to this module is of shape `(batch_size, num_tokens, input_size)`, The input to this module is of shape `(batch_size, num_tokens, input_size)`,
The output is of shape `(batch_size, hidden_size*2)` if LSTM is bidirectional; The output is of shape `(batch_size, hidden_size*2)` if LSTM is bidirection;
If not, output is of shape `(batch_size, hidden_size)`. If not, output is of shape `(batch_size, hidden_size)`.
Paddle's LSTM have two outputs: the hidden state for every time step at last layer, Paddle's LSTM have two outputs: the hidden state for every time step at last layer,
...@@ -334,7 +334,7 @@ class LSTMEncoder(nn.Layer): ...@@ -334,7 +334,7 @@ class LSTMEncoder(nn.Layer):
If `pooling_type` is None, we perform the pooling on the hidden state of every time If `pooling_type` is None, we perform the pooling on the hidden state of every time
step at last layer to create a single vector. If not None, we use the hidden state step at last layer to create a single vector. If not None, we use the hidden state
of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`); of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`);
And if direction is bidirectional, the we concat the hidden state of the last forward And if direction is bidirection, the we concat the hidden state of the last forward
lstm and backward lstm layer to create a single vector (shape of `(batch_size, hidden_size*2)`). lstm and backward lstm layer to create a single vector (shape of `(batch_size, hidden_size*2)`).
Args: Args:
...@@ -344,8 +344,8 @@ class LSTMEncoder(nn.Layer): ...@@ -344,8 +344,8 @@ class LSTMEncoder(nn.Layer):
E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM,
with the second LSTM taking in outputs of the first LSTM and computing the final results. with the second LSTM taking in outputs of the first LSTM and computing the final results.
direction (obj:`str`, optional, defaults to obj:`forwrd`): The direction of the network. direction (obj:`str`, optional, defaults to obj:`forwrd`): The direction of the network.
It can be "forward" and "bidirectional". It can be "forward" and "bidirect" (it means bidirection network).
When "bidirectional", the way to merge outputs of forward and backward is concatenating. When "bidirection", the way to merge outputs of forward and backward is concatenating.
dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer
on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout.
pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None, pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None,
...@@ -390,7 +390,7 @@ class LSTMEncoder(nn.Layer): ...@@ -390,7 +390,7 @@ class LSTMEncoder(nn.Layer):
Returns the dimension of the final vector output by this `LSTMEncoder`. This is not Returns the dimension of the final vector output by this `LSTMEncoder`. This is not
the shape of the returned tensor, but the last element of that shape. the shape of the returned tensor, but the last element of that shape.
""" """
if self._direction == "bidirectional": if self._direction == "bidirect":
return self._hidden_size * 2 return self._hidden_size * 2
else: else:
return self._hidden_size return self._hidden_size
...@@ -400,7 +400,7 @@ class LSTMEncoder(nn.Layer): ...@@ -400,7 +400,7 @@ class LSTMEncoder(nn.Layer):
LSTMEncoder takes the a sequence of vectors and and returns a LSTMEncoder takes the a sequence of vectors and and returns a
single vector, which is a combination of multiple LSTM layers. single vector, which is a combination of multiple LSTM layers.
The input to this module is of shape `(batch_size, num_tokens, input_size)`, The input to this module is of shape `(batch_size, num_tokens, input_size)`,
The output is of shape `(batch_size, hidden_size*2)` if LSTM is bidirectional; The output is of shape `(batch_size, hidden_size*2)` if LSTM is bidirection;
If not, output is of shape `(batch_size, hidden_size)`. If not, output is of shape `(batch_size, hidden_size)`.
Args: Args:
...@@ -417,11 +417,11 @@ class LSTMEncoder(nn.Layer): ...@@ -417,11 +417,11 @@ class LSTMEncoder(nn.Layer):
if not self._pooling_type: if not self._pooling_type:
# We exploit the `last_hidden` (the hidden state at the last time step for every layer) # We exploit the `last_hidden` (the hidden state at the last time step for every layer)
# to create a single vector. # to create a single vector.
# If lstm is not bidirectional, then output is the hidden state of the last time step # If lstm is not bidirection, then output is the hidden state of the last time step
# at last layer. Output is shape of `(batch_size, hidden_size)`. # at last layer. Output is shape of `(batch_size, hidden_size)`.
# If lstm is bidirectional, then output is concatenation of the forward and backward hidden state # If lstm is bidirection, then output is concatenation of the forward and backward hidden state
# of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`. # of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`.
if self._direction != 'bidirectional': if self._direction != 'bidirect':
output = last_hidden[-1, :, :] output = last_hidden[-1, :, :]
else: else:
output = paddle.concat( output = paddle.concat(
...@@ -429,8 +429,8 @@ class LSTMEncoder(nn.Layer): ...@@ -429,8 +429,8 @@ class LSTMEncoder(nn.Layer):
else: else:
# We exploit the `encoded_text` (the hidden state at the every time step for last layer) # We exploit the `encoded_text` (the hidden state at the every time step for last layer)
# to create a single vector. We perform pooling on the encoded text. # to create a single vector. We perform pooling on the encoded text.
# If lstm is not bidirectional, output is shape of `(batch_size, hidden_size)`. # If lstm is not bidirection, output is shape of `(batch_size, hidden_size)`.
# If lstm is bidirectional, then output is shape of `(batch_size, hidden_size*2)`. # If lstm is bidirection, then output is shape of `(batch_size, hidden_size*2)`.
if self._pooling_type == 'sum': if self._pooling_type == 'sum':
output = paddle.sum(encoded_text, axis=1) output = paddle.sum(encoded_text, axis=1)
elif self._pooling_type == 'max': elif self._pooling_type == 'max':
...@@ -450,7 +450,7 @@ class RNNEncoder(nn.Layer): ...@@ -450,7 +450,7 @@ class RNNEncoder(nn.Layer):
A RNNEncoder takes as input a sequence of vectors and returns a A RNNEncoder takes as input a sequence of vectors and returns a
single vector, which is a combination of multiple RNN layers. single vector, which is a combination of multiple RNN layers.
The input to this module is of shape `(batch_size, num_tokens, input_size)`, The input to this module is of shape `(batch_size, num_tokens, input_size)`,
The output is of shape `(batch_size, hidden_size*2)` if RNN is bidirectional; The output is of shape `(batch_size, hidden_size*2)` if RNN is bidirection;
If not, output is of shape `(batch_size, hidden_size)`. If not, output is of shape `(batch_size, hidden_size)`.
Paddle's RNN have two outputs: the hidden state for every time step at last layer, Paddle's RNN have two outputs: the hidden state for every time step at last layer,
...@@ -458,7 +458,7 @@ class RNNEncoder(nn.Layer): ...@@ -458,7 +458,7 @@ class RNNEncoder(nn.Layer):
If `pooling_type` is None, we perform the pooling on the hidden state of every time If `pooling_type` is None, we perform the pooling on the hidden state of every time
step at last layer to create a single vector. If not None, we use the hidden state step at last layer to create a single vector. If not None, we use the hidden state
of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`); of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`);
And if direction is bidirectional, the we concat the hidden state of the last forward And if direction is bidirection, the we concat the hidden state of the last forward
rnn and backward rnn layer to create a single vector (shape of `(batch_size, hidden_size*2)`). rnn and backward rnn layer to create a single vector (shape of `(batch_size, hidden_size*2)`).
Args: Args:
...@@ -468,8 +468,8 @@ class RNNEncoder(nn.Layer): ...@@ -468,8 +468,8 @@ class RNNEncoder(nn.Layer):
E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN,
with the second RNN taking in outputs of the first RNN and computing the final results. with the second RNN taking in outputs of the first RNN and computing the final results.
direction (obj:`str`, optional, defaults to obj:`forwrd`): The direction of the network. direction (obj:`str`, optional, defaults to obj:`forwrd`): The direction of the network.
It can be "forward" and "bidirectional". It can be "forward" and "bidirect" (it means bidirection network).
When "bidirectional", the way to merge outputs of forward and backward is concatenating. When "bidirection", the way to merge outputs of forward and backward is concatenating.
dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer
on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout.
pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None, pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None,
...@@ -514,7 +514,7 @@ class RNNEncoder(nn.Layer): ...@@ -514,7 +514,7 @@ class RNNEncoder(nn.Layer):
Returns the dimension of the final vector output by this `RNNEncoder`. This is not Returns the dimension of the final vector output by this `RNNEncoder`. This is not
the shape of the returned tensor, but the last element of that shape. the shape of the returned tensor, but the last element of that shape.
""" """
if self._direction == "bidirectional": if self._direction == "bidirect":
return self._hidden_size * 2 return self._hidden_size * 2
else: else:
return self._hidden_size return self._hidden_size
...@@ -524,7 +524,7 @@ class RNNEncoder(nn.Layer): ...@@ -524,7 +524,7 @@ class RNNEncoder(nn.Layer):
RNNEncoder takes the a sequence of vectors and and returns a RNNEncoder takes the a sequence of vectors and and returns a
single vector, which is a combination of multiple RNN layers. single vector, which is a combination of multiple RNN layers.
The input to this module is of shape `(batch_size, num_tokens, input_size)`, The input to this module is of shape `(batch_size, num_tokens, input_size)`,
The output is of shape `(batch_size, hidden_size*2)` if RNN is bidirectional; The output is of shape `(batch_size, hidden_size*2)` if RNN is bidirection;
If not, output is of shape `(batch_size, hidden_size)`. If not, output is of shape `(batch_size, hidden_size)`.
Args: Args:
...@@ -541,11 +541,11 @@ class RNNEncoder(nn.Layer): ...@@ -541,11 +541,11 @@ class RNNEncoder(nn.Layer):
if not self._pooling_type: if not self._pooling_type:
# We exploit the `last_hidden` (the hidden state at the last time step for every layer) # We exploit the `last_hidden` (the hidden state at the last time step for every layer)
# to create a single vector. # to create a single vector.
# If rnn is not bidirectional, then output is the hidden state of the last time step # If rnn is not bidirection, then output is the hidden state of the last time step
# at last layer. Output is shape of `(batch_size, hidden_size)`. # at last layer. Output is shape of `(batch_size, hidden_size)`.
# If rnn is bidirectional, then output is concatenation of the forward and backward hidden state # If rnn is bidirection, then output is concatenation of the forward and backward hidden state
# of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`. # of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`.
if self._direction != 'bidirectional': if self._direction != 'bidirect':
output = last_hidden[-1, :, :] output = last_hidden[-1, :, :]
else: else:
output = paddle.concat( output = paddle.concat(
...@@ -553,8 +553,8 @@ class RNNEncoder(nn.Layer): ...@@ -553,8 +553,8 @@ class RNNEncoder(nn.Layer):
else: else:
# We exploit the `encoded_text` (the hidden state at the every time step for last layer) # We exploit the `encoded_text` (the hidden state at the every time step for last layer)
# to create a single vector. We perform pooling on the encoded text. # to create a single vector. We perform pooling on the encoded text.
# If rnn is not bidirectional, output is shape of `(batch_size, hidden_size)`. # If rnn is not bidirection, output is shape of `(batch_size, hidden_size)`.
# If rnn is bidirectional, then output is shape of `(batch_size, hidden_size*2)`. # If rnn is bidirection, then output is shape of `(batch_size, hidden_size*2)`.
if self._pooling_type == 'sum': if self._pooling_type == 'sum':
output = paddle.sum(encoded_text, axis=1) output = paddle.sum(encoded_text, axis=1)
elif self._pooling_type == 'max': elif self._pooling_type == 'max':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册