未验证 提交 f68ec4b8 编写于 作者: Z zhouzj 提交者: GitHub

[cherry-pick] add loss info and skd distillation (#1612)

* add skd distillation. (#1587)

* add skd distillation.

* update skd's test.

* [ACT] add loss info (#1597)

* add loss info on ACT training.

* Add flops info.
上级 d5214607
......@@ -28,7 +28,19 @@ def analysis_prune(eval_function,
params_filename,
analysis_file,
pruned_ratios,
target_loss=None):
target_loss=None,
criterion='l1_norm'):
'''
Args:
eval_func(function): The callback function used to evaluate the model. It should accept a instance of `paddle.static.Program` as argument and return a score on test dataset.
model_dir(str): Directory path to load model. If you want to load onnx model, only set ``model_dir=model.onnx``.
model_filename(str): Specify model_filename. If you want to load onnx model, model filename should be None.
params_filename(str): Specify params_filename. If you want to load onnx model, params filename should be None.
analysis_file(str): The file to save the sensitivities. It will append the latest computed sensitivities into the file. And the sensitivities in the file would not be computed again. This file can be loaded by `pickle` library.
pruned_ratios(list): The ratios to be pruned.
criterion(str|function): The criterion used to sort channels for pruning. Currently supports l1_ norm, bn_scale, geometry_median. Default: l1_norm.
'''
devices = paddle.device.get_device().split(':')[0]
places = paddle.device._convert_to_place(devices)
exe = paddle.static.Executor(places)
......@@ -47,7 +59,8 @@ def analysis_prune(eval_function,
eval_function,
sensitivities_file=analysis_file,
eval_args=[exe, feed_target_names, fetch_targets],
pruned_ratios=pruned_ratios)
pruned_ratios=pruned_ratios,
criterion=criterion)
with open(analysis_file, 'rb') as f:
if sys.version_info < (3, 0):
......
......@@ -783,13 +783,17 @@ class AutoCompression:
total_epochs = train_config.epochs if train_config.epochs else 100
total_train_iter = 0
stop_training = False
loss_vars = [var for var in train_program_info.loss_dict.values()]
loss_names = [name for name in train_program_info.loss_dict.keys()]
for epoch_id in range(total_epochs):
if stop_training:
break
for batch_id, data in enumerate(self.train_dataloader()):
np_probs_float, = self._exe.run(train_program_info.program, \
loss = self._exe.run(train_program_info.program, \
feed=data, \
fetch_list=train_program_info.fetch_targets)
fetch_list=train_program_info.fetch_targets+loss_vars)
if not isinstance(train_program_info.learning_rate, float):
train_program_info.learning_rate.step()
if 'unstructure' in strategy:
......@@ -800,10 +804,12 @@ class AutoCompression:
else:
logging_iter = train_config.logging_iter
if batch_id % int(logging_iter) == 0:
_logger.info(
"Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
total_train_iter, epoch_id, batch_id,
np_probs_float))
print_info = "Total iter: {}, epoch: {}, batch: {}, loss: {}".format(
total_train_iter, epoch_id, batch_id, loss[0])
for idx, loss_value in enumerate(loss[1:]):
print_info += '{}: {} '.format(loss_names[idx],
loss_value)
_logger.info(print_info)
total_train_iter += 1
if total_train_iter % int(
train_config.eval_iter) == 0 and total_train_iter != 0:
......
......@@ -24,6 +24,7 @@ from ..common.recover_program import recover_inference_program, _remove_fetch_no
from ..common import get_logger
from .strategy_config import ProgramInfo
from ..common.load_model import load_inference_model
from ..analysis import flops
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
......@@ -118,7 +119,7 @@ def _parse_distill_loss(distill_node_pair,
distill_lambda=1.0):
"""parse distill loss config"""
loss_dist = 0.0
losses = []
losses = {}
if isinstance(distill_node_pair[0], str):
assert isinstance(distill_loss, str)
assert isinstance(distill_lambda, float)
......@@ -128,16 +129,17 @@ def _parse_distill_loss(distill_node_pair,
assert len(distill_node_pair) == len(distill_loss)
assert len(distill_node_pair) == len(distill_lambda)
for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda):
tmp_loss = 0.0
_logger.info("train config.distill_node_pair: {}".format(node, loss,
lam))
for node, loss_clas, lam in zip(distill_node_pair, distill_loss,
distill_lambda):
tmp_loss = losses.get(loss_clas, 0.0)
_logger.info("train config.distill_node_pair: {}".format(
node, loss_clas, lam))
assert len(node) % 2 == 0, \
"distill_node_pair config wrong, the length needs to be an even number"
for i in range(len(node) // 2):
tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1])
loss_dist += lam * tmp_loss
losses.append(tmp_loss)
tmp_loss += eval(loss_clas)(node[i * 2], node[i * 2 + 1]) * lam
loss_dist += tmp_loss
losses[loss_clas] = tmp_loss
return loss_dist, losses
......@@ -313,7 +315,7 @@ def build_distill_program(executor,
use_dynamic_loss_scaling=True,
**train_config['amp_config'])
distill_loss, losses = _parse_distill_loss(
distill_loss, loss_dict = _parse_distill_loss(
distill_node_pair,
config.get('loss') or 'l2', ### default loss is l2
config.get('alpha') or 1.0) ### default alpha is 1.0
......@@ -334,7 +336,7 @@ def build_distill_program(executor,
train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, train_fetch_list,
optimizer, learning_rate)
optimizer, learning_rate, loss_dict)
test_program_info = ProgramInfo(startup_program, test_program,
feed_target_names, fetch_targets)
return train_program_info, test_program_info
......@@ -469,6 +471,8 @@ def build_prune_program(executor,
params.append(param.name)
original_shapes[param.name] = param.shape
origin_flops = flops(train_program_info.program)
pruned_program, _, _ = pruner.prune(
train_program_info.program,
paddle.static.global_scope(),
......@@ -485,6 +489,12 @@ def build_prune_program(executor,
param.name, original_shapes[param.name], param.shape))
_logger.info(
"####################channel pruning end##########################")
final_flops = flops(pruned_program)
pruned_flops = abs(origin_flops - final_flops) / origin_flops
_logger.info("FLOPs before pruning: {}".format(origin_flops))
_logger.info("FLOPs after pruning: {}. Pruned FLOPs: {}%.".format(
final_flops, round(pruned_flops * 100, 2)))
train_program_info.program = pruned_program
elif strategy.startswith('asp'):
......
......@@ -431,7 +431,8 @@ class ProgramInfo:
feed_target_names,
fetch_targets,
optimizer=None,
learning_rate=None):
learning_rate=None,
loss_dict=None):
"""
ProgramInfo Config.
Args:
......@@ -441,6 +442,7 @@ class ProgramInfo:
fetch_targets(list(Variable)): The fetch variable in the program.
optimizer(Optimizer, optional): Optimizer in training. Default: None.
learning_rate(float|paddle.optimizer.lr, optional): learning_rate in training. Default: None.
loss_dict(dict): The components of losses.
"""
self.startup_program = startup_program
self.program = program
......@@ -448,3 +450,4 @@ class ProgramInfo:
self.fetch_targets = fetch_targets
self.optimizer = optimizer
self.learning_rate = learning_rate
self.loss_dict = loss_dict
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .single_distiller import merge, fsp, l2, soft_label, loss, dkd
from .single_distiller import merge, fsp, l2, soft_label, loss, dkd, skd
from .dml import DML
......@@ -15,6 +15,7 @@
import numpy as np
import paddle
from paddleslim.core import GraphWrapper
import paddle.nn.functional as F
def merge(teacher_program,
......@@ -203,8 +204,11 @@ def soft_label(teacher_var_name,
teacher_var = paddle.nn.functional.softmax(teacher_var /
teacher_temperature)
soft_label_loss = paddle.mean(
paddle.fluid.layers.cross_entropy(
student_var, teacher_var, soft_label=True))
paddle.nn.functional.cross_entropy(
input=student_var,
label=teacher_var,
soft_label=True,
use_softmax=False))
return soft_label_loss
......@@ -305,3 +309,53 @@ def dkd(teacher_var_name,
temperature=temperature,
alpha=alpha,
beta=beta)
def skd(teacher_var_name, student_var_name, program=None, multiplier=None):
"""Combine variables from student model and teacher model
by Spherical Knowledge Distillation loss (aka. skd-loss).
Reference: https://github.com/forjiuzhou/Spherical-Knowledge-Distillation
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program. If not specified,
the default program will be used. Default: None
multiplier(float): The multiplier to recover its norm to the original
level. When it's None, the appropriate multiplier can be computed by
teacher's logits with paddle.std(output_t, axis=1). Default: None.
Returns:
Variable: skd distiller loss.
"""
if program == None:
program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
teacher_var.stop_gradient = True
if multiplier is None:
multiplier = paddle.std(teacher_var, axis=1, keepdim=True)
logits_student = F.layer_norm(
student_var,
student_var.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * multiplier
logits_teacher = F.layer_norm(
teacher_var,
teacher_var.shape[1:],
weight=None,
bias=None,
epsilon=1e-7) * multiplier
student_out = F.softmax(logits_student, axis=1)
teacher_out = F.softmax(logits_teacher, axis=1)
skd_loss = paddle.mean(
F.cross_entropy(
input=student_out,
label=teacher_out,
soft_label=True,
use_softmax=False))
return skd_loss
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle
from paddleslim.dist import merge, skd
from layers import conv_bn_layer
from static_case import StaticCase
class TestSKDLoss(StaticCase):
def test_skd_loss(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
student_program = paddle.static.Program()
student_startup = paddle.static.Program()
with paddle.static.program_guard(student_program, student_startup):
with paddle.utils.unique_name.guard():
input = paddle.static.data(
name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
student_predict = conv1 + conv2
teacher_program = paddle.static.Program()
teacher_startup = paddle.static.Program()
with paddle.static.program_guard(teacher_program, teacher_startup):
with paddle.utils.unique_name.guard():
input = paddle.static.data(
name="image", shape=[None, 3, 224, 224])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
teacher_predict = conv_bn_layer(conv5, 8, 3, "conv6")
exe.run(teacher_startup)
exe.run(student_startup)
data_name_map = {'image': 'image'}
merge(teacher_program, student_program, data_name_map, place)
merged_ops = []
for block in student_program.blocks:
for op in block.ops:
merged_ops.append(op.type)
with paddle.static.program_guard(student_program, student_startup):
distill_loss = skd('teacher_' + teacher_predict.name,
student_predict.name,
program=None,
multiplier=None)
loss_ops = []
for block in student_program.blocks:
for op in block.ops:
loss_ops.append(op.type)
print(f"ret: {set(loss_ops).difference(set(merged_ops))}")
self.assertTrue(set(merged_ops).difference(set(loss_ops)) == set())
self.assertTrue({
'softmax_with_cross_entropy', 'softmax', 'reduce_mean', 'layer_norm'
}.issubset(set(loss_ops).difference(set(merged_ops))))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册