未验证 提交 e76079dc 编写于 作者: Z zhouzj 提交者: GitHub

[ACT] Reduce preparation time before starting training. (#1499)

* Optimize get_patterns.

* Fix bugs.

* Fix bugs.

* Fix comments.

* Fix bugs.
Co-authored-by: Nceci3 <ceci3@users.noreply.github.com>
上级 f11ae71e
...@@ -28,7 +28,7 @@ import paddle.distributed.fleet as fleet ...@@ -28,7 +28,7 @@ import paddle.distributed.fleet as fleet
from ..quant.quanter import convert, quant_post from ..quant.quanter import convert, quant_post
from ..common.recover_program import recover_inference_program from ..common.recover_program import recover_inference_program
from ..common import get_logger from ..common import get_logger
from ..common.patterns import get_patterns from ..common.patterns import get_patterns, find_final_nodes
from ..common.load_model import load_inference_model, get_model_dir, export_onnx from ..common.load_model import load_inference_model, get_model_dir, export_onnx
from ..common.dataloader import wrap_dataloader, get_feed_vars from ..common.dataloader import wrap_dataloader, get_feed_vars
from ..common.config_helper import load_config from ..common.config_helper import load_config
...@@ -155,7 +155,7 @@ class AutoCompression: ...@@ -155,7 +155,7 @@ class AutoCompression:
paddle.enable_static() paddle.enable_static()
self._exe, self._places = self._prepare_envs() self._exe, self._places = self._prepare_envs()
self.model_type = self._get_model_type() self.default_distill_node_pair, self.model_type = self._get_model_info()
if self.train_config is not None and self.train_config.use_fleet: if self.train_config is not None and self.train_config.use_fleet:
fleet.init(is_collective=True) fleet.init(is_collective=True)
...@@ -188,7 +188,6 @@ class AutoCompression: ...@@ -188,7 +188,6 @@ class AutoCompression:
self._strategy, self._config = self._prepare_strategy( self._strategy, self._config = self._prepare_strategy(
self.strategy_config) self.strategy_config)
self.train_config = self._get_final_train_config( self.train_config = self._get_final_train_config(
self.train_config, self._strategy, self.model_type) self.train_config, self._strategy, self.model_type)
_logger.info(f"Selected strategies: {self._strategy}") _logger.info(f"Selected strategies: {self._strategy}")
...@@ -206,7 +205,7 @@ class AutoCompression: ...@@ -206,7 +205,7 @@ class AutoCompression:
### The TrainConfig for quantization is extrapolate from above. ### The TrainConfig for quantization is extrapolate from above.
tmp_train_config = copy.deepcopy(train_config.__dict__) tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress ### the epoch, train_iter, learning rate of quant is 10% of the prune compress
if self.model_type != 'transformer': if self.model_type != 'transformer' and train_config.epochs is not None:
tmp_train_config['epochs'] = max( tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1) int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None: if train_config.train_iter is not None:
...@@ -301,13 +300,25 @@ class AutoCompression: ...@@ -301,13 +300,25 @@ class AutoCompression:
exe = paddle.static.Executor(places) exe = paddle.static.Executor(places)
return exe, places return exe, places
def _get_model_type(self): def _get_model_info(self):
[inference_program, _, _] = (load_inference_model( [inference_program, _, _] = (load_inference_model(
self.model_dir, self.model_dir,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
executor=self._exe)) executor=self._exe))
_, _, model_type = get_patterns(inference_program)
### set the output of final weight node as the default distillation node
distill_node = []
final_weight_node = find_final_nodes(inference_program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
model_type = None
if not isinstance(self.strategy_config, dict):
_, model_type = get_patterns(inference_program)
_logger.info(f"Detect model type: {model_type}")
if self.model_filename is None: if self.model_filename is None:
opt_model_filename = '__opt_model__' opt_model_filename = '__opt_model__'
else: else:
...@@ -321,8 +332,8 @@ class AutoCompression: ...@@ -321,8 +332,8 @@ class AutoCompression:
shutil.move( shutil.move(
os.path.join(self.updated_model_dir, opt_model_filename), os.path.join(self.updated_model_dir, opt_model_filename),
os.path.join(self.updated_model_dir, self.model_filename)) os.path.join(self.updated_model_dir, self.model_filename))
_logger.info(f"Detect model type: {model_type}")
return model_type return distill_node, model_type
def _prepare_strategy(self, strategy_config): def _prepare_strategy(self, strategy_config):
if not isinstance(strategy_config, list): if not isinstance(strategy_config, list):
...@@ -438,8 +449,7 @@ class AutoCompression: ...@@ -438,8 +449,7 @@ class AutoCompression:
return strategy return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets, def _prepare_program(self, program, feed_target_names, fetch_targets,
patterns, default_distill_node_pair, strategy, config, patterns, strategy, config, train_config):
train_config):
train_program = recover_inference_program(program) train_program = recover_inference_program(program)
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
train_program_info = ProgramInfo(startup_program, train_program, train_program_info = ProgramInfo(startup_program, train_program,
...@@ -476,7 +486,7 @@ class AutoCompression: ...@@ -476,7 +486,7 @@ class AutoCompression:
strategy, patterns, self.eval_dataloader) strategy, patterns, self.eval_dataloader)
if train_config.use_fleet: if train_config.use_fleet:
dist_strategy = _prepare_fleet_strategy(train_config) dist_strategy = self._prepare_fleet_strategy(train_config)
else: else:
dist_strategy = None dist_strategy = None
...@@ -490,7 +500,7 @@ class AutoCompression: ...@@ -490,7 +500,7 @@ class AutoCompression:
train_program_info, train_program_info,
pruner=self._pruner, pruner=self._pruner,
dist_strategy=dist_strategy, dist_strategy=dist_strategy,
default_distill_node_pair=default_distill_node_pair) default_distill_node_pair=self.default_distill_node_pair)
self._quant_config = None self._quant_config = None
### add quant_aware program, quant always is last step ### add quant_aware program, quant always is last step
...@@ -702,12 +712,12 @@ class AutoCompression: ...@@ -702,12 +712,12 @@ class AutoCompression:
train_config.origin_metric, metric)) train_config.origin_metric, metric))
self.metric_before_compressed = metric self.metric_before_compressed = metric
patterns, default_distill_node_pair, _ = get_patterns( patterns = None
inference_program) if 'transformer' in strategy:
patterns, _ = get_patterns(inference_program)
train_program_info, test_program_info = self._prepare_program( train_program_info, test_program_info = self._prepare_program(
inference_program, feed_target_names, fetch_targets, patterns, inference_program, feed_target_names, fetch_targets, patterns,
default_distill_node_pair, strategy, config, train_config) strategy, config, train_config)
if 'unstructure' in self._strategy: if 'unstructure' in self._strategy:
test_program_info.program._program = remove_unused_var_nodes( test_program_info.program._program = remove_unused_var_nodes(
test_program_info.program._program) test_program_info.program._program)
......
...@@ -79,8 +79,7 @@ def _is_ffn(pattern_ops, pattern_ops_type): ...@@ -79,8 +79,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
def get_patterns(program, only_final_node=True): def get_patterns(program, only_final_node=True):
""" distinguish the pattern in the program and get distillation node """ """ distinguish the pattern in the program and get model type """
distill_node = []
skip_quant_tensor_list = [] skip_quant_tensor_list = []
patterns = {} patterns = {}
graph = GraphWrapper(program) graph = GraphWrapper(program)
...@@ -124,10 +123,6 @@ def get_patterns(program, only_final_node=True): ...@@ -124,10 +123,6 @@ def get_patterns(program, only_final_node=True):
pattern_name = 'FFN$' + str(block_num) pattern_name = 'FFN$' + str(block_num)
block_num += 1 block_num += 1
if not only_final_node:
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
if model_type == 'transformer' and ( if model_type == 'transformer' and (
'fetch' in pattern_ops_type or 'fetch' in pattern_ops_type or
pattern_ops_type[-1] == 'scale'): pattern_ops_type[-1] == 'scale'):
...@@ -140,16 +135,6 @@ def get_patterns(program, only_final_node=True): ...@@ -140,16 +135,6 @@ def get_patterns(program, only_final_node=True):
patterns[pattern_name] = pattern_ops patterns[pattern_name] = pattern_ops
if model_type != 'transformer' and (not only_final_node):
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
### add the output of final weight node to distill node
final_weight_node = find_final_nodes(program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
#### skip quant matmul in attention #### skip quant matmul in attention
if model_type == 'transformer': if model_type == 'transformer':
for block_id in range(len(program.blocks)): for block_id in range(len(program.blocks)):
...@@ -158,4 +143,4 @@ def get_patterns(program, only_final_node=True): ...@@ -158,4 +143,4 @@ def get_patterns(program, only_final_node=True):
if inp_name in skip_quant_tensor_list: if inp_name in skip_quant_tensor_list:
op._set_attr("op_namescope", "skip_quant") op._set_attr("op_namescope", "skip_quant")
return patterns, distill_node, model_type return patterns, model_type
...@@ -319,7 +319,7 @@ def quant_aware(program, ...@@ -319,7 +319,7 @@ def quant_aware(program,
skip_tensor_list = [] skip_tensor_list = []
same_scale_tensor_list = [] same_scale_tensor_list = []
if model_type == 'transformer' and pattern_ops is None: if model_type == 'transformer' and pattern_ops is None:
pattern_ops, _, model_type = get_patterns(program) pattern_ops, model_type = get_patterns(program)
if model_type != 'transformer': if model_type != 'transformer':
_logger.info( _logger.info(
'Warning! After analysis, the real model type is not transformer! If you encounter this situation, please raise an issue let us know in which case "get_patterns" determines model type is not transformer.' 'Warning! After analysis, the real model type is not transformer! If you encounter this situation, please raise an issue let us know in which case "get_patterns" determines model type is not transformer.'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册