Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
874c1ac6
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
874c1ac6
编写于
2月 07, 2023
作者:
C
ceci3
提交者:
GitHub
2月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix ernie example to adapt develop paddle (#1644)
* fix ac * adapt x2paddle version * adapte develop paddle
上级
d0a1e2b6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
67 addition
and
67 deletion
+67
-67
paddleslim/auto_compression/compressor.py
paddleslim/auto_compression/compressor.py
+13
-15
paddleslim/auto_compression/transformer_pruner.py
paddleslim/auto_compression/transformer_pruner.py
+22
-21
paddleslim/common/load_model.py
paddleslim/common/load_model.py
+21
-20
paddleslim/common/recover_program.py
paddleslim/common/recover_program.py
+11
-11
未找到文件。
paddleslim/auto_compression/compressor.py
浏览文件 @
874c1ac6
...
@@ -241,9 +241,9 @@ class AutoCompression:
...
@@ -241,9 +241,9 @@ class AutoCompression:
],
f
'Type of input_shapes should be in [dict, tuple or list] but got
{
type
(
input_shapes
)
}
.'
],
f
'Type of input_shapes should be in [dict, tuple or list] but got
{
type
(
input_shapes
)
}
.'
paddle
.
enable_static
()
paddle
.
enable_static
()
exe
=
paddle
.
static
.
Executor
(
paddle
.
CPUPlace
())
exe
=
paddle
.
static
.
Executor
(
paddle
.
CPUPlace
())
[
inference_program
,
feed_target_names
,
[
inference_program
,
fe
tch_targets
]
=
load_inference_model
(
model_dir
,
exe
,
model_filename
,
fe
ed_target_names
,
fetch_targets
]
=
load_inference_model
(
params_filename
)
model_dir
,
exe
,
model_filename
,
params_filename
)
if
type
(
input_shapes
)
in
[
list
,
tuple
]:
if
type
(
input_shapes
)
in
[
list
,
tuple
]:
assert
len
(
assert
len
(
...
@@ -451,30 +451,29 @@ class AutoCompression:
...
@@ -451,30 +451,29 @@ class AutoCompression:
strategy
.
build_strategy
=
build_strategy
strategy
.
build_strategy
=
build_strategy
if
train_config
.
recompute_config
is
not
None
:
if
train_config
.
recompute_config
is
not
None
:
strategy
.
recompute
=
True
strategy
.
recompute
=
True
strategy
.
recompute_configs
=
{
**
train_config
.
recompute_config
}
strategy
.
recompute_configs
=
{
**
train_config
.
recompute_config
}
if
train_config
.
sharding_config
is
not
None
:
if
train_config
.
sharding_config
is
not
None
:
strategy
.
sharding
=
True
strategy
.
sharding
=
True
strategy
.
sharding_configs
=
{
**
train_config
.
sharding_config
}
strategy
.
sharding_configs
=
{
**
train_config
.
sharding_config
}
if
train_config
.
amp_config
is
not
None
:
if
train_config
.
amp_config
is
not
None
:
strategy
.
amp
=
True
strategy
.
amp
=
True
strategy
.
amp_configs
=
{
**
train_config
.
amp_config
}
strategy
.
amp_configs
=
{
**
train_config
.
amp_config
}
if
train_config
.
asp_config
is
not
None
:
if
train_config
.
asp_config
is
not
None
:
strategy
.
asp
=
True
strategy
.
asp
=
True
return
strategy
return
strategy
def
_prepare_program
(
self
,
program
,
feed_target_names
,
fetch_targets
,
def
_prepare_program
(
self
,
program
,
feed_target_names
,
fetch_targets
,
patterns
,
strategy
,
config
,
train_config
):
patterns
,
strategy
,
config
,
train_config
):
train_program
=
recover_inference_program
(
program
)
startup_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
train_program
=
recover_inference_program
(
program
,
startup_program
)
train_program_info
=
ProgramInfo
(
startup_program
,
train_program
,
train_program_info
=
ProgramInfo
(
startup_program
,
train_program
,
feed_target_names
,
fetch_targets
)
feed_target_names
,
fetch_targets
)
config_dict
=
config
.
__dict__
config_dict
=
config
.
__dict__
if
"prune_strategy"
in
config_dict
and
config_dict
[
if
"prune_strategy"
in
config_dict
and
config_dict
[
"prune_strategy"
]
==
"gmp"
and
config_dict
[
'gmp_config'
]
is
None
:
"prune_strategy"
]
==
"gmp"
and
config_dict
[
'gmp_config'
]
is
None
:
_logger
.
info
(
_logger
.
info
(
"Calculating the iterations per epoch……(It will take some time)"
)
"Calculating the iterations per epoch……(It will take some time)"
)
# NOTE:XXX: This way of calculating the iters needs to be improved.
# NOTE:XXX: This way of calculating the iters needs to be improved.
if
train_config
.
epochs
:
if
train_config
.
epochs
:
iters_per_epoch
=
len
(
list
(
self
.
train_dataloader
()))
iters_per_epoch
=
len
(
list
(
self
.
train_dataloader
()))
...
@@ -587,9 +586,8 @@ class AutoCompression:
...
@@ -587,9 +586,8 @@ class AutoCompression:
train_config
=
None
train_config
=
None
strategy_idx
=
None
strategy_idx
=
None
self
.
final_metric
=
-
1.0
self
.
final_metric
=
-
1.0
for
strategy_idx
,
(
for
strategy_idx
,
(
strategy
,
config
,
train_config
)
in
enumerate
(
strategy
,
config
,
train_config
zip
(
self
.
_strategy
,
self
.
_config
,
self
.
train_config
)):
)
in
enumerate
(
zip
(
self
.
_strategy
,
self
.
_config
,
self
.
train_config
)):
self
.
single_strategy_compress
(
strategy
,
config
,
strategy_idx
,
self
.
single_strategy_compress
(
strategy
,
config
,
strategy_idx
,
train_config
)
train_config
)
...
@@ -815,7 +813,7 @@ class AutoCompression:
...
@@ -815,7 +813,7 @@ class AutoCompression:
train_config
.
eval_iter
)
==
0
and
total_train_iter
!=
0
:
train_config
.
eval_iter
)
==
0
and
total_train_iter
!=
0
:
if
self
.
eval_function
is
not
None
:
if
self
.
eval_function
is
not
None
:
# GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
# GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
if
'unstructure'
in
strategy
:
if
'unstructure'
in
strategy
:
self
.
_pruner
.
update_params
()
self
.
_pruner
.
update_params
()
...
...
paddleslim/auto_compression/transformer_pruner.py
浏览文件 @
874c1ac6
...
@@ -296,13 +296,14 @@ class TransformerPruner:
...
@@ -296,13 +296,14 @@ class TransformerPruner:
head_num
=
-
1
head_num
=
-
1
tmp_mha_ops
=
patterns
[
'MHA$0'
]
tmp_mha_ops
=
patterns
[
'MHA$0'
]
for
op
in
tmp_mha_ops
:
for
op
in
tmp_mha_ops
:
if
op
.
type
()
in
[
'matmul'
,
'matmul_v2'
]
and
(
if
op
.
type
()
in
[
not
has_trainable_var
(
op
))
and
head_num
==
-
1
:
'matmul'
,
'matmul_v2'
]
and
(
not
has_trainable_var
(
op
))
and
head_num
==
-
1
:
inp_var
=
op
.
inputs
(
"X"
)
inp_var
=
op
.
inputs
(
"X"
)
head_num
=
inp_var
[
0
].
shape
()[
1
]
head_num
=
inp_var
[
0
].
shape
()[
1
]
mha_weight
,
ffn_weight
=
preprocess_transformer_patterns
(
patterns
,
mha_weight
,
ffn_weight
=
preprocess_transformer_patterns
(
graph
)
patterns
,
graph
)
return
input_mask_op
,
layer_num
,
head_num
,
mha_weight
,
ffn_weight
return
input_mask_op
,
layer_num
,
head_num
,
mha_weight
,
ffn_weight
def
_program_add_mask
(
self
,
program
,
patterns
,
layer_num
,
head_num
,
def
_program_add_mask
(
self
,
program
,
patterns
,
layer_num
,
head_num
,
...
@@ -312,7 +313,7 @@ class TransformerPruner:
...
@@ -312,7 +313,7 @@ class TransformerPruner:
for
ft
in
fetch_targets
:
for
ft
in
fetch_targets
:
fetch_list
.
append
(
ft
.
name
)
fetch_list
.
append
(
ft
.
name
)
program
=
recover_inference_program
(
program
)
program
=
recover_inference_program
(
program
)
block
=
program
.
global
_block
()
block
=
program
.
current
_block
()
head_mask
=
block
.
create_var
(
head_mask
=
block
.
create_var
(
name
=
'head_mask'
,
name
=
'head_mask'
,
shape
=
[
layer_num
,
head_num
],
shape
=
[
layer_num
,
head_num
],
...
@@ -325,11 +326,12 @@ class TransformerPruner:
...
@@ -325,11 +326,12 @@ class TransformerPruner:
1.0
,
1.0
,
out
=
head_mask
,
out
=
head_mask
,
stop_gradient
=
False
)
stop_gradient
=
False
)
head_mask
=
unsqueeze_op
(
head_mask
=
unsqueeze_op
(
block
,
-
1
,
block
,
-
1
,
unsqueeze_op
(
block
,
-
1
,
unsqueeze_op
(
block
,
-
1
,
unsqueeze_op
(
unsqueeze_op
(
block
,
1
,
head_mask
,
feed_num
+
1
),
block
,
1
,
head_mask
,
feed_num
+
2
),
feed_num
+
3
)
feed_num
+
1
),
feed_num
+
2
),
feed_num
+
3
)
for
pattern_name
,
pattern
in
patterns
.
items
():
for
pattern_name
,
pattern
in
patterns
.
items
():
if
'MHA'
in
pattern_name
:
if
'MHA'
in
pattern_name
:
...
@@ -432,8 +434,7 @@ class TransformerPruner:
...
@@ -432,8 +434,7 @@ class TransformerPruner:
index
=
np
.
reshape
(
index
=
np
.
reshape
(
np
.
take
(
np
.
take
(
np
.
reshape
(
np
.
reshape
(
np
.
arange
(
np
.
arange
(
0
,
head_num
*
num_per_head
,
dtype
=
'int64'
),
0
,
head_num
*
num_per_head
,
dtype
=
'int64'
),
(
head_num
,
num_per_head
)),
(
head_num
,
num_per_head
)),
idx
,
idx
,
axis
=
0
),
(
-
1
))
axis
=
0
),
(
-
1
))
...
@@ -455,13 +456,13 @@ class TransformerPruner:
...
@@ -455,13 +456,13 @@ class TransformerPruner:
for
w_idx
,
weight_name
in
enumerate
(
qkv
):
for
w_idx
,
weight_name
in
enumerate
(
qkv
):
if
w_idx
%
2
==
0
:
if
w_idx
%
2
==
0
:
### reorder qkv weight
### reorder qkv weight
reorder_head_matrix
(
weight_name
,
qkv_index
,
dim
=
1
)
reorder_head_matrix
(
weight_name
,
qkv_index
,
dim
=
1
)
else
:
else
:
### reorder qkv bias
### reorder qkv bias
reorder_head_matrix
(
weight_name
,
qkv_index
,
dim
=
0
)
reorder_head_matrix
(
weight_name
,
qkv_index
,
dim
=
0
)
### reorder attention output weight
### reorder attention output weight
reorder_head_matrix
(
attn_out
[
0
],
index
,
dim
=
0
)
reorder_head_matrix
(
attn_out
[
0
],
index
,
dim
=
0
)
def
_reorder_neuron
(
self
,
scope
,
place
,
weight
,
idx
):
def
_reorder_neuron
(
self
,
scope
,
place
,
weight
,
idx
):
...
@@ -528,8 +529,8 @@ class TransformerPruner:
...
@@ -528,8 +529,8 @@ class TransformerPruner:
if
_var
is
None
:
if
_var
is
None
:
return
return
param_t
=
_var
.
get_tensor
()
param_t
=
_var
.
get_tensor
()
pruned_ratio
=
[
pruned_ratio
[
1
]
]
if
len
(
param_t
.
shape
(
pruned_ratio
=
[
pruned_ratio
[
1
]
))
==
1
else
pruned_ratio
]
if
len
(
param_t
.
shape
(
))
==
1
else
pruned_ratio
origin_shape
=
param_t
.
shape
()
origin_shape
=
param_t
.
shape
()
def
process_qkv
(
qkv_param
,
pruned_ratio
):
def
process_qkv
(
qkv_param
,
pruned_ratio
):
...
@@ -602,12 +603,12 @@ class TransformerPruner:
...
@@ -602,12 +603,12 @@ class TransformerPruner:
origin_shape
=
op
.
attr
(
'shape'
)
origin_shape
=
op
.
attr
(
'shape'
)
pruned_shape
=
origin_shape
pruned_shape
=
origin_shape
if
len
(
origin_shape
)
==
3
:
if
len
(
origin_shape
)
==
3
:
pruned_shape
[
-
1
]
=
int
(
origin_shape
[
-
1
]
*
pruned_shape
[
-
1
]
=
int
(
self
.
width_mult
)
origin_shape
[
-
1
]
*
self
.
width_mult
)
op
.
set_attr
(
'shape'
,
pruned_shape
)
op
.
set_attr
(
'shape'
,
pruned_shape
)
elif
len
(
origin_shape
)
==
4
or
len
(
origin_shape
)
==
5
:
elif
len
(
origin_shape
)
==
4
or
len
(
origin_shape
)
==
5
:
pruned_shape
[
-
2
]
=
int
(
origin_shape
[
-
2
]
*
pruned_shape
[
-
2
]
=
int
(
self
.
width_mult
)
origin_shape
[
-
2
]
*
self
.
width_mult
)
op
.
set_attr
(
'shape'
,
pruned_shape
)
op
.
set_attr
(
'shape'
,
pruned_shape
)
else
:
else
:
raise
IndexError
raise
IndexError
...
...
paddleslim/common/load_model.py
浏览文件 @
874c1ac6
...
@@ -50,28 +50,28 @@ def load_inference_model(path_prefix,
...
@@ -50,28 +50,28 @@ def load_inference_model(path_prefix,
),
'Please check {}, or fix params_filename parameter.'
.
format
(
),
'Please check {}, or fix params_filename parameter.'
.
format
(
os
.
path
.
join
(
path_prefix
,
model_name
+
'.pdiparams'
))
os
.
path
.
join
(
path_prefix
,
model_name
+
'.pdiparams'
))
model_path_prefix
=
os
.
path
.
join
(
path_prefix
,
model_name
)
model_path_prefix
=
os
.
path
.
join
(
path_prefix
,
model_name
)
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
(
[
inference_program
,
feed_target_names
,
paddle
.
static
.
load_inference_model
(
fetch_targets
]
=
(
paddle
.
static
.
load_inference_model
(
path_prefix
=
model_path_prefix
,
executor
=
executor
))
path_prefix
=
model_path_prefix
,
executor
=
executor
))
elif
model_filename
is
not
None
and
params_filename
is
not
None
:
elif
model_filename
is
not
None
and
params_filename
is
not
None
:
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
(
[
inference_program
,
feed_target_names
,
paddle
.
static
.
load_inference_model
(
fetch_targets
]
=
(
paddle
.
static
.
load_inference_model
(
path_prefix
=
path_prefix
,
path_prefix
=
path_prefix
,
executor
=
executor
,
executor
=
executor
,
model_filename
=
model_filename
,
model_filename
=
model_filename
,
params_filename
=
params_filename
))
params_filename
=
params_filename
))
else
:
else
:
model_name
=
'.'
.
join
(
model_filename
.
split
(
'.'
)
model_name
=
'.'
.
join
(
model_filename
.
split
(
'.'
)
[:
-
1
])
if
model_filename
is
not
None
else
'model'
[:
-
1
])
if
model_filename
is
not
None
else
'model'
if
os
.
path
.
exists
(
os
.
path
.
join
(
path_prefix
,
model_name
+
'.pdmodel'
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
path_prefix
,
model_name
+
'.pdmodel'
)):
model_path_prefix
=
os
.
path
.
join
(
path_prefix
,
model_name
)
model_path_prefix
=
os
.
path
.
join
(
path_prefix
,
model_name
)
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
(
[
inference_program
,
feed_target_names
,
paddle
.
static
.
load_inference_model
(
fetch_targets
]
=
(
paddle
.
static
.
load_inference_model
(
path_prefix
=
model_path_prefix
,
executor
=
executor
))
path_prefix
=
model_path_prefix
,
executor
=
executor
))
else
:
else
:
[
inference_program
,
feed_target_names
,
fetch_targets
]
=
(
[
inference_program
,
feed_target_names
,
paddle
.
static
.
load_inference_model
(
fetch_targets
]
=
(
paddle
.
static
.
load_inference_model
(
path_prefix
=
path_prefix
,
executor
=
executor
))
path_prefix
=
path_prefix
,
executor
=
executor
))
return
[
inference_program
,
feed_target_names
,
fetch_targets
]
return
[
inference_program
,
feed_target_names
,
fetch_targets
]
...
@@ -125,13 +125,13 @@ def load_onnx_model(model_path,
...
@@ -125,13 +125,13 @@ def load_onnx_model(model_path,
version
=
x2paddle
.
__version__
version
=
x2paddle
.
__version__
v0
,
v1
,
v2
=
version
.
split
(
'.'
)
v0
,
v1
,
v2
=
version
.
split
(
'.'
)
version_sum
=
int
(
v0
)
*
100
+
int
(
v1
)
*
10
+
int
(
v2
)
version_sum
=
int
(
v0
)
*
100
+
int
(
v1
)
*
10
+
int
(
v2
)
if
version_sum
<
139
:
if
version_sum
!=
139
:
_logger
.
warning
(
_logger
.
warning
(
"x2paddle
>=1.3.9 is required, please use
\"
pip install x2paddle
\"
."
"x2paddle
==1.3.9 is required, please use
\"
pip install x2paddle==1.3.9
\"
."
)
)
os
.
system
(
'python -m pip install -U x2paddle'
)
os
.
system
(
'python -m pip install -U x2paddle
==1.3.9
'
)
except
:
except
:
os
.
system
(
'python -m pip install -U x2paddle'
)
os
.
system
(
'python -m pip install -U x2paddle
==1.3.9
'
)
# check onnx installation and version
# check onnx installation and version
try
:
try
:
pkg
.
require
(
'onnx'
)
pkg
.
require
(
'onnx'
)
...
@@ -153,7 +153,8 @@ def load_onnx_model(model_path,
...
@@ -153,7 +153,8 @@ def load_onnx_model(model_path,
time_info
=
int
(
time
.
time
())
time_info
=
int
(
time
.
time
())
if
not
disable_feedback
:
if
not
disable_feedback
:
ConverterCheck
(
ConverterCheck
(
task
=
"ONNX"
,
time_info
=
time_info
,
convert_state
=
"Start"
).
start
()
task
=
"ONNX"
,
time_info
=
time_info
,
convert_state
=
"Start"
).
start
()
# support distributed convert model
# support distributed convert model
model_idx
=
paddle
.
distributed
.
get_rank
(
model_idx
=
paddle
.
distributed
.
get_rank
(
)
if
paddle
.
distributed
.
get_world_size
()
>
1
else
0
)
if
paddle
.
distributed
.
get_world_size
()
>
1
else
0
...
...
paddleslim/common/recover_program.py
浏览文件 @
874c1ac6
...
@@ -41,9 +41,9 @@ def _recover_outputs_attr(program):
...
@@ -41,9 +41,9 @@ def _recover_outputs_attr(program):
if
"ReserveSpace"
not
in
op
.
output_names
or
len
(
if
"ReserveSpace"
not
in
op
.
output_names
or
len
(
op
.
output
(
"ReserveSpace"
))
==
0
:
op
.
output
(
"ReserveSpace"
))
==
0
:
reserve_space
=
block
.
create_var
(
reserve_space
=
block
.
create_var
(
name
=
paddle
.
fluid
.
unique_name
.
name
=
paddle
.
fluid
.
generate_with_ignorable_key
(
"."
.
join
(
unique_name
.
generate_with_ignorable_key
(
[
"reserve_space"
,
'tmp'
])),
"."
.
join
(
[
"reserve_space"
,
'tmp'
])),
dtype
=
block
.
var
(
op
.
input
(
"X"
)[
0
]).
dtype
,
dtype
=
block
.
var
(
op
.
input
(
"X"
)[
0
]).
dtype
,
type
=
paddle
.
framework
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
paddle
.
framework
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
persistable
=
False
,
persistable
=
False
,
...
@@ -52,9 +52,9 @@ def _recover_outputs_attr(program):
...
@@ -52,9 +52,9 @@ def _recover_outputs_attr(program):
if
op
.
type
==
'transpose2'
:
if
op
.
type
==
'transpose2'
:
if
'XShape'
not
in
op
.
output_names
:
if
'XShape'
not
in
op
.
output_names
:
xshape
=
block
.
create_var
(
xshape
=
block
.
create_var
(
name
=
paddle
.
fluid
.
unique_name
.
name
=
paddle
.
fluid
.
generate_with_ignorable_key
(
"."
.
join
([
"xshape"
,
'tmp'
unique_name
.
generate_with_ignorable_key
(
])),
"."
.
join
([
"xshape"
,
'tmp'
])),
dtype
=
block
.
var
(
op
.
input
(
"X"
)[
0
]).
dtype
,
dtype
=
block
.
var
(
op
.
input
(
"X"
)[
0
]).
dtype
,
type
=
paddle
.
framework
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
paddle
.
framework
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
shape
=
(
0
,
)
+
block
.
var
(
op
.
input
(
"X"
)[
0
]).
shape
,
shape
=
(
0
,
)
+
block
.
var
(
op
.
input
(
"X"
)[
0
]).
shape
,
...
@@ -64,24 +64,24 @@ def _recover_outputs_attr(program):
...
@@ -64,24 +64,24 @@ def _recover_outputs_attr(program):
return
program
return
program
def
_recover_param_attr
(
program
):
def
_recover_param_attr
(
program
,
startup_program
):
"""recover parameters attribute.
"""recover parameters attribute.
Params in infermodel are stored in the form of variable, which can not be trained."""
Params in infermodel are stored in the form of variable, which can not be trained."""
all_weights
=
[
param
for
param
in
program
.
list_vars
()
\
all_weights
=
[
param
for
param
in
program
.
list_vars
()
\
if
param
.
persistable
is
True
and
param
.
name
!=
'feed'
and
param
.
name
!=
'fetch'
]
if
param
.
persistable
is
True
and
param
.
name
!=
'feed'
and
param
.
name
!=
'fetch'
]
with
paddle
.
static
.
program_guard
(
program
):
with
paddle
.
static
.
program_guard
(
program
,
startup_program
):
for
w
in
all_weights
:
for
w
in
all_weights
:
new_w
=
paddle
.
create_parameter
(
new_w
=
paddle
.
create_parameter
(
shape
=
w
.
shape
,
dtype
=
w
.
dtype
,
name
=
w
.
name
)
shape
=
w
.
shape
,
dtype
=
w
.
dtype
,
name
=
w
.
name
)
new_w
.
set_value
(
w
.
get_value
())
new_w
.
set_value
(
w
.
get_value
())
program
.
block
(
0
).
vars
[
w
.
name
]
=
new_w
program
.
current_block
(
).
vars
[
w
.
name
]
=
new_w
return
program
return
program
def
recover_inference_program
(
inference_program
):
def
recover_inference_program
(
inference_program
,
startup_program
=
None
):
""" recover inference program to train program which can be trained. """
""" recover inference program to train program which can be trained. """
_remove_fetch_node
(
inference_program
)
_remove_fetch_node
(
inference_program
)
inference_program
=
_recover_param_attr
(
inference_program
)
inference_program
=
_recover_param_attr
(
inference_program
,
startup_program
)
inference_program
=
_recover_outputs_attr
(
inference_program
)
inference_program
=
_recover_outputs_attr
(
inference_program
)
for
var
in
inference_program
.
list_vars
():
for
var
in
inference_program
.
list_vars
():
var
.
stop_gradient
=
False
var
.
stop_gradient
=
False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录