未验证 提交 79a7b164 编写于 作者: Z zhouzj 提交者: GitHub

Fix sparse quant (#1076)

* Remove 'mask' nodes in sparse model

* Fixed sparsity of compressed model.

* Fixed sparsity of compressed model.
Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 2350af8e
......@@ -29,7 +29,7 @@ from ..common.recover_program import recover_inference_program
from ..common import get_logger
from ..common.patterns import get_patterns
from ..analysis import TableLatencyPredictor
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program
from .create_compressed_program import build_distill_program, build_quant_program, build_prune_program, remove_unused_var_nodes
from .strategy_config import ProgramInfo, merge_config
from .auto_strategy import prepare_strategy, get_final_quant_config, create_strategy_config
......@@ -274,6 +274,15 @@ class AutoCompression:
train_program_info, test_program_info, self._quant_config = build_quant_program(
self._exe, self._places, config_dict, train_program_info,
test_program_info)
if self.train_config.sparse_model:
from ..prune.unstructured_pruner import UnstructuredPruner
self._pruner = UnstructuredPruner(
train_program_info.program,
mode='ratio',
ratio=0.75,
prune_params_type='conv1x1_only',
place=self._places)
self._pruner.set_static_masks()
self._exe.run(train_program_info.startup_program)
......@@ -402,7 +411,9 @@ class AutoCompression:
train_program_info, test_program_info = self._prepare_program(
inference_program, feed_target_names, fetch_targets, patterns,
default_distill_node_pair, strategy, config)
if 'unstructure' in self._strategy:
test_program_info.program._program = remove_unused_var_nodes(
test_program_info.program._program)
test_program_info = self._start_train(train_program_info,
test_program_info, strategy)
self._save_model(test_program_info, strategy, strategy_idx)
......@@ -462,6 +473,9 @@ class AutoCompression:
"Not set eval function, so unable to test accuracy performance."
)
if 'unstructure' in self._strategy or self.train_config.sparse_model:
self._pruner.update_params()
return test_program_info
def _save_model(self, test_program_info, strategy, strategy_idx):
......
......@@ -25,7 +25,8 @@ from .strategy_config import ProgramInfo
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
'build_distill_program', 'build_quant_program', 'build_prune_program'
'build_distill_program', 'build_quant_program', 'build_prune_program',
'remove_unused_var_nodes'
]
......@@ -425,3 +426,25 @@ def build_prune_program(executor,
format(config['prune_algo']))
return pruner, train_program_info
def remove_unused_var_nodes(program):
'''
This function is called before saving the sparse model to remove redundant nodes.
Args:
program(paddle.static.Program): The sparse model to be saved.
Returns:
program(paddle.static.Program): The sparse model.
'''
from paddle.fluid import core
from paddle.fluid.framework import IrGraph
graph = IrGraph(core.Graph(program.desc), for_test=True)
removed_nodes = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
if '_mask' in input_node.name():
removed_nodes.add(op_node)
graph.safe_remove_nodes(removed_nodes)
program = graph.to_program()
return program
......@@ -103,18 +103,9 @@ UnstructurePrune.__new__.__defaults__ = (None, ) * len(UnstructurePrune._fields)
### Train
TrainConfig = namedtuple("Train", [
"epochs",
"learning_rate",
"optimizer",
"optim_args",
"eval_iter",
"logging_iter",
"origin_metric",
"target_metric",
"use_fleet",
"amp_config",
"recompute_config",
"sharding_config",
"epochs", "learning_rate", "optimizer", "optim_args", "eval_iter",
"logging_iter", "origin_metric", "target_metric", "use_fleet", "amp_config",
"recompute_config", "sharding_config", "sparse_model"
])
TrainConfig.__new__.__defaults__ = (None, ) * len(TrainConfig._fields)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册