提交 dd3d25ce 编写于 作者: L LiuChiaChi

add padding idx for emb and delete useless code in reader

上级 7d018ca8
......@@ -56,7 +56,7 @@ def parse_args():
"--padding_idx",
type=int,
default=0,
help="padding_idx of embedding")
help="padding index of embedding")
parser.add_argument(
"--num_layers",
......
......@@ -15,7 +15,7 @@
import numpy as np
import paddle
from paddle.nn import Layer, Linear, Dropout, Embedding, LayerList, RNN, LSTM, LSTMCell, RNNCellBase
from paddle.nn.initializer import Uniform
import paddle.nn.initializer as I
import paddle.nn.functional as F
SEED = 102
paddle.framework.manual_seed(SEED)
......@@ -84,9 +84,10 @@ class Encoder(Layer):
self.embedder = Embedding(
vocab_size,
hidden_size,
padding_idx=padding_idx,
weight_attr=paddle.ParamAttr(
name='source_embedding',
initializer=Uniform(
initializer=I.Uniform(
low=-init_scale, high=init_scale)))
self.lstm = LSTM(
input_size=hidden_size,
......@@ -111,13 +112,13 @@ class AttentionLayer(Layer):
self.input_proj = Linear(
hidden_size,
hidden_size,
weight_attr=paddle.ParamAttr(initializer=Uniform(
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)),
bias_attr=bias)
self.output_proj = Linear(
hidden_size + hidden_size,
hidden_size,
weight_attr=paddle.ParamAttr(initializer=Uniform(
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)),
bias_attr=bias)
......@@ -202,9 +203,10 @@ class Decoder(Layer):
self.embedder = Embedding(
vocab_size,
hidden_size,
padding_idx=padding_idx,
weight_attr=paddle.ParamAttr(
name='target_embedding',
initializer=Uniform(
initializer=I.Uniform(
low=-init_scale, high=init_scale)))
self.dropout = dropout
self.lstm_attention = RNN(DecoderCell(hidden_size, hidden_size,
......@@ -214,7 +216,7 @@ class Decoder(Layer):
self.fc = Linear(
hidden_size,
vocab_size,
weight_attr=paddle.ParamAttr(initializer=Uniform(
weight_attr=paddle.ParamAttr(initializer=I.Uniform(
low=-init_scale, high=init_scale)),
bias_attr=False)
......
/models/dygraph/seq2seq_attn/data
\ No newline at end of file
......@@ -177,17 +177,6 @@ class IWSLTDataset(Dataset):
def __len__(self):
return self.num_samples
def get_max_seq_len(self):
src_max_seq_len = 0
trg_max_seq_len = 0
for data in self.src_data:
src_max_seq_len = max(src_max_seq_len, len(data))
for data in self.tar_data:
trg_max_seq_len = max(trg_max_seq_len, len(data))
src_max_seq_len = min(src_max_seq_len, self.max_seq_len)
trg_max_seq_len = min(trg_max_seq_len, self.max_seq_len)
return src_max_seq_len, trg_max_seq_len
class DataCollector():
def __init__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册