From 5ad38fe5ea3b9062038ac7ca247bb0aae7e53a9f Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Wed, 25 May 2022 10:54:28 +0800 Subject: [PATCH] [auto-compression-asp] support fleet training and asp+distillation in auto compression (#1139) * Update README.md * Update README.md * Update README.md * [auto-compression-asp] support fleet training and asp+distillation in auto compression --- paddleslim/auto_compression/compressor.py | 2 ++ paddleslim/auto_compression/create_compressed_program.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index b7a86f08..df764017 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -248,6 +248,8 @@ class AutoCompression: if train_config.amp_config is not None: strategy.amp = True strategy.amp_configs = { ** train_config.amp_config} + if train_config.asp_config is not None: + strategy.asp = True return strategy def _prepare_program(self, program, feed_target_names, fetch_targets, diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index aae9f3ea..5df2fdc0 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -278,7 +278,8 @@ def build_distill_program(executor, loss.stop_gradient = False if 'prune_algo' in config: ### prune & asp - if config['prune_algo'] == 'asp': + if config['prune_algo'] == 'asp' and not train_config.get( + 'use_fleet'): optimizer = pruner.decorate(optimizer) optimizer.minimize(loss) elif 'prune_strategy' in config: ###unstructure prune @@ -414,6 +415,8 @@ def build_prune_program(executor, 'prune_params_name'] is not None and param.name not in config[ 'prune_params_name']: excluded_params_name.append(param.name) + if "teacher_" in param.name: + excluded_params_name.append(param.name) pruner.set_excluded_layers(train_program_info.program, excluded_params_name) elif config['prune_algo'] == 'transformer_pruner': -- GitLab