Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
ee442428
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ee442428
编写于
4月 08, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix distribute BatchSampler
上级
6431daed
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
172 addition
and
146 deletion
+172
-146
transformer/reader.py
transformer/reader.py
+162
-94
transformer/train.py
transformer/train.py
+10
-52
未找到文件。
transformer/reader.py
浏览文件 @
ee442428
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -15,8 +15,9 @@
import
glob
import
six
import
os
import
tarfile
import
io
import
itertools
from
functools
import
partial
import
numpy
as
np
import
paddle.fluid
as
fluid
...
...
@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
,
Dataset
def
prepare_train_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
def
create_data_loader
(
args
,
device
):
data_loaders
=
[
None
,
None
]
data_files
=
[
args
.
training_file
,
args
.
validation_file
]
if
args
.
validation_file
else
[
args
.
training_file
]
for
i
,
data_file
in
enumerate
(
data_files
):
dataset
=
Seq2SeqDataset
(
fpattern
=
data_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
byte_data
=
True
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
max_length
=
args
.
max_length
,
distribute_mode
=
True
if
i
==
0
else
False
)
# every device eval all data
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
collate_fn
=
partial
(
prepare_train_input
,
bos_idx
=
args
.
bos_idx
,
eos_idx
=
args
.
eos_idx
,
src_pad_idx
=
args
.
eos_idx
,
trg_pad_idx
=
args
.
eos_idx
,
n_head
=
args
.
n_head
),
num_workers
=
0
,
# TODO: use multi-process
return_list
=
True
)
data_loaders
[
i
]
=
data_loader
return
data_loaders
def
prepare_train_input
(
insts
,
bos_idx
,
eos_idx
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
"""
Put all padded data needed by training into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
[
inst
[
0
]
+
[
eos_idx
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
[[
bos_idx
]
+
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
)
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
)
...
...
@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
[
inst
[
1
]
+
[
eos_idx
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
False
,
...
...
@@ -71,9 +123,7 @@ def prepare_infer_input(insts, src_pad_idx, n_head):
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_src_attn_bias
]
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_src_attn_bias
]
return
data_inputs
...
...
@@ -142,29 +192,30 @@ class SortType(object):
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
,
add_end
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_add_beg
=
add_beg
self
.
_add_end
=
add_end
def
__call__
(
self
,
sentence
):
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
[
self
.
_end
]
]
+
([
self
.
_end
]
if
self
.
_add_end
else
[])
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
parallel_sentence
):
def
__call__
(
self
,
fields
):
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
]
)
for
i
in
range
(
len
(
self
.
_converters
)
)
converter
(
field
)
for
field
,
converter
in
zip
(
fields
,
self
.
_converters
)
]
...
...
@@ -201,10 +252,11 @@ class TokenBatchCreator(object):
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
def
__init__
(
self
,
i
,
lens
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
# take bos and eos into account
self
.
min_len
=
min
(
lens
[
0
]
+
1
,
lens
[
1
]
+
2
)
self
.
max_len
=
max
(
lens
[
0
]
+
1
,
lens
[
1
]
+
2
)
class
MinMaxFilter
(
object
):
...
...
@@ -229,98 +281,109 @@ class Seq2SeqDataset(Dataset):
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
tar_fname
=
None
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
):
# convert str to bytes, and use byte data
only_src
=
False
,
trg_fpattern
=
None
,
byte_data
=
False
):
if
byte_data
:
# The WMT16 bpe data used here seems including bytes can not be
# decoded by utf8. Thus convert str to bytes, and use byte data
field_delimiter
=
field_delimiter
.
encode
(
"utf8"
)
token_delimiter
=
token_delimiter
.
encode
(
"utf8"
)
start_mark
=
start_mark
.
encode
(
"utf8"
)
end_mark
=
end_mark
.
encode
(
"utf8"
)
unk_mark
=
unk_mark
.
encode
(
"utf8"
)
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_byte_data
=
byte_data
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
,
byte_data
=
byte_data
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
,
byte_data
=
byte_data
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_only_src
=
only_src
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
fpattern
,
t
ar_fname
)
self
.
load_src_trg_ids
(
fpattern
,
t
rg_fpattern
)
def
load_src_trg_ids
(
self
,
fpattern
,
t
ar_fnam
e
):
converters
=
[
Converter
(
vocab
=
self
.
_src_vocab
,
def
load_src_trg_ids
(
self
,
fpattern
,
t
rg_fpattern
=
Non
e
):
src_converter
=
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
]
if
not
self
.
_only_src
:
converters
.
append
(
Converter
(
vocab
=
self
.
_trg_vocab
,
add_beg
=
False
,
add_end
=
False
)
trg_converter
=
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
add_beg
=
False
,
add_end
=
False
)
converters
=
ComposedConverter
(
converters
)
converters
=
ComposedConverter
(
[
src_converter
,
trg_converter
]
)
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_trg_seq_ids
=
[]
self
.
_sample_infos
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
src_trg_ids
=
converters
(
line
)
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
lens
=
[
len
(
src_trg_ids
[
0
])]
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
slots
=
[
self
.
_src_seq_ids
,
self
.
_trg_seq_ids
]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
trg_fpattern
)):
lens
=
[]
for
field
,
slot
in
zip
(
converters
(
line
),
slots
):
slot
.
append
(
field
)
lens
.
append
(
len
(
field
))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
lens
))
def
_load_lines
(
self
,
fpattern
,
t
ar_fnam
e
):
def
_load_lines
(
self
,
fpattern
,
t
rg_fpattern
=
Non
e
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
sorted
(
fpaths
)
# TODO: Add custum sort
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"rb"
)
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
(
f_mode
,
f_encoding
,
endl
)
=
(
"rb"
,
None
,
b
"
\n
"
)
if
self
.
_byte_data
else
(
"r"
,
"utf8"
,
"
\n
"
)
if
trg_fpattern
is
None
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
with
open
(
fpath
,
"rb"
)
as
f
:
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
for
line
in
f
:
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
fields
=
line
.
strip
(
endl
).
split
(
self
.
_field_delimiter
)
yield
fields
else
:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
# TODO: Need more rigorous check
trg_fpaths
=
glob
.
glob
(
trg_fpattern
)
trg_fpaths
=
sorted
(
trg_fpaths
)
assert
len
(
fpaths
)
==
len
(
trg_fpaths
),
"the number of source language data files must equal
\
with that of source language"
for
fpath
,
trg_fpath
in
zip
(
fpaths
,
trg_fpaths
):
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
with
io
.
open
(
trg_fpath
,
f_mode
,
encoding
=
f_encoding
)
as
trg_f
:
for
line
in
zip
(
f
,
trg_f
):
fields
=
[
field
.
strip
(
endl
)
for
field
in
line
]
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
def
load_dict
(
dict_path
,
reverse
=
False
,
byte_data
=
False
):
word_dict
=
{}
with
open
(
dict_path
,
"rb"
)
as
fdict
:
(
f_mode
,
f_encoding
,
endl
)
=
(
"rb"
,
None
,
b
"
\n
"
)
if
byte_data
else
(
"r"
,
"utf8"
,
"
\n
"
)
with
io
.
open
(
dict_path
,
f_mode
,
encoding
=
f_encoding
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
b
"
\n
"
)
word_dict
[
idx
]
=
line
.
strip
(
endl
)
else
:
word_dict
[
line
.
strip
(
b
"
\n
"
)]
=
idx
word_dict
[
line
.
strip
(
endl
)]
=
idx
return
word_dict
def
get_vocab_summary
(
self
):
...
...
@@ -328,9 +391,8 @@ class Seq2SeqDataset(Dataset):
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
def
__getitem__
(
self
,
idx
):
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:]
)
if
not
self
.
_only_src
else
self
.
_src_seq_ids
[
idx
]
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
]
)
if
self
.
_trg_seq_ids
else
self
.
_src_seq_ids
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
_sample_infos
)
...
...
@@ -348,6 +410,7 @@ class Seq2SeqBatchSampler(BatchSampler):
shuffle_batch
=
False
,
use_token_batch
=
False
,
clip_last_batch
=
False
,
distribute_mode
=
True
,
seed
=
0
):
for
arg
,
value
in
locals
().
items
():
if
arg
!=
"self"
:
...
...
@@ -355,6 +418,7 @@ class Seq2SeqBatchSampler(BatchSampler):
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
# for multi-devices
self
.
_distribute_mode
=
distribute_mode
self
.
_nranks
=
ParallelEnv
().
nranks
self
.
_local_rank
=
ParallelEnv
().
local_rank
self
.
_device_id
=
ParallelEnv
().
dev_id
...
...
@@ -362,8 +426,8 @@ class Seq2SeqBatchSampler(BatchSampler):
def
__iter__
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
infos
=
sorted
(
self
.
_dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
_dataset
.
_sample_infos
...
...
@@ -383,9 +447,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
...
...
@@ -413,11 +477,15 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device
for
batch_id
,
batch
in
enumerate
(
batches
):
if
batch_id
%
self
.
_nranks
==
self
.
_local_rank
:
if
not
self
.
_distribute_mode
or
(
batch_id
%
self
.
_nranks
==
self
.
_local_rank
):
batch_indices
=
[
info
.
i
for
info
in
batch
]
yield
batch_indices
if
self
.
_local_rank
>
len
(
batches
)
%
self
.
_nranks
:
if
self
.
_distribute_mode
and
len
(
batches
)
%
self
.
_nranks
!=
0
:
if
self
.
_local_rank
>=
len
(
batches
)
%
self
.
_nranks
:
# use previous data to pad
yield
batch_indices
def
__len__
(
self
):
return
100
# TODO(guosheng): fix the uncertain length
return
0
transformer/train.py
浏览文件 @
ee442428
...
...
@@ -17,7 +17,6 @@ import os
import
six
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
from
functools
import
partial
import
numpy
as
np
import
paddle
...
...
@@ -29,14 +28,18 @@ from utils.check import check_gpu, check_version
from
model
import
Input
,
set_device
from
callbacks
import
ProgBarLogger
from
reader
import
prepare_train_input
,
Seq2SeqDataset
,
Seq2SeqBatchSampl
er
from
reader
import
create_data_load
er
from
transformer
import
Transformer
,
CrossEntropyCriterion
class
TrainCallback
(
ProgBarLogger
):
def
__init__
(
self
,
log_freq
=
1
,
verbose
=
2
,
loss_normalizer
=
0.
):
super
(
TrainCallback
,
self
).
__init__
(
log_freq
,
verbose
)
# TODO: wrap these override function to simplify
def
__init__
(
self
,
args
,
verbose
=
2
):
super
(
TrainCallback
,
self
).
__init__
(
args
.
print_step
,
verbose
)
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
self
.
loss_normalizer
=
loss_normalizer
def
on_train_begin
(
self
,
logs
=
None
):
...
...
@@ -100,42 +103,7 @@ def do_train(args):
]
# def dataloader
data_loaders
=
[
None
,
None
]
data_files
=
[
args
.
training_file
,
args
.
validation_file
]
if
args
.
validation_file
else
[
args
.
training_file
]
for
i
,
data_file
in
enumerate
(
data_files
):
dataset
=
Seq2SeqDataset
(
fpattern
=
data_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
])
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
max_length
=
args
.
max_length
)
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
collate_fn
=
partial
(
prepare_train_input
,
src_pad_idx
=
args
.
eos_idx
,
trg_pad_idx
=
args
.
eos_idx
,
n_head
=
args
.
n_head
),
num_workers
=
0
,
# TODO: use multi-process
return_list
=
True
)
data_loaders
[
i
]
=
data_loader
train_loader
,
eval_loader
=
data_loaders
train_loader
,
eval_loader
=
create_data_loader
(
args
,
device
)
# define model
transformer
=
Transformer
(
...
...
@@ -166,12 +134,6 @@ def do_train(args):
if
args
.
init_from_pretrain_model
:
transformer
.
load
(
args
.
init_from_pretrain_model
,
reset_optimizer
=
True
)
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
# model train
transformer
.
fit
(
train_data
=
train_loader
,
eval_data
=
eval_loader
,
...
...
@@ -180,11 +142,7 @@ def do_train(args):
save_freq
=
1
,
save_dir
=
args
.
save_model
,
verbose
=
2
,
callbacks
=
[
TrainCallback
(
log_freq
=
args
.
print_step
,
loss_normalizer
=
loss_normalizer
)
])
callbacks
=
[
TrainCallback
(
args
)])
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录