Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
84738545
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
84738545
编写于
3月 31, 2020
作者:
A
Aurelius84
提交者:
GitHub
3月 31, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add dygraph_to_static training unitTest of transformer model (#23316)
上级
420944e5
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
999 addition
and
0 deletion
+999
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py
...uid/tests/unittests/dygraph_to_static/test_transformer.py
+235
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py
.../unittests/dygraph_to_static/transformer_dygraph_model.py
+488
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_util.py
...uid/tests/unittests/dygraph_to_static/transformer_util.py
+276
-0
未找到文件。
python/paddle/fluid/tests/unittests/dygraph_to_static/test_transformer.py
0 → 100644
浏览文件 @
84738545
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
numpy
as
np
import
time
import
os
import
unittest
import
paddle.fluid
as
fluid
import
transformer_util
as
util
from
transformer_dygraph_model
import
Transformer
from
transformer_dygraph_model
import
CrossEntropyCriterion
trainer_count
=
1
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
(
)
SEED
=
10
def
train_static
(
args
,
batch_generator
):
train_prog
=
fluid
.
default_main_program
()
startup_prog
=
fluid
.
default_startup_program
()
train_prog
.
random_seed
=
SEED
startup_prog
.
random_seed
=
SEED
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
# define input and reader
input_field_names
=
util
.
encoder_data_input_fields
+
\
util
.
decoder_data_input_fields
[:
-
1
]
+
util
.
label_data_input_fields
input_descs
=
util
.
get_input_descs
(
args
)
input_slots
=
[{
"name"
:
name
,
"shape"
:
input_descs
[
name
][
0
],
"dtype"
:
input_descs
[
name
][
1
]
}
for
name
in
input_field_names
]
input_field
=
util
.
InputField
(
input_slots
)
# Define DataLoader
data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
input_field
.
feed_list
,
capacity
=
60
)
data_loader
.
set_batch_generator
(
batch_generator
,
places
=
place
)
# define model
transformer
=
Transformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
args
.
eos_idx
)
logits
=
transformer
(
*
input_field
.
feed_list
[:
7
])
# define loss
criterion
=
CrossEntropyCriterion
(
args
.
label_smooth_eps
)
lbl_word
,
lbl_weight
=
input_field
.
feed_list
[
7
:]
sum_cost
,
avg_cost
,
token_num
=
criterion
(
logits
,
lbl_word
,
lbl_weight
)
# define optimizer
learning_rate
=
fluid
.
layers
.
learning_rate_scheduler
.
noam_decay
(
args
.
d_model
,
args
.
warmup_steps
,
args
.
learning_rate
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
,
beta1
=
args
.
beta1
,
beta2
=
args
.
beta2
,
epsilon
=
float
(
args
.
eps
))
optimizer
.
minimize
(
avg_cost
)
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
((
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
step_idx
=
0
total_batch_num
=
0
avg_loss
=
[]
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
for
pass_id
in
range
(
args
.
epoch
):
batch_id
=
0
for
feed_dict
in
data_loader
:
outs
=
exe
.
run
(
program
=
train_prog
,
feed
=
feed_dict
,
fetch_list
=
[
sum_cost
.
name
,
token_num
.
name
])
if
step_idx
%
args
.
print_step
==
0
:
sum_cost_val
,
token_num_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
outs
[
1
])
total_sum_cost
=
sum_cost_val
.
sum
()
total_token_num
=
token_num_val
.
sum
()
total_avg_cost
=
total_sum_cost
/
total_token_num
avg_loss
.
append
(
total_avg_cost
)
if
step_idx
==
0
:
logging
.
info
(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f"
%
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
total_avg_cost
-
loss_normalizer
,
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
avg_batch_time
=
time
.
time
()
else
:
logging
.
info
(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
total_avg_cost
-
loss_normalizer
,
np
.
exp
([
min
(
total_avg_cost
,
100
)]),
args
.
print_step
/
(
time
.
time
()
-
avg_batch_time
)))
avg_batch_time
=
time
.
time
()
batch_id
+=
1
step_idx
+=
1
total_batch_num
=
total_batch_num
+
1
if
step_idx
==
10
:
if
args
.
save_model
:
model_path
=
os
.
path
.
join
(
args
.
save_model
,
"step_"
+
str
(
step_idx
),
"transformer"
)
fluid
.
save
(
train_prog
,
model_path
)
break
return
np
.
array
(
avg_loss
)
def
train_dygraph
(
args
,
batch_generator
):
with
fluid
.
dygraph
.
guard
(
place
):
if
SEED
is
not
None
:
fluid
.
default_main_program
().
random_seed
=
SEED
fluid
.
default_startup_program
().
random_seed
=
SEED
# define data loader
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
10
)
train_loader
.
set_batch_generator
(
batch_generator
,
places
=
place
)
# define model
transformer
=
Transformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
args
.
eos_idx
)
# define loss
criterion
=
CrossEntropyCriterion
(
args
.
label_smooth_eps
)
# define optimizer
learning_rate
=
fluid
.
layers
.
learning_rate_scheduler
.
noam_decay
(
args
.
d_model
,
args
.
warmup_steps
,
args
.
learning_rate
)
# define optimizer
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
,
beta1
=
args
.
beta1
,
beta2
=
args
.
beta2
,
epsilon
=
float
(
args
.
eps
),
parameter_list
=
transformer
.
parameters
())
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
ce_time
=
[]
ce_ppl
=
[]
avg_loss
=
[]
step_idx
=
0
for
pass_id
in
range
(
args
.
epoch
):
pass_start_time
=
time
.
time
()
batch_id
=
0
for
input_data
in
train_loader
():
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
)
=
input_data
logits
=
transformer
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
)
sum_cost
,
avg_cost
,
token_num
=
criterion
(
logits
,
lbl_word
,
lbl_weight
)
avg_cost
.
backward
()
optimizer
.
minimize
(
avg_cost
)
transformer
.
clear_gradients
()
if
step_idx
%
args
.
print_step
==
0
:
total_avg_cost
=
avg_cost
.
numpy
()
*
trainer_count
avg_loss
.
append
(
total_avg_cost
[
0
])
if
step_idx
==
0
:
logging
.
info
(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f"
%
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
total_avg_cost
-
loss_normalizer
,
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
avg_batch_time
=
time
.
time
()
else
:
logging
.
info
(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
total_avg_cost
-
loss_normalizer
,
np
.
exp
([
min
(
total_avg_cost
,
100
)]),
args
.
print_step
/
(
time
.
time
()
-
avg_batch_time
)))
ce_ppl
.
append
(
np
.
exp
([
min
(
total_avg_cost
,
100
)]))
avg_batch_time
=
time
.
time
()
batch_id
+=
1
step_idx
+=
1
if
step_idx
==
10
:
if
args
.
save_model
:
model_dir
=
os
.
path
.
join
(
args
.
save_model
+
'_dygraph'
,
"step_"
+
str
(
step_idx
))
if
not
os
.
path
.
exists
(
model_dir
):
os
.
makedirs
(
model_dir
)
fluid
.
save_dygraph
(
transformer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"transformer"
))
fluid
.
save_dygraph
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"transformer"
))
break
time_consumed
=
time
.
time
()
-
pass_start_time
ce_time
.
append
(
time_consumed
)
return
np
.
array
(
avg_loss
)
class
TestTransformer
(
unittest
.
TestCase
):
def
prepare
(
self
,
mode
=
'train'
):
args
=
util
.
ModelHyperParams
()
batch_generator
=
util
.
get_feed_data_reader
(
args
,
mode
)
return
args
,
batch_generator
def
test_train
(
self
):
args
,
batch_generator
=
self
.
prepare
(
mode
=
'train'
)
static_avg_loss
=
train_static
(
args
,
batch_generator
)
dygraph_avg_loss
=
train_dygraph
(
args
,
batch_generator
)
self
.
assertTrue
(
np
.
allclose
(
static_avg_loss
,
dygraph_avg_loss
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_dygraph_model.py
0 → 100644
浏览文件 @
84738545
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/dygraph_to_static/transformer_util.py
0 → 100644
浏览文件 @
84738545
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pickle
import
warnings
import
six
from
functools
import
partial
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
paddle.dataset.wmt16
as
wmt16
def
get_input_descs
(
args
):
batch_size
=
args
.
batch_size
# TODO None(before)
seq_len
=
None
n_head
=
getattr
(
args
,
"n_head"
,
8
)
d_model
=
getattr
(
args
,
"d_model"
,
512
)
input_descs
=
{
"src_word"
:
[(
batch_size
,
seq_len
),
"int64"
,
2
],
"src_pos"
:
[(
batch_size
,
seq_len
),
"int64"
],
"src_slf_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
"trg_word"
:
[(
batch_size
,
seq_len
),
"int64"
,
2
],
"trg_pos"
:
[(
batch_size
,
seq_len
),
"int64"
],
"trg_slf_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
"trg_src_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
# TODO: 1 for predict, seq_len for train
"enc_output"
:
[(
batch_size
,
seq_len
,
d_model
),
"float32"
],
"lbl_word"
:
[(
None
,
1
),
"int64"
],
"lbl_weight"
:
[(
None
,
1
),
"float32"
],
"init_score"
:
[(
batch_size
,
1
),
"float32"
,
2
],
"init_idx"
:
[(
batch_size
,
),
"int32"
],
}
return
input_descs
encoder_data_input_fields
=
(
"src_word"
,
"src_pos"
,
"src_slf_attn_bias"
,
)
decoder_data_input_fields
=
(
"trg_word"
,
"trg_pos"
,
"trg_slf_attn_bias"
,
"trg_src_attn_bias"
,
"enc_output"
,
)
label_data_input_fields
=
(
"lbl_word"
,
"lbl_weight"
,
)
fast_decoder_data_input_fields
=
(
"trg_word"
,
"trg_src_attn_bias"
,
)
class
ModelHyperParams
(
object
):
print_step
=
2
init_from_params
=
"trained_models/step_10/"
save_model
=
"trained_models"
inference_model_dir
=
"infer_model"
output_file
=
"predict.txt"
batch_size
=
5
epoch
=
1
learning_rate
=
2.0
beta1
=
0.9
beta2
=
0.997
eps
=
1e-9
warmup_steps
=
8000
label_smooth_eps
=
0.1
beam_size
=
5
max_out_len
=
256
n_best
=
1
src_vocab_size
=
10000
trg_vocab_size
=
10000
bos_idx
=
0
# index for <bos> token
eos_idx
=
1
# index for <eos> token
unk_idx
=
2
# index for <unk> token
max_length
=
256
d_model
=
512
d_inner_hid
=
2048
d_key
=
64
d_value
=
64
n_head
=
8
n_layer
=
6
prepostprocess_dropout
=
0.1
attention_dropout
=
0.1
relu_dropout
=
0.1
preprocess_cmd
=
"n"
# layer normalization
postprocess_cmd
=
"da"
# dropout + residual connection
weight_sharing
=
True
def
pad_batch_data
(
insts
,
pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
False
,
return_attn_bias
=
True
,
return_max_len
=
True
,
return_num_token
=
False
):
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
inst_data
=
np
.
array
(
[
inst
+
[
pad_idx
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_data
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
is_label
:
# label weight
inst_weight
=
np
.
array
([[
1.
]
*
len
(
inst
)
+
[
0.
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_weight
.
astype
(
"float32"
).
reshape
([
-
1
,
1
])]
else
:
# position data
inst_pos
=
np
.
array
([
list
(
range
(
0
,
len
(
inst
)))
+
[
0
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_pos
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
return_attn_bias
:
if
is_target
:
slf_attn_bias_data
=
np
.
ones
((
inst_data
.
shape
[
0
],
max_len
,
max_len
))
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_len
,
max_len
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
,
[
1
,
n_head
,
1
,
1
])
*
[
-
1e9
]
else
:
slf_attn_bias_data
=
np
.
array
([[
0
]
*
len
(
inst
)
+
[
-
1e9
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
.
reshape
([
-
1
,
1
,
1
,
max_len
]),
[
1
,
n_head
,
max_len
,
1
])
return_list
+=
[
slf_attn_bias_data
.
astype
(
"float32"
)]
if
return_max_len
:
return_list
+=
[
max_len
]
if
return_num_token
:
num_token
=
0
for
inst
in
insts
:
num_token
+=
len
(
inst
)
return_list
+=
[
num_token
]
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
def
prepare_train_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
)
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
True
,
return_attn_bias
=
False
,
return_max_len
=
False
,
return_num_token
=
True
)
lbl_word
=
lbl_word
.
reshape
(
-
1
,
1
)
lbl_weight
=
lbl_weight
.
reshape
(
-
1
,
1
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]
return
data_inputs
def
prepare_infer_input
(
insts
,
src_pad_idx
,
bos_idx
,
n_head
):
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
# start tokens
trg_word
=
np
.
asarray
([[
bos_idx
]]
*
len
(
insts
),
dtype
=
"int64"
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
trg_word
=
trg_word
.
reshape
(
-
1
,
1
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_src_attn_bias
]
return
data_inputs
def
get_feed_data_reader
(
args
,
mode
=
'train'
):
def
__for_train__
():
train_reader
=
paddle
.
batch
(
wmt16
.
train
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
),
batch_size
=
args
.
batch_size
)
for
batch
in
train_reader
():
tensors
=
prepare_train_input
(
batch
,
args
.
eos_idx
,
args
.
eos_idx
,
args
.
n_head
)
yield
tensors
def
__for_test__
():
test_reader
=
paddle
.
batch
(
wmt16
.
train
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
),
batch_size
=
args
.
batch_size
)
for
batch
in
test_reader
():
tensors
=
prepare_infer_input
(
batch
,
args
.
eos_idx
,
args
.
eos_idx
,
args
.
n_head
)
yield
tensors
return
__for_train__
if
mode
==
'train'
else
__for_test__
class
InputField
(
object
):
def
__init__
(
self
,
input_slots
):
self
.
feed_list
=
[]
for
slot
in
input_slots
:
self
.
feed_list
.
append
(
fluid
.
layers
.
data
(
name
=
slot
[
'name'
],
shape
=
slot
[
'shape'
],
dtype
=
slot
[
'dtype'
],
lod_level
=
slot
.
get
(
'lod_level'
,
0
),
append_batch_size
=
False
))
def
load
(
program
,
model_path
,
executor
=
None
,
var_list
=
None
):
"""
To load python2 saved models in python3.
"""
try
:
fluid
.
load
(
program
,
model_path
,
executor
,
var_list
)
except
UnicodeDecodeError
:
warnings
.
warn
(
"An UnicodeDecodeError is catched, which might be caused by loading "
"a python2 saved model. Encoding of pickle.load would be set and "
"load again automatically."
)
if
six
.
PY3
:
load_bak
=
pickle
.
load
pickle
.
load
=
partial
(
load_bak
,
encoding
=
"latin1"
)
fluid
.
load
(
program
,
model_path
,
executor
,
var_list
)
pickle
.
load
=
load_bak
def
load_dygraph
(
model_path
,
keep_name_table
=
False
):
"""
To load python2 saved models in python3.
"""
try
:
para_dict
,
opti_dict
=
fluid
.
load_dygraph
(
model_path
,
keep_name_table
)
return
para_dict
,
opti_dict
except
UnicodeDecodeError
:
warnings
.
warn
(
"An UnicodeDecodeError is catched, which might be caused by loading "
"a python2 saved model. Encoding of pickle.load would be set and "
"load again automatically."
)
if
six
.
PY3
:
load_bak
=
pickle
.
load
pickle
.
load
=
partial
(
load_bak
,
encoding
=
"latin1"
)
para_dict
,
opti_dict
=
fluid
.
load_dygraph
(
model_path
,
keep_name_table
)
pickle
.
load
=
load_bak
return
para_dict
,
opti_dict
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录