未验证 提交 d2cc9663 编写于 作者: C ceci3 提交者: GitHub

[distill] how to get feature map (#799)

上级 6238fd7b
...@@ -12,4 +12,9 @@ ...@@ -12,4 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import distill
from .distill import *
__all__ = [] __all__ = []
__all__ += distill.__all__
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import collections
from collections import namedtuple
import paddle.nn as nn
from .losses import *
__all__ = ['Distill', 'AdaptorBase']
class LayerConfig:
def __init__(self,
s_feature_idx,
t_feature_idx,
feature_type,
loss_function,
weight=1.0,
align=False,
align_shape=None):
self.s_feature_idx = s_feature_idx
self.t_feature_idx = t_feature_idx
self.feature_type = feature_type
if loss_function in ['l1', 'l2', 'smooth_l1']:
self.loss_function = 'DistillationDistanceLoss'
elif loss_function in ['dml']:
self.loss_function = 'DistillationDMLLoss'
elif loss_function in ['rkl']:
self.loss_function = 'DistillationRKDLoss'
else:
raise NotImplementedError("loss function is not support!!!")
self.weight = weight
self.align = align
self.align_shape = align_shape
class AdaptorBase:
def __init__(self, model):
self.model = model
self.add_tensor = False
def _get_activation(self, outs, name):
def get_output_hook(layer, input, output):
outs[name] = output
return get_output_hook
def _add_distill_hook(self, outs, mapping_layers_name, layers_type):
"""
Get output by name.
outs(dict): save the middle outputs of model according to the name.
mapping_layers(list): name of middle layers.
layers_type(list): type of the middle layers to calculate distill loss.
"""
### TODO: support DP model
for idx, (n, m) in enumerate(self.model.named_sublayers()):
if n in mapping_layers_name:
midx = mapping_layers_name.index(n)
m.register_forward_post_hook(
self._get_activation(outs, layers_type[midx]))
def mapping_layers(self):
raise NotImplementedError("function mapping_layers is not implemented")
class Distill(nn.Layer):
### TODO: support list of student model and teacher model
def __init__(self, distill_configs, student_models, teacher_models,
adaptors_S, adaptors_T):
super(Distill, self).__init__()
self._distill_configs = distill_configs
self._student_models = student_models
self._teacher_models = teacher_models
self._adaptors_S = adaptors_S(self._student_models)
self._adaptors_T = adaptors_T(self._teacher_models)
self.stu_outs_dict, self.tea_outs_dict = self._prepare_outputs()
self.configs = []
for c in self._distill_configs:
self.configs.append(LayerConfig(**c).__dict__)
self.distill_idx = self._get_distill_idx()
self._loss_config_list = []
for c in self.configs:
loss_config = {}
loss_config[str(c['loss_function'])] = {}
loss_config[str(c['loss_function'])]['weight'] = c['weight']
loss_config[str(c['loss_function'])]['key'] = c[
'feature_type'] + '_' + str(c['s_feature_idx']) + '_' + str(c[
't_feature_idx'])
### TODO: support list of student models and teacher_models
loss_config[str(c['loss_function'])][
'model_name_pairs'] = [['student', 'teacher']]
self._loss_config_list.append(loss_config)
self._prepare_loss()
def _prepare_hook(self, adaptors, outs_dict):
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
adaptors._add_distill_hook(outs_dict, [layer], [layer_type])
return outs_dict
def _get_model_intermediate_output(self, adaptors, outs_dict):
mapping_layers = adaptors.mapping_layers()
for layer_type, layer in mapping_layers.items():
if isinstance(layer, str):
continue
outs_dict[layer_type] = layer
return outs_dict
def _get_distill_idx(self):
distill_idx = {}
for config in self._distill_configs:
if config['feature_type'] not in distill_idx:
distill_idx[config['feature_type']] = [[
int(config['s_feature_idx']), int(config['t_feature_idx'])
]]
else:
distill_idx[config['feature_type']].append([
int(config['s_feature_idx']), int(config['t_feature_idx'])
])
return distill_idx
def _prepare_loss(self):
self.distill_loss = CombinedLoss(self._loss_config_list)
def _prepare_outputs(self):
stu_outs_dict = collections.OrderedDict()
tea_outs_dict = collections.OrderedDict()
stu_outs_dict = self._prepare_hook(self._adaptors_S, stu_outs_dict)
tea_outs_dict = self._prepare_hook(self._adaptors_T, tea_outs_dict)
return stu_outs_dict, tea_outs_dict
def _post_outputs(self):
final_keys = []
for key, value in self.stu_outs_dict.items():
if len(key.split('_')) == 1:
final_keys.append(key)
### TODO: support list of student models and teacher_models
final_distill_dict = {
"student": collections.OrderedDict(),
"teacher": collections.OrderedDict()
}
for feature_type, dist_idx in self.distill_idx.items():
for idx, idx_list in enumerate(dist_idx):
sidx, tidx = idx_list[0], idx_list[1]
final_distill_dict['student'][feature_type + '_' + str(
sidx) + '_' + str(tidx)] = self.stu_outs_dict[
feature_type + '_' + str(sidx)]
final_distill_dict['teacher'][feature_type + '_' + str(
sidx) + '_' + str(tidx)] = self.tea_outs_dict[
feature_type + '_' + str(tidx)]
return final_distill_dict
def forward(self, *inputs, **kwargs):
stu_batch_outs = self._student_models.forward(*inputs, **kwargs)
tea_batch_outs = self._teacher_models.forward(*inputs, **kwargs)
if self._adaptors_S.add_tensor == False:
self._adaptors_S.add_tensor = True
if self._adaptors_T.add_tensor == False:
self._adaptors_T.add_tensor = True
self.stu_outs_dict = self._get_model_intermediate_output(
self._adaptors_S, self.stu_outs_dict)
self.tea_outs_dict = self._get_model_intermediate_output(
self._adaptors_T, self.tea_outs_dict)
distill_inputs = self._post_outputs()
### batch is None just for now
distill_outputs = self.distill_loss(distill_inputs, None)
distill_loss = distill_outputs['loss']
return stu_batch_outs, tea_batch_outs, distill_loss
import sys
sys.path.append("../../")
import logging
import numpy as np
import unittest
import paddle
import paddle.nn as nn
from paddle.vision.models import MobileNetV1
import paddle.vision.transforms as T
from paddleslim.dygraph.dist import Distill, AdaptorBase
from paddleslim.common.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestImperativeDistill(unittest.TestCase):
def setUp(self):
self.s_model, self.t_model = self.prepare_model()
self.t_model.eval()
self.distill_configs = self.prepare_config()
self.adaptor = self.prepare_adaptor()
def prepare_model(self):
return MobileNetV1(), MobileNetV1()
def prepare_config(self):
distill_configs = [{
's_feature_idx': 0,
't_feature_idx': 0,
'feature_type': 'hidden',
'loss_function': 'l2'
}, {
's_feature_idx': 1,
't_feature_idx': 1,
'feature_type': 'hidden',
'loss_function': 'l2'
}, {
's_feature_idx': 0,
't_feature_idx': 0,
'feature_type': 'logits',
'loss_function': 'l2'
}]
return distill_configs
def prepare_adaptor(self):
class Adaptor(AdaptorBase):
def mapping_layers(self):
mapping_layers = {}
mapping_layers['hidden_0'] = 'conv1'
mapping_layers['hidden_1'] = 'conv2_2'
mapping_layers['hidden_2'] = 'conv3_2'
mapping_layers['logits_0'] = 'fc'
return mapping_layers
return Adaptor
def test_distill(self):
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', backend='cv2', transform=transform)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
train_reader = paddle.io.DataLoader(
train_dataset, drop_last=True, places=place, batch_size=64)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=64)
def test(model):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info(
"Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
def train(model):
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
student_out, teacher_out, distill_loss = model(img)
loss = paddle.nn.functional.loss.cross_entropy(student_out,
label)
avg_loss = paddle.mean(loss)
all_loss = avg_loss + distill_loss
all_loss.backward()
adam.step()
adam.clear_grad()
if batch_id % 100 == 0:
_logger.info("Train | At epoch {} step {}: loss = {:}".
format(str(0), batch_id, all_loss.numpy()))
test(self.s_model)
self.s_model.train()
distill_model = Distill(self.distill_configs, self.s_model,
self.t_model, self.adaptor, self.adaptor)
train(distill_model)
class TestImperativeDistillCase1(TestImperativeDistill):
def prepare_model(self):
class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2D(3, 3, 3, padding=1)
self.conv2 = nn.Conv2D(3, 3, 3, padding=1)
self.conv3 = nn.Conv2D(3, 3, 3, padding=1)
self.fc = nn.Linear(3072, 10)
def forward(self, x):
self.conv1_out = self.conv1(x)
conv2_out = self.conv2(self.conv1_out)
self.conv3_out = self.conv3(conv2_out)
out = paddle.reshape(self.conv3_out, shape=[x.shape[0], -1])
out = self.fc(out)
return out
return Model(), Model()
def prepare_adaptor(self):
class Adaptor(AdaptorBase):
def mapping_layers(self):
mapping_layers = {}
mapping_layers['hidden_1'] = 'conv2'
if self.add_tensor:
mapping_layers['hidden_0'] = self.model.conv1_out
mapping_layers['hidden_2'] = self.model.conv3_out
return mapping_layers
return Adaptor
def prepare_config(self):
distill_configs = [{
's_feature_idx': 0,
't_feature_idx': 0,
'feature_type': 'hidden',
'loss_function': 'l2'
}, {
's_feature_idx': 1,
't_feature_idx': 2,
'feature_type': 'hidden',
'loss_function': 'l2'
}]
return distill_configs
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册