未验证 提交 e76079dc 编写于 作者: Z zhouzj 提交者: GitHub

[ACT] Reduce preparation time before starting training. (#1499)

* Optimize get_patterns.

* Fix bugs.

* Fix bugs.

* Fix comments.

* Fix bugs.
Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 f11ae71e
......@@ -28,7 +28,7 @@ import paddle.distributed.fleet as fleet
from ..quant.quanter import convert, quant_post
from ..common.recover_program import recover_inference_program
from ..common import get_logger
from ..common.patterns import get_patterns
from ..common.patterns import get_patterns, find_final_nodes
from ..common.load_model import load_inference_model, get_model_dir, export_onnx
from ..common.dataloader import wrap_dataloader, get_feed_vars
from ..common.config_helper import load_config
......@@ -155,7 +155,7 @@ class AutoCompression:
paddle.enable_static()
self._exe, self._places = self._prepare_envs()
self.model_type = self._get_model_type()
self.default_distill_node_pair, self.model_type = self._get_model_info()
if self.train_config is not None and self.train_config.use_fleet:
fleet.init(is_collective=True)
......@@ -188,7 +188,6 @@ class AutoCompression:
self._strategy, self._config = self._prepare_strategy(
self.strategy_config)
self.train_config = self._get_final_train_config(
self.train_config, self._strategy, self.model_type)
_logger.info(f"Selected strategies: {self._strategy}")
......@@ -206,7 +205,7 @@ class AutoCompression:
### The TrainConfig for quantization is extrapolate from above.
tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress
if self.model_type != 'transformer':
if self.model_type != 'transformer' and train_config.epochs is not None:
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None:
......@@ -301,13 +300,25 @@ class AutoCompression:
exe = paddle.static.Executor(places)
return exe, places
def _get_model_type(self):
def _get_model_info(self):
[inference_program, _, _] = (load_inference_model(
self.model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
executor=self._exe))
_, _, model_type = get_patterns(inference_program)
### set the output of final weight node as the default distillation node
distill_node = []
final_weight_node = find_final_nodes(inference_program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
model_type = None
if not isinstance(self.strategy_config, dict):
_, model_type = get_patterns(inference_program)
_logger.info(f"Detect model type: {model_type}")
if self.model_filename is None:
opt_model_filename = '__opt_model__'
else:
......@@ -321,8 +332,8 @@ class AutoCompression:
shutil.move(
os.path.join(self.updated_model_dir, opt_model_filename),
os.path.join(self.updated_model_dir, self.model_filename))
_logger.info(f"Detect model type: {model_type}")
return model_type
return distill_node, model_type
def _prepare_strategy(self, strategy_config):
if not isinstance(strategy_config, list):
......@@ -438,8 +449,7 @@ class AutoCompression:
return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets,
patterns, default_distill_node_pair, strategy, config,
train_config):
patterns, strategy, config, train_config):
train_program = recover_inference_program(program)
startup_program = paddle.static.Program()
train_program_info = ProgramInfo(startup_program, train_program,
......@@ -476,7 +486,7 @@ class AutoCompression:
strategy, patterns, self.eval_dataloader)
if train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(train_config)
dist_strategy = self._prepare_fleet_strategy(train_config)
else:
dist_strategy = None
......@@ -490,7 +500,7 @@ class AutoCompression:
train_program_info,
pruner=self._pruner,
dist_strategy=dist_strategy,
default_distill_node_pair=default_distill_node_pair)
default_distill_node_pair=self.default_distill_node_pair)
self._quant_config = None
### add quant_aware program, quant always is last step
......@@ -702,12 +712,12 @@ class AutoCompression:
train_config.origin_metric, metric))
self.metric_before_compressed = metric
patterns, default_distill_node_pair, _ = get_patterns(
inference_program)
patterns = None
if 'transformer' in strategy:
patterns, _ = get_patterns(inference_program)
train_program_info, test_program_info = self._prepare_program(
inference_program, feed_target_names, fetch_targets, patterns,
default_distill_node_pair, strategy, config, train_config)
strategy, config, train_config)
if 'unstructure' in self._strategy:
test_program_info.program._program = remove_unused_var_nodes(
test_program_info.program._program)
......
......@@ -79,8 +79,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
def get_patterns(program, only_final_node=True):
""" distinguish the pattern in the program and get distillation node """
distill_node = []
""" distinguish the pattern in the program and get model type """
skip_quant_tensor_list = []
patterns = {}
graph = GraphWrapper(program)
......@@ -124,10 +123,6 @@ def get_patterns(program, only_final_node=True):
pattern_name = 'FFN$' + str(block_num)
block_num += 1
if not only_final_node:
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
if model_type == 'transformer' and (
'fetch' in pattern_ops_type or
pattern_ops_type[-1] == 'scale'):
......@@ -140,16 +135,6 @@ def get_patterns(program, only_final_node=True):
patterns[pattern_name] = pattern_ops
if model_type != 'transformer' and (not only_final_node):
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
### add the output of final weight node to distill node
final_weight_node = find_final_nodes(program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
#### skip quant matmul in attention
if model_type == 'transformer':
for block_id in range(len(program.blocks)):
......@@ -158,4 +143,4 @@ def get_patterns(program, only_final_node=True):
if inp_name in skip_quant_tensor_list:
op._set_attr("op_namescope", "skip_quant")
return patterns, distill_node, model_type
return patterns, model_type
......@@ -319,7 +319,7 @@ def quant_aware(program,
skip_tensor_list = []
same_scale_tensor_list = []
if model_type == 'transformer' and pattern_ops is None:
pattern_ops, _, model_type = get_patterns(program)
pattern_ops, model_type = get_patterns(program)
if model_type != 'transformer':
_logger.info(
'Warning! After analysis, the real model type is not transformer! If you encounter this situation, please raise an issue let us know in which case "get_patterns" determines model type is not transformer.'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册