# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import six
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import random
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.io import DataLoader
from paddle.fluid.dygraph_grad_clip import GradClipByGlobalNorm
import reader
from args import parse_args
from seq2seq_base import BaseModel, CrossEntropyCriterion
from seq2seq_attn import AttentionModel
from model import Input, set_device
from callbacks import ProgBarLogger
from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_train_input
def do_train(args):
device = set_device("gpu" if args.use_gpu else "cpu")
fluid.enable_dygraph(device) if args.eager_run else None
if args.enable_ce:
fluid.default_main_program().random_seed = 102
fluid.default_startup_program().random_seed = 102
args.shuffle = False
# define model
inputs = [
Input(
[None, None], "int64", name="src_word"),
Input(
[None], "int64", name="src_length"),
Input(
[None, None], "int64", name="trg_word"),
Input(
[None], "int64", name="trg_length"),
]
labels = [Input([None, None, 1], "int64", name="label"), ]
# def dataloader
data_loaders = [None, None]
data_prefixes = [args.train_data_prefix, args.eval_data_prefix
] if args.eval_data_prefix else [args.train_data_prefix]
for i, data_prefix in enumerate(data_prefixes):
dataset = Seq2SeqDataset(
fpattern=data_prefix + "." + args.src_lang,
trg_fpattern=data_prefix + "." + args.tar_lang,
src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang,
token_delimiter=None,
start_mark="",
end_mark="",
unk_mark="")
(args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.batch_size * 20,
sort_type=SortType.POOL,
shuffle=args.shuffle)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None if fluid.in_dygraph_mode() else
[x.forward() for x in inputs + labels],
collate_fn=partial(
prepare_train_input,
bos_id=bos_id,
eos_id=eos_id,
pad_id=eos_id),
num_workers=0,
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
model = AttentionModel(args.src_vocab_size, args.tar_vocab_size,
args.hidden_size, args.hidden_size, args.num_layers,
args.dropout)
model.prepare(
fluid.optimizer.Adam(
learning_rate=args.learning_rate,
parameter_list=model.parameters()),
CrossEntropyCriterion(),
inputs=inputs,
labels=labels)
model.fit(train_data=train_loader,
eval_data=eval_loader,
epochs=1,
eval_freq=1,
save_freq=1,
log_freq=1,
verbose=2)
if __name__ == "__main__":
args = parse_args()
do_train(args)