提交 054ba614 编写于 作者: L LiuChiaChi

correct some args

上级 5a50806e
......@@ -34,12 +34,6 @@ def parse_args():
parser.add_argument("--src_lang", type=str, help="source language suffix")
parser.add_argument("--tar_lang", type=str, help="target language suffix")
parser.add_argument(
"--attention",
type=eval,
default=False,
help="Whether use attention model")
parser.add_argument(
"--optimizer",
type=str,
......
......@@ -28,13 +28,12 @@ class AttentionModel(Layer):
trg_vocab_size,
num_layers=1,
init_scale=0.1,
padding_idx=2,
padding_idx=0,
dropout=None,
beam_size=1,
beam_start_token=1,
beam_end_token=2,
beam_max_step_num=100,
mode='train',
dtype="float32"):
super(AttentionModel, self).__init__()
self.hidden_size = hidden_size
......@@ -47,7 +46,6 @@ class AttentionModel(Layer):
self.beam_start_token = beam_start_token
self.beam_end_token = beam_end_token
self.beam_max_step_num = beam_max_step_num
self.mode = mode
self.kinf = 1e9
self.encoder = Encoder(src_vocab_size, hidden_size, num_layers,
......
......@@ -125,14 +125,14 @@ def raw_data(src_lang,
src_vocab = _build_vocab(src_vocab_file)
tar_vocab = _build_vocab(tar_vocab_file)
train_src, train_tar = _para_file_to_ids( src_train_file, tar_train_file, \
src_vocab, tar_vocab )
train_src, train_tar = _para_file_to_ids(src_train_file, tar_train_file, \
src_vocab, tar_vocab)
train_src, train_tar = filter_len(
train_src, train_tar, max_sequence_len=max_sequence_len)
eval_src, eval_tar = _para_file_to_ids( src_eval_file, tar_eval_file, \
src_vocab, tar_vocab )
eval_src, eval_tar = _para_file_to_ids(src_eval_file, tar_eval_file, \
src_vocab, tar_vocab)
test_src, test_tar = _para_file_to_ids( src_test_file, tar_test_file, \
test_src, test_tar = _para_file_to_ids(src_test_file, tar_test_file, \
src_vocab, tar_vocab )
return (train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\
......@@ -143,8 +143,8 @@ def raw_mono_data(vocab_file, file_path):
src_vocab = _build_vocab(vocab_file)
test_src, test_tar = _para_file_to_ids( file_path, file_path, \
src_vocab, src_vocab )
test_src, test_tar = _para_file_to_ids(file_path, file_path, \
src_vocab, src_vocab)
return (test_src, test_tar)
......@@ -160,6 +160,7 @@ class IWSLTDataset(Dataset):
src_data, trg_data = raw_data
data_pair = []
for src, trg in zip(src_data, trg_data):
if len(src) > 0:
data_pair.append([src, trg])
sorted_data_pair = sorted(data_pair, key=lambda k: len(k[0]))
......
......@@ -3,13 +3,12 @@ export CUDA_VISIBLE_DEVICES=0
python train.py \
--src_lang en --tar_lang vi \
--attention True \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.0 \
--dropout 0.2 \
--init_scale 0.2 \
--max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \
......@@ -20,6 +19,7 @@ python train.py \
--model_path attention_models \
--enable_ce \
--learning_rate 0.002 \
--dtype float64 \
--dtype float32 \
--optimizer adam \
--max_epoch 1
--max_epoch 12 \
--padding_idx 2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册