From f68ec4b855f772dcd59b44f4c16cfb6841b1ed76 Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Wed, 28 Dec 2022 11:37:33 +0800 Subject: [PATCH] [cherry-pick] add loss info and skd distillation (#1612) * add skd distillation. (#1587) * add skd distillation. * update skd's test. * [ACT] add loss info (#1597) * add loss info on ACT training. * Add flops info. --- paddleslim/auto_compression/analysis.py | 17 +++- paddleslim/auto_compression/compressor.py | 18 +++-- .../create_compressed_program.py | 30 ++++--- .../auto_compression/strategy_config.py | 5 +- paddleslim/dist/__init__.py | 2 +- paddleslim/dist/single_distiller.py | 58 ++++++++++++- tests/test_skd_loss.py | 81 +++++++++++++++++++ 7 files changed, 189 insertions(+), 22 deletions(-) create mode 100644 tests/test_skd_loss.py diff --git a/paddleslim/auto_compression/analysis.py b/paddleslim/auto_compression/analysis.py index db9f601e..3423db4a 100644 --- a/paddleslim/auto_compression/analysis.py +++ b/paddleslim/auto_compression/analysis.py @@ -28,7 +28,19 @@ def analysis_prune(eval_function, params_filename, analysis_file, pruned_ratios, - target_loss=None): + target_loss=None, + criterion='l1_norm'): + ''' + Args: + eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset. + model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``. + model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None. + params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None. + analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library. + pruned_ratios(list): The ratios to be pruned. + criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm. + ''' + devices = paddle.device.get_device().split(':')[0] places = paddle.device._convert_to_place(devices) exe = paddle.static.Executor(places) @@ -47,7 +59,8 @@ def analysis_prune(eval_function, eval_function, sensitivities_file=analysis_file, eval_args=[exe, feed_target_names, fetch_targets], - pruned_ratios=pruned_ratios) + pruned_ratios=pruned_ratios, + criterion=criterion) with open(analysis_file, 'rb') as f: if sys.version_info < (3, 0): diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index e098830f..5ca66163 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -783,13 +783,17 @@ class AutoCompression: total_epochs = train_config.epochs if train_config.epochs else 100 total_train_iter = 0 stop_training = False + + loss_vars = [var for var in train_program_info.loss_dict.values()] + loss_names = [name for name in train_program_info.loss_dict.keys()] + for epoch_id in range(total_epochs): if stop_training: break for batch_id, data in enumerate(self.train_dataloader()): - np_probs_float, = self._exe.run(train_program_info.program, \ + loss = self._exe.run(train_program_info.program, \ feed=data, \ - fetch_list=train_program_info.fetch_targets) + fetch_list=train_program_info.fetch_targets+loss_vars) if not isinstance(train_program_info.learning_rate, float): train_program_info.learning_rate.step() if 'unstructure' in strategy: @@ -800,10 +804,12 @@ class AutoCompression: else: logging_iter = train_config.logging_iter if batch_id % int(logging_iter) == 0: - _logger.info( - "Total iter: {}, epoch: {}, batch: {}, loss: {}".format( - total_train_iter, epoch_id, batch_id, - np_probs_float)) + print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format( + total_train_iter, epoch_id, batch_id, loss[0]) + for idx, loss_value in enumerate(loss[1:]): + print_info += '{}: {} '.format(loss_names[idx], + loss_value) + _logger.info(print_info) total_train_iter += 1 if total_train_iter % int( train_config.eval_iter) == 0 and total_train_iter != 0: diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index eacc39ff..7217b033 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no from ..common import get_logger from .strategy_config import ProgramInfo from ..common.load_model import load_inference_model +from ..analysis import flops _logger = get_logger(__name__, level=logging.INFO) __all__ = [ @@ -118,7 +119,7 @@ def _parse_distill_loss(distill_node_pair, distill_lambda=1.0): """parse distill loss config""" loss_dist = 0.0 - losses = [] + losses = {} if isinstance(distill_node_pair[0], str): assert isinstance(distill_loss, str) assert isinstance(distill_lambda, float) @@ -128,16 +129,17 @@ def _parse_distill_loss(distill_node_pair, assert len(distill_node_pair) == len(distill_loss) assert len(distill_node_pair) == len(distill_lambda) - for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda): - tmp_loss = 0.0 - _logger.info("train config.distill_node_pair: {}".format(node, loss, - lam)) + for node, loss_clas, lam in zip(distill_node_pair, distill_loss, + distill_lambda): + tmp_loss = losses.get(loss_clas, 0.0) + _logger.info("train config.distill_node_pair: {}".format( + node, loss_clas, lam)) assert len(node) % 2 == 0, \ "distill_node_pair config wrong, the length needs to be an even number" for i in range(len(node) // 2): - tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1]) - loss_dist += lam * tmp_loss - losses.append(tmp_loss) + tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam + loss_dist += tmp_loss + losses[loss_clas] = tmp_loss return loss_dist, losses @@ -313,7 +315,7 @@ def build_distill_program(executor, use_dynamic_loss_scaling=True, **train_config['amp_config']) - distill_loss, losses = _parse_distill_loss( + distill_loss, loss_dict = _parse_distill_loss( distill_node_pair, config.get('loss') or 'l2', ### default loss is l2 config.get('alpha') or 1.0) ### default alpha is 1.0 @@ -334,7 +336,7 @@ def build_distill_program(executor, train_program_info = ProgramInfo(startup_program, train_program, feed_target_names, train_fetch_list, - optimizer, learning_rate) + optimizer, learning_rate, loss_dict) test_program_info = ProgramInfo(startup_program, test_program, feed_target_names, fetch_targets) return train_program_info, test_program_info @@ -469,6 +471,8 @@ def build_prune_program(executor, params.append(param.name) original_shapes[param.name] = param.shape + origin_flops = flops(train_program_info.program) + pruned_program, _, _ = pruner.prune( train_program_info.program, paddle.static.global_scope(), @@ -485,6 +489,12 @@ def build_prune_program(executor, param.name, original_shapes[param.name], param.shape)) _logger.info( "####################channel pruning end##########################") + + final_flops = flops(pruned_program) + pruned_flops = abs(origin_flops - final_flops) / origin_flops + _logger.info("FLOPs before pruning: {}".format(origin_flops)) + _logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format( + final_flops, round(pruned_flops * 100, 2))) train_program_info.program = pruned_program elif strategy.startswith('asp'): diff --git a/paddleslim/auto_compression/strategy_config.py b/paddleslim/auto_compression/strategy_config.py index b68537c6..508da644 100644 --- a/paddleslim/auto_compression/strategy_config.py +++ b/paddleslim/auto_compression/strategy_config.py @@ -431,7 +431,8 @@ class ProgramInfo: feed_target_names, fetch_targets, optimizer=None, - learning_rate=None): + learning_rate=None, + loss_dict=None): """ ProgramInfo Config. Args: @@ -441,6 +442,7 @@ class ProgramInfo: fetch_targets(list(Variable)): The fetch variable in the program. optimizer(Optimizer, optional): Optimizer in training. Default: None. learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None. + loss_dict(dict): The components of losses. """ self.startup_program = startup_program self.program = program @@ -448,3 +450,4 @@ class ProgramInfo: self.fetch_targets = fetch_targets self.optimizer = optimizer self.learning_rate = learning_rate + self.loss_dict = loss_dict diff --git a/paddleslim/dist/__init__.py b/paddleslim/dist/__init__.py index de4b6196..46a02564 100755 --- a/paddleslim/dist/__init__.py +++ b/paddleslim/dist/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .single_distiller import merge, fsp, l2, soft_label, loss, dkd +from .single_distiller import merge, fsp, l2, soft_label, loss, dkd, skd from .dml import DML diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 8a658a6a..ac349d58 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -15,6 +15,7 @@ import numpy as np import paddle from paddleslim.core import GraphWrapper +import paddle.nn.functional as F def merge(teacher_program, @@ -203,8 +204,11 @@ def soft_label(teacher_var_name, teacher_var = paddle.nn.functional.softmax(teacher_var / teacher_temperature) soft_label_loss = paddle.mean( - paddle.fluid.layers.cross_entropy( - student_var, teacher_var, soft_label=True)) + paddle.nn.functional.cross_entropy( + input=student_var, + label=teacher_var, + soft_label=True, + use_softmax=False)) return soft_label_loss @@ -305,3 +309,53 @@ def dkd(teacher_var_name, temperature=temperature, alpha=alpha, beta=beta) + + +def skd(teacher_var_name, student_var_name, program=None, multiplier=None): + """Combine variables from student model and teacher model + by Spherical Knowledge Distillation loss (aka. skd-loss). + Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation + Args: + teacher_var_name(str): The name of teacher_var. + student_var_name(str): The name of student_var. + program(Program): The input distiller program. If not specified, + the default program will be used. Default: None + multiplier(float): The multiplier to recover its norm to the original + level. When it's None, the appropriate multiplier can be computed by + teacher's logits with paddle.std(output_t, axis=1). Default: None. + + Returns: + Variable: skd distiller loss. + """ + if program == None: + program = paddle.static.default_main_program() + + student_var = program.global_block().var(student_var_name) + teacher_var = program.global_block().var(teacher_var_name) + teacher_var.stop_gradient = True + + if multiplier is None: + multiplier = paddle.std(teacher_var, axis=1, keepdim=True) + + logits_student = F.layer_norm( + student_var, + student_var.shape[1:], + weight=None, + bias=None, + epsilon=1e-7) * multiplier + logits_teacher = F.layer_norm( + teacher_var, + teacher_var.shape[1:], + weight=None, + bias=None, + epsilon=1e-7) * multiplier + + student_out = F.softmax(logits_student, axis=1) + teacher_out = F.softmax(logits_teacher, axis=1) + skd_loss = paddle.mean( + F.cross_entropy( + input=student_out, + label=teacher_out, + soft_label=True, + use_softmax=False)) + return skd_loss diff --git a/tests/test_skd_loss.py b/tests/test_skd_loss.py new file mode 100644 index 00000000..19a07b34 --- /dev/null +++ b/tests/test_skd_loss.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle +from paddleslim.dist import merge, skd +from layers import conv_bn_layer +from static_case import StaticCase + + +class TestSKDLoss(StaticCase): + def test_skd_loss(self): + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + + student_program = paddle.static.Program() + student_startup = paddle.static.Program() + with paddle.static.program_guard(student_program, student_startup): + with paddle.utils.unique_name.guard(): + input = paddle.static.data( + name="image", shape=[None, 3, 224, 224]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + student_predict = conv1 + conv2 + + teacher_program = paddle.static.Program() + teacher_startup = paddle.static.Program() + with paddle.static.program_guard(teacher_program, teacher_startup): + with paddle.utils.unique_name.guard(): + input = paddle.static.data( + name="image", shape=[None, 3, 224, 224]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + teacher_predict = conv_bn_layer(conv5, 8, 3, "conv6") + + exe.run(teacher_startup) + exe.run(student_startup) + + data_name_map = {'image': 'image'} + merge(teacher_program, student_program, data_name_map, place) + merged_ops = [] + for block in student_program.blocks: + for op in block.ops: + merged_ops.append(op.type) + with paddle.static.program_guard(student_program, student_startup): + distill_loss = skd('teacher_' + teacher_predict.name, + student_predict.name, + program=None, + multiplier=None) + + loss_ops = [] + for block in student_program.blocks: + for op in block.ops: + loss_ops.append(op.type) + print(f"ret: {set(loss_ops).difference(set(merged_ops))}") + self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set()) + + self.assertTrue({ + 'softmax_with_cross_entropy', 'softmax', 'reduce_mean', 'layer_norm' + }.issubset(set(loss_ops).difference(set(merged_ops)))) + + +if __name__ == '__main__': + unittest.main() -- GitLab