提交 054ba614 编写于 作者: L LiuChiaChi

correct some args

上级 5a50806e
...@@ -34,12 +34,6 @@ def parse_args(): ...@@ -34,12 +34,6 @@ def parse_args():
parser.add_argument("--src_lang", type=str, help="source language suffix") parser.add_argument("--src_lang", type=str, help="source language suffix")
parser.add_argument("--tar_lang", type=str, help="target language suffix") parser.add_argument("--tar_lang", type=str, help="target language suffix")
parser.add_argument(
"--attention",
type=eval,
default=False,
help="Whether use attention model")
parser.add_argument( parser.add_argument(
"--optimizer", "--optimizer",
type=str, type=str,
......
...@@ -28,13 +28,12 @@ class AttentionModel(Layer): ...@@ -28,13 +28,12 @@ class AttentionModel(Layer):
trg_vocab_size, trg_vocab_size,
num_layers=1, num_layers=1,
init_scale=0.1, init_scale=0.1,
padding_idx=2, padding_idx=0,
dropout=None, dropout=None,
beam_size=1, beam_size=1,
beam_start_token=1, beam_start_token=1,
beam_end_token=2, beam_end_token=2,
beam_max_step_num=100, beam_max_step_num=100,
mode='train',
dtype="float32"): dtype="float32"):
super(AttentionModel, self).__init__() super(AttentionModel, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -47,7 +46,6 @@ class AttentionModel(Layer): ...@@ -47,7 +46,6 @@ class AttentionModel(Layer):
self.beam_start_token = beam_start_token self.beam_start_token = beam_start_token
self.beam_end_token = beam_end_token self.beam_end_token = beam_end_token
self.beam_max_step_num = beam_max_step_num self.beam_max_step_num = beam_max_step_num
self.mode = mode
self.kinf = 1e9 self.kinf = 1e9
self.encoder = Encoder(src_vocab_size, hidden_size, num_layers, self.encoder = Encoder(src_vocab_size, hidden_size, num_layers,
......
...@@ -125,14 +125,14 @@ def raw_data(src_lang, ...@@ -125,14 +125,14 @@ def raw_data(src_lang,
src_vocab = _build_vocab(src_vocab_file) src_vocab = _build_vocab(src_vocab_file)
tar_vocab = _build_vocab(tar_vocab_file) tar_vocab = _build_vocab(tar_vocab_file)
train_src, train_tar = _para_file_to_ids( src_train_file, tar_train_file, \ train_src, train_tar = _para_file_to_ids(src_train_file, tar_train_file, \
src_vocab, tar_vocab ) src_vocab, tar_vocab)
train_src, train_tar = filter_len( train_src, train_tar = filter_len(
train_src, train_tar, max_sequence_len=max_sequence_len) train_src, train_tar, max_sequence_len=max_sequence_len)
eval_src, eval_tar = _para_file_to_ids( src_eval_file, tar_eval_file, \ eval_src, eval_tar = _para_file_to_ids(src_eval_file, tar_eval_file, \
src_vocab, tar_vocab ) src_vocab, tar_vocab)
test_src, test_tar = _para_file_to_ids( src_test_file, tar_test_file, \ test_src, test_tar = _para_file_to_ids(src_test_file, tar_test_file, \
src_vocab, tar_vocab ) src_vocab, tar_vocab )
return (train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\ return (train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\
...@@ -143,8 +143,8 @@ def raw_mono_data(vocab_file, file_path): ...@@ -143,8 +143,8 @@ def raw_mono_data(vocab_file, file_path):
src_vocab = _build_vocab(vocab_file) src_vocab = _build_vocab(vocab_file)
test_src, test_tar = _para_file_to_ids( file_path, file_path, \ test_src, test_tar = _para_file_to_ids(file_path, file_path, \
src_vocab, src_vocab ) src_vocab, src_vocab)
return (test_src, test_tar) return (test_src, test_tar)
...@@ -160,6 +160,7 @@ class IWSLTDataset(Dataset): ...@@ -160,6 +160,7 @@ class IWSLTDataset(Dataset):
src_data, trg_data = raw_data src_data, trg_data = raw_data
data_pair = [] data_pair = []
for src, trg in zip(src_data, trg_data): for src, trg in zip(src_data, trg_data):
if len(src) > 0:
data_pair.append([src, trg]) data_pair.append([src, trg])
sorted_data_pair = sorted(data_pair, key=lambda k: len(k[0])) sorted_data_pair = sorted(data_pair, key=lambda k: len(k[0]))
......
...@@ -3,13 +3,12 @@ export CUDA_VISIBLE_DEVICES=0 ...@@ -3,13 +3,12 @@ export CUDA_VISIBLE_DEVICES=0
python train.py \ python train.py \
--src_lang en --tar_lang vi \ --src_lang en --tar_lang vi \
--attention True \
--num_layers 2 \ --num_layers 2 \
--hidden_size 512 \ --hidden_size 512 \
--src_vocab_size 17191 \ --src_vocab_size 17191 \
--tar_vocab_size 7709 \ --tar_vocab_size 7709 \
--batch_size 128 \ --batch_size 128 \
--dropout 0.0 \ --dropout 0.2 \
--init_scale 0.2 \ --init_scale 0.2 \
--max_grad_norm 5.0 \ --max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \ --train_data_prefix data/en-vi/train \
...@@ -20,6 +19,7 @@ python train.py \ ...@@ -20,6 +19,7 @@ python train.py \
--model_path attention_models \ --model_path attention_models \
--enable_ce \ --enable_ce \
--learning_rate 0.002 \ --learning_rate 0.002 \
--dtype float64 \ --dtype float32 \
--optimizer adam \ --optimizer adam \
--max_epoch 1 --max_epoch 12 \
--padding_idx 2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册