未验证 提交 cb6ae979 编写于 作者: M minghaoBD 提交者: GitHub

[unstructured_pruner] add local_sparsity args in demo (#920)

* [Unstructured_prune] add local_sparsity demo
上级 69e3e28a
......@@ -50,6 +50,7 @@ add_arg('pruning_steps', int, 100, "How many times you want to increase y
add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
add_arg('pruning_strategy', str, 'base', "Which training strategy to use in pruning, we only support base and gmp for now. Default: base")
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False")
# yapf: enable
......@@ -96,12 +97,14 @@ def create_unstructured_pruner(model, args, configs=None):
mode=args.pruning_mode,
ratio=args.ratio,
threshold=args.threshold,
prune_params_type=args.prune_params_type)
prune_params_type=args.prune_params_type,
local_sparsity=args.local_sparsity)
else:
return GMPUnstructuredPruner(
model,
ratio=args.ratio,
prune_params_type=args.prune_params_type,
local_sparsity=args.local_sparsity,
configs=configs)
......@@ -270,7 +273,6 @@ def compress(args):
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
for i in range(args.last_epoch + 1, args.num_epochs):
......
......@@ -49,6 +49,7 @@ add_arg('tunning_epochs', int, 60, "The epoch numbers used to tune
add_arg('pruning_steps', int, 120, "How many times you want to increase your ratio during training. Default: 120")
add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False")
# yapf: enable
model_list = models.__all__
......@@ -96,13 +97,15 @@ def create_unstructured_pruner(train_program, args, place, configs):
ratio=args.ratio,
threshold=args.threshold,
prune_params_type=args.prune_params_type,
place=place)
place=place,
local_sparsity=args.local_sparsity)
else:
return GMPUnstructuredPruner(
train_program,
ratio=args.ratio,
prune_params_type=args.prune_params_type,
place=place,
local_sparsity=args.local_sparsity,
configs=configs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册