未验证 提交 5662a660 编写于 作者: R rehulisw 提交者: GitHub

Add ITPruner (#1519)

* itpruner

* Delete paddleslim/nas/itpruner/__pycache__ directory

* test_itpruner

* Delete paddleslim/nas/itpruner/Cifar/__pycache__ directory

* Delete paddleslim/nas/itpruner/CKA/__pycache__ directory

* Delete paddleslim/nas/itpruner/Cifar/utils/__pycache__ directory

* Delete paddleslim/nas/itpruner/Cifar/nets/__pycache__ directory

* Update cka.py

* Update itpruner.py

* Update test_itpruner.py

* Update test_itpruner.py

* Update test_itpruner.py

* Delete initializer.py

* Update base_models.py

* Update resnet_cifar.py

* Update base_models.py

* Update resnet_cifar.py

* Update base_models.py

* Update cka.py

* Update resnet_cifar.py

* Update base_models.py

* Update utils.py

* Delete base_models.py

* Update resnet_cifar.py
Co-authored-by: NminghaoBD <79566150+minghaoBD@users.noreply.github.com>
上级 fa1ed78e
import paddle
def gram_linear(x):
x = paddle.Tensor(x).cuda()
return paddle.matmul(x, paddle.t(x))
def center_gram(gram, unbiased=False):
gram = gram.cuda()
if not paddle.allclose(gram, paddle.t(gram)):
raise ValueError('Input must be a symmetric matrix.')
means = paddle.mean(gram, 0)
means -= paddle.mean(means) / 2
gram -= means[:, None]
gram -= means[None, :]
return gram
def cka(gram_x, gram_y, unbiased=False):
gram_x = gram_x.cuda()
gram_y = gram_y.cuda()
gram_x = center_gram(gram_x, unbiased=unbiased)
gram_y = center_gram(gram_y, unbiased=unbiased)
scaled_hsic = paddle.dot(gram_x.reshape([-1]), gram_y.reshape([-1]))
normalization_x = paddle.linalg.norm(gram_x)
normalization_y = paddle.linalg.norm(gram_y)
return scaled_hsic / (normalization_x * normalization_y)
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
class LambdaLayer(nn.Layer):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self, in_planes, planes, stride=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2D
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride, bias_attr=False)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = norm_layer(planes)
self.downsample = nn.Sequential()
self.stride = stride
self.in_planes = in_planes
self.planes = planes
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.stride != 1 or self.planes != self.in_planes:
if self.stride != 1:
self.downsample = LambdaLayer(
lambda x: F.pad(x[:, :, ::2, ::2],
(0, 0, (self.planes-self.in_planes)//2, self.planes-self.in_planes-(self.planes-self.in_planes)//2,
0, 0, 0, 0), "constant", 0))
else:
self.downsample = LambdaLayer(
lambda x: F.pad(x[:, :, :, :],
(0, 0, (self.planes-self.in_planes) // 2,
self.planes-self.in_planes-(self.planes-self.in_planes)//2, 0, 0, 0, 0), "constant", 0))
identity = self.downsample(identity)
out += identity
out = self.relu(out)
return out
class ResNetCifar(nn.Layer):
def __init__(self, depth=20, num_classes=10, cfg=None, cutout=False):
super(ResNetCifar, self).__init__()
cfg_base = []
n = (depth-2) // 6
for i in [16, 32, 64]:
for j in range(n):
cfg_base.append(i)
if cfg is None:
cfg = cfg_base
num_blocks = []
if depth == 20:
num_blocks = [3, 3, 3]
block = BasicBlock
self.cfg_base = cfg_base
self.num_classes = num_classes
self.num_blocks = num_blocks
self.cutout = cutout
self.cfg = cfg
self.in_planes = 16
conv1 = nn.Conv2D(3, 16, kernel_size=3, stride=1, padding=1, bias_attr=False)
bn1 = nn.BatchNorm2D(16)
self.conv_bn = nn.Sequential(conv1, bn1)
self.layer1 = self._make_layer(block, cfg[0:n], num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, cfg[n:2*n], num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, cfg[2*n:], num_blocks[2], stride=2)
self.pool = nn.AdaptiveAvgPool2D(1)
self.linear = nn.Linear(cfg[-1], num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for i in range(len(strides)):
layers.append(('block_%d' % i, block(self.in_planes, planes[i], strides[i])))
self.in_planes = planes[i]
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.conv_bn(x))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.pool(out).flatten(1)
out = self.linear(out)
return out
def feature_extract(self, x):
tensor = []
out = F.relu(self.conv_bn(x))
for i in [self.layer1, self.layer2, self.layer3]:
for _layer in i:
out = _layer(out)
if type(_layer) is BasicBlock:
tensor.append(out)
return tensor
def cfg2flops(self, cfg): # to simplify, only count convolution flops
size = 32
flops = 0
flops += (3 * 3 * 3 * 16 * 32 * 32 + 16 * 32 * 32 * 4) # conv1+bn1
in_c = 16
cfg_idx = 0
for i in range(3):
num_blocks = self.num_blocks[i]
if i==1 or i==2:
size = size // 2
for j in range(num_blocks):
c = cfg[cfg_idx]
flops += (in_c * 3 * 3 * c * size * size + c * size * size * 4 + c * 3 * 3 * c * size * size + c * size * size * 4) # per block flops
if in_c != c:
flops += in_c * c * size * size # shortcut
in_c = c
cfg_idx += 1
flops += (2 * self.cfg[-1] + 1) * self.num_classes # fc layer
return flops
def cfg2flops_perlayer(self, cfg, length): # to simplify, only count convolution flops
size = 32
flops_singlecfg = [0 for j in range(length)]
flops_doublecfg = np.zeros((length, length))
flops_squarecfg = [0 for j in range(length)]
in_c = 16
cfg_idx = 0
for i in range(3):
num_blocks = self.num_blocks[i]
if i==1 or i==2:
size = size // 2
for j in range(num_blocks):
c = cfg[cfg_idx]
if i==0 and j==0:
flops_singlecfg[cfg_idx] += (c * size * size * 4 + c * size * size * 4 + in_c * 3 * 3 * c * size * size)
flops_squarecfg[cfg_idx] += c * 3 * 3 * c * size * size
else:
flops_singlecfg[cfg_idx] += (c * size * size * 4 + c * size * size * 4)
flops_doublecfg[cfg_idx-1][cfg_idx] += in_c * 3 * 3 * c * size * size
flops_doublecfg[cfg_idx][cfg_idx-1] += in_c * 3 * 3 * c * size * size
flops_squarecfg[cfg_idx] += (c * 3 * 3 * c * size * size )
if in_c != c:
flops_doublecfg[cfg_idx][cfg_idx-1] += in_c * c * size * size # shortcut
flops_doublecfg[cfg_idx-1][cfg_idx] += in_c * c * size * size
in_c = c
cfg_idx += 1
flops_singlecfg[-1] += 2 * self.cfg[-1] * self.num_classes # fc layer
return flops_singlecfg, flops_doublecfg, flops_squarecfg
def sum_list(a, j):
b = 0
for i in range(len(a)):
if i != j:
b += a[i]
return b
from .itpruner import *
__all__ = []
__all__ += itpruner.__all__
import numpy as np
from scipy import optimize
import math
import paddle
from .CKA import cka
from .Cifar.utils.utils import sum_list
__all__ = ['ITPruner']
class ITPruner:
def __init__(self, net, data):
self.net = net
self.data = data
self.important = []
self.length = 0
self.flops_singlecfg = None
self.flops_doublecfg = None
self.flops_squarecfg = None
self.target_flops = 0
def extract_feature(self, data):
n = data.shape[0]
with paddle.no_grad():
feature = self.net.feature_extract(data)
for i in range(len(feature)):
feature[i] = feature[i].reshape((n, -1))
feature[i] = feature[i].cpu().numpy()
return feature
def func(self, x, sign=1.0):
""" Objective function """
sum_fuc = []
for i in range(self.length):
sum_fuc.append(x[i] * self.important[i])
return sum(sum_fuc)
def func_deriv(self, x, sign=1.0):
""" Derivative of objective function """
diff = []
for i in range(len(self.important)):
diff.append(sign * (self.important[i]))
return np.array(diff)
def constrain_func(self, x):
""" constrain function """
a = []
for i in range(self.length):
a.append(x[i] * self.flops_singlecfg[i])
a.append(self.flops_squarecfg[i] * x[i] * x[i])
for i in range(1, self.length):
for j in range(i):
a.append(x[i] * x[j] * self.flops_doublecfg[i][j])
return np.array([self.target_flops - sum(a)])
def compute_similar_matrix(self, feature):
similar_matrix = np.zeros((len(feature), len(feature)))
for i in range(len(feature)):
for j in range(len(feature)):
similar_matrix[i][j] = cka.cka(cka.gram_linear(feature[i]), cka.gram_linear(feature[j]))
return similar_matrix
def linear_programming(self):
bnds = []
for i in range(self.length):
bnds.append((0, 1))
bnds = tuple(bnds)
cons = ({'type': 'ineq',
'fun': self.constrain_func})
result = optimize.minimize(self.func, x0=[1 for i in range(self.length)], jac=self.func_deriv, method='SLSQP', bounds=bnds,
constraints=cons)
return result
@paddle.no_grad()
def prune(self, target_flops, beta):
self.target_flops = target_flops
temp = []
feature = self.extract_feature(self.data)
similar_matrix = self.compute_similar_matrix(feature)
for i in range(len(feature)):
temp.append(sum_list(similar_matrix[i], i))
b = sum_list(temp, -1)
temp = [x / b for x in temp]
for i in range(len(feature)):
self.important.append(math.exp(-1 * beta * temp[i]))
self.length = len(self.net.cfg)
self.flops_singlecfg, self.flops_doublecfg, self.flops_squarecfg = self.net.cfg2flops_perlayer(self.net.cfg, self.length)
self.important = np.array(self.important)
self.important = np.negative(self.important)
result = self.linear_programming()
prun_cfg = np.around(np.array(self.net.cfg) * result.x)
optimize_cfg = []
for i in range(len(prun_cfg)):
b = list(prun_cfg)[i].tolist()
optimize_cfg.append(int(b))
print(optimize_cfg)
print(self.net.cfg2flops(prun_cfg))
print(self.net.cfg2flops(self.net.cfg))
import sys
sys.path.append("../")
import unittest
import paddle
from paddleslim.nas.itpruner import ITPruner
from paddleslim.nas.itpruner.Cifar.nets.resnet_cifar import ResNetCifar
class TestITPruner(unittest.TestCase):
def test_itpruner(self):
net = ResNetCifar(depth=20, num_classes=10, cfg=None)
data = paddle.normal(shape=[100, 3, 32, 32])
itpruner = ITPruner(net, data)
target_flops = 20800000
beta = 243
itpruner.prune(target_flops, beta)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册