未验证 提交 f68ec4b8 编写于 作者: Z zhouzj 提交者: GitHub

[cherry-pick] add loss info and skd distillation (#1612)

* add skd distillation. (#1587)

* add skd distillation.

* update skd's test.

* [ACT] add loss info (#1597)

* add loss info on ACT training.

* Add flops info.
上级 d5214607
...@@ -28,7 +28,19 @@ def analysis_prune(eval_function, ...@@ -28,7 +28,19 @@ def analysis_prune(eval_function,
params_filename, params_filename,
analysis_file, analysis_file,
pruned_ratios, pruned_ratios,
target_loss=None): target_loss=None,
criterion='l1_norm'):
'''
Args:
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``.
model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None.
params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None.
analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned.
criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm.
'''
devices = paddle.device.get_device().split(':')[0] devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices) places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places) exe = paddle.static.Executor(places)
...@@ -47,7 +59,8 @@ def analysis_prune(eval_function, ...@@ -47,7 +59,8 @@ def analysis_prune(eval_function,
eval_function, eval_function,
sensitivities_file=analysis_file, sensitivities_file=analysis_file,
eval_args=[exe, feed_target_names, fetch_targets], eval_args=[exe, feed_target_names, fetch_targets],
pruned_ratios=pruned_ratios) pruned_ratios=pruned_ratios,
criterion=criterion)
with open(analysis_file, 'rb') as f: with open(analysis_file, 'rb') as f:
if sys.version_info < (3, 0): if sys.version_info < (3, 0):
......
...@@ -783,13 +783,17 @@ class AutoCompression: ...@@ -783,13 +783,17 @@ class AutoCompression:
total_epochs = train_config.epochs if train_config.epochs else 100 total_epochs = train_config.epochs if train_config.epochs else 100
total_train_iter = 0 total_train_iter = 0
stop_training = False stop_training = False
loss_vars = [var for var in train_program_info.loss_dict.values()]
loss_names = [name for name in train_program_info.loss_dict.keys()]
for epoch_id in range(total_epochs): for epoch_id in range(total_epochs):
if stop_training: if stop_training:
break break
for batch_id, data in enumerate(self.train_dataloader()): for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \ loss = self._exe.run(train_program_info.program, \
feed=data, \ feed=data, \
fetch_list=train_program_info.fetch_targets) fetch_list=train_program_info.fetch_targets+loss_vars)
if not isinstance(train_program_info.learning_rate, float): if not isinstance(train_program_info.learning_rate, float):
train_program_info.learning_rate.step() train_program_info.learning_rate.step()
if 'unstructure' in strategy: if 'unstructure' in strategy:
...@@ -800,10 +804,12 @@ class AutoCompression: ...@@ -800,10 +804,12 @@ class AutoCompression:
else: else:
logging_iter = train_config.logging_iter logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0: if batch_id % int(logging_iter) == 0:
_logger.info( print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format( total_train_iter, epoch_id, batch_id, loss[0])
total_train_iter, epoch_id, batch_id, for idx, loss_value in enumerate(loss[1:]):
np_probs_float)) print_info += '{}: {} '.format(loss_names[idx],
loss_value)
_logger.info(print_info)
total_train_iter += 1 total_train_iter += 1
if total_train_iter % int( if total_train_iter % int(
train_config.eval_iter) == 0 and total_train_iter != 0: train_config.eval_iter) == 0 and total_train_iter != 0:
......
...@@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no ...@@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no
from ..common import get_logger from ..common import get_logger
from .strategy_config import ProgramInfo from .strategy_config import ProgramInfo
from ..common.load_model import load_inference_model from ..common.load_model import load_inference_model
from ..analysis import flops
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
__all__ = [ __all__ = [
...@@ -118,7 +119,7 @@ def _parse_distill_loss(distill_node_pair, ...@@ -118,7 +119,7 @@ def _parse_distill_loss(distill_node_pair,
distill_lambda=1.0): distill_lambda=1.0):
"""parse distill loss config""" """parse distill loss config"""
loss_dist = 0.0 loss_dist = 0.0
losses = [] losses = {}
if isinstance(distill_node_pair[0], str): if isinstance(distill_node_pair[0], str):
assert isinstance(distill_loss, str) assert isinstance(distill_loss, str)
assert isinstance(distill_lambda, float) assert isinstance(distill_lambda, float)
...@@ -128,16 +129,17 @@ def _parse_distill_loss(distill_node_pair, ...@@ -128,16 +129,17 @@ def _parse_distill_loss(distill_node_pair,
assert len(distill_node_pair) == len(distill_loss) assert len(distill_node_pair) == len(distill_loss)
assert len(distill_node_pair) == len(distill_lambda) assert len(distill_node_pair) == len(distill_lambda)
for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda): for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
tmp_loss = 0.0 distill_lambda):
_logger.info("train config.distill_node_pair: {}".format(node, loss, tmp_loss = losses.get(loss_clas, 0.0)
lam)) _logger.info("train config.distill_node_pair: {}".format(
node, loss_clas, lam))
assert len(node) % 2 == 0, \ assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number" "distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2): for i in range(len(node) // 2):
tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1]) tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
loss_dist += lam * tmp_loss loss_dist += tmp_loss
losses.append(tmp_loss) losses[loss_clas] = tmp_loss
return loss_dist, losses return loss_dist, losses
...@@ -313,7 +315,7 @@ def build_distill_program(executor, ...@@ -313,7 +315,7 @@ def build_distill_program(executor,
use_dynamic_loss_scaling=True, use_dynamic_loss_scaling=True,
**train_config['amp_config']) **train_config['amp_config'])
distill_loss, losses = _parse_distill_loss( distill_loss, loss_dict = _parse_distill_loss(
distill_node_pair, distill_node_pair,
config.get('loss') or 'l2', ### default loss is l2 config.get('loss') or 'l2', ### default loss is l2
config.get('alpha') or 1.0) ### default alpha is 1.0 config.get('alpha') or 1.0) ### default alpha is 1.0
...@@ -334,7 +336,7 @@ def build_distill_program(executor, ...@@ -334,7 +336,7 @@ def build_distill_program(executor,
train_program_info = ProgramInfo(startup_program, train_program, train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, train_fetch_list, feed_target_names, train_fetch_list,
optimizer, learning_rate) optimizer, learning_rate, loss_dict)
test_program_info = ProgramInfo(startup_program, test_program, test_program_info = ProgramInfo(startup_program, test_program,
feed_target_names, fetch_targets) feed_target_names, fetch_targets)
return train_program_info, test_program_info return train_program_info, test_program_info
...@@ -469,6 +471,8 @@ def build_prune_program(executor, ...@@ -469,6 +471,8 @@ def build_prune_program(executor,
params.append(param.name) params.append(param.name)
original_shapes[param.name] = param.shape original_shapes[param.name] = param.shape
origin_flops = flops(train_program_info.program)
pruned_program, _, _ = pruner.prune( pruned_program, _, _ = pruner.prune(
train_program_info.program, train_program_info.program,
paddle.static.global_scope(), paddle.static.global_scope(),
...@@ -485,6 +489,12 @@ def build_prune_program(executor, ...@@ -485,6 +489,12 @@ def build_prune_program(executor,
param.name, original_shapes[param.name], param.shape)) param.name, original_shapes[param.name], param.shape))
_logger.info( _logger.info(
"####################channel pruning end##########################") "####################channel pruning end##########################")
final_flops = flops(pruned_program)
pruned_flops = abs(origin_flops - final_flops) / origin_flops
_logger.info("FLOPs before pruning: {}".format(origin_flops))
_logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
final_flops, round(pruned_flops * 100, 2)))
train_program_info.program = pruned_program train_program_info.program = pruned_program
elif strategy.startswith('asp'): elif strategy.startswith('asp'):
......
...@@ -431,7 +431,8 @@ class ProgramInfo: ...@@ -431,7 +431,8 @@ class ProgramInfo:
feed_target_names, feed_target_names,
fetch_targets, fetch_targets,
optimizer=None, optimizer=None,
learning_rate=None): learning_rate=None,
loss_dict=None):
""" """
ProgramInfo Config. ProgramInfo Config.
Args: Args:
...@@ -441,6 +442,7 @@ class ProgramInfo: ...@@ -441,6 +442,7 @@ class ProgramInfo:
fetch_targets(list(Variable)): The fetch variable in the program. fetch_targets(list(Variable)): The fetch variable in the program.
optimizer(Optimizer, optional): Optimizer in training. Default: None. optimizer(Optimizer, optional): Optimizer in training. Default: None.
learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None. learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None.
loss_dict(dict): The components of losses.
""" """
self.startup_program = startup_program self.startup_program = startup_program
self.program = program self.program = program
...@@ -448,3 +450,4 @@ class ProgramInfo: ...@@ -448,3 +450,4 @@ class ProgramInfo:
self.fetch_targets = fetch_targets self.fetch_targets = fetch_targets
self.optimizer = optimizer self.optimizer = optimizer
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.loss_dict = loss_dict
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .single_distiller import merge, fsp, l2, soft_label, loss, dkd from .single_distiller import merge, fsp, l2, soft_label, loss, dkd, skd
from .dml import DML from .dml import DML
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import numpy as np import numpy as np
import paddle import paddle
from paddleslim.core import GraphWrapper from paddleslim.core import GraphWrapper
import paddle.nn.functional as F
def merge(teacher_program, def merge(teacher_program,
...@@ -203,8 +204,11 @@ def soft_label(teacher_var_name, ...@@ -203,8 +204,11 @@ def soft_label(teacher_var_name,
teacher_var = paddle.nn.functional.softmax(teacher_var / teacher_var = paddle.nn.functional.softmax(teacher_var /
teacher_temperature) teacher_temperature)
soft_label_loss = paddle.mean( soft_label_loss = paddle.mean(
paddle.fluid.layers.cross_entropy( paddle.nn.functional.cross_entropy(
student_var, teacher_var, soft_label=True)) input=student_var,
label=teacher_var,
soft_label=True,
use_softmax=False))
return soft_label_loss return soft_label_loss
...@@ -305,3 +309,53 @@ def dkd(teacher_var_name, ...@@ -305,3 +309,53 @@ def dkd(teacher_var_name,
temperature=temperature, temperature=temperature,
alpha=alpha, alpha=alpha,
beta=beta) beta=beta)
def skd(teacher_var_name, student_var_name, program=None, multiplier=None):
"""Combine variables from student model and teacher model
by Spherical Knowledge Distillation loss (aka. skd-loss).
Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program. If not specified,
the default program will be used. Default: None
multiplier(float): The multiplier to recover its norm to the original
level. When it's None, the appropriate multiplier can be computed by
teacher's logits with paddle.std(output_t, axis=1). Default: None.
Returns:
Variable: skd distiller loss.
"""
if program == None:
program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
teacher_var.stop_gradient = True
if multiplier is None:
multiplier = paddle.std(teacher_var, axis=1, keepdim=True)
logits_student = F.layer_norm(
student_var,
student_var.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * multiplier
logits_teacher = F.layer_norm(
teacher_var,
teacher_var.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * multiplier
student_out = F.softmax(logits_student, axis=1)
teacher_out = F.softmax(logits_teacher, axis=1)
skd_loss = paddle.mean(
F.cross_entropy(
input=student_out,
label=teacher_out,
soft_label=True,
use_softmax=False))
return skd_loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle
from paddleslim.dist import merge, skd
from layers import conv_bn_layer
from static_case import StaticCase
class TestSKDLoss(StaticCase):
def test_skd_loss(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
student_program = paddle.static.Program()
student_startup = paddle.static.Program()
with paddle.static.program_guard(student_program, student_startup):
with paddle.utils.unique_name.guard():
input = paddle.static.data(
name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
student_predict = conv1 + conv2
teacher_program = paddle.static.Program()
teacher_startup = paddle.static.Program()
with paddle.static.program_guard(teacher_program, teacher_startup):
with paddle.utils.unique_name.guard():
input = paddle.static.data(
name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
teacher_predict = conv_bn_layer(conv5, 8, 3, "conv6")
exe.run(teacher_startup)
exe.run(student_startup)
data_name_map = {'image': 'image'}
merge(teacher_program, student_program, data_name_map, place)
merged_ops = []
for block in student_program.blocks:
for op in block.ops:
merged_ops.append(op.type)
with paddle.static.program_guard(student_program, student_startup):
distill_loss = skd('teacher_' + teacher_predict.name,
student_predict.name,
program=None,
multiplier=None)
loss_ops = []
for block in student_program.blocks:
for op in block.ops:
loss_ops.append(op.type)
print(f"ret: {set(loss_ops).difference(set(merged_ops))}")
self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set())
self.assertTrue({
'softmax_with_cross_entropy', 'softmax', 'reduce_mean', 'layer_norm'
}.issubset(set(loss_ops).difference(set(merged_ops))))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册