未验证 提交 9c214906 编写于 作者: C ceci3 提交者: GitHub

fix when params_filename is None (#1106)

上级 410d90ef
......@@ -71,14 +71,26 @@ def create_strategy_config(strategy_str, model_type):
dis_config = Distillation()
if len(tmp_s) == 3:
### TODO(ceci3): choose prune algo automatically
if 'prune' in tmp_s[0]:
### default prune config
default_prune_config = {
'pruned_ratio': float(tmp_s[1]),
'prune_algo': 'prune',
'criterion': 'l1_norm'
}
else:
### default unstruture prune config
default_prune_config = {
'prune_strategy':
'gmp', ### default unstruture prune strategy is gmp
'prune_mode': 'ratio',
'pruned_ratio': float(tmp_s[1]),
'local_sparsity': True,
'prune_params_type': 'conv1x1_only'
}
tmp_s[0] = tmp_s[0].replace('prune', 'Prune')
tmp_s[0] = tmp_s[0].replace('sparse', 'UnstructurePrune')
### TODO(ceci3): auto choose prune algo
default_prune_config = {
'pruned_ratio': float(tmp_s[1]),
'prune_algo': 'prune',
'criterion': 'l1_norm'
}
if model_type == 'transformer' and tmp_s[0] == 'Prune':
default_prune_config['prune_algo'] = 'transformer_pruner'
prune_config = eval(tmp_s[0])(**default_prune_config)
......
......@@ -97,7 +97,11 @@ class AutoCompression:
deploy_hardware(str, optional): The hardware you want to deploy. Default: 'gpu'.
"""
self.model_dir = model_dir
if model_filename == 'None':
model_filename = None
self.model_filename = model_filename
if params_filename == 'None':
params_filename = None
self.params_filename = params_filename
base_path = os.path.basename(os.path.normpath(save_dir))
parent_path = os.path.abspath(os.path.join(save_dir, os.pardir))
......
......@@ -100,7 +100,8 @@ def _load_program_and_merge(executor,
feed_target_names=None):
scope = paddle.static.global_scope()
new_scope = paddle.static.Scope()
print(model_dir, model_filename, params_filename)
if params_filename == 'None':
params_filename = None
try:
with paddle.static.scope_guard(new_scope):
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
......
......@@ -34,8 +34,9 @@ Quantization = namedtuple(
"weight_quantize_type"
])
Quantization.__new__.__defaults__ = (None, ) * (len(Quantization._fields) - 1
) + (False, )
Quantization.__new__.__defaults__ = (None, ) * (
len(Quantization._fields) - 3) + (False, 'moving_average_abs_max',
'channel_wise_abs_max')
### Distillation:
Distillation = namedtuple(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册