提交 4862c7b0 编写于 作者: L likesiwell

add distill

上级 840f652d
运行本目录下的程序示例使用PaddlePaddle v0.12.0 版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
# 神经网络知识蒸馏
本文介绍了如何使用PaddlePaddle实现文章《Distilling the Knowledge in a Neural Network》[^distill](2014 NIPS workshop)在MNIST[^mnist]数据上训练神经网络的方法与功能。
## 方法概述
知识蒸馏方法的目的是从比较大的复杂的大模型中提取信息,用于较小模型的训练,从而可以使得小模型达到更好的效果。以MNIST分类任务为例,大模型通常可以产生具有高置信度的分类结果,除了我们需要的结果之外,还有更多的信息隐藏在输出的软目标(soft targets)概率的比例中,我们可以利用大模型输出的软目标概率来作为小模型训练的监督信息,结合正常的训练目标从而提升小模型的性能。在该方法中,我们通过提升最终softmax层的temperature来获得更加合适的软目标概率。
## 基本训练流程
1.训练大模型: 使用hard targets, 也就是正常的label训练大模型。
2.计算训练样本的soft targets: 利用训练好的大模型,在softmax函数中不同的temperature下产生训练样本的soft targets,公式如下:
$$ q_i = \frac{exp(z_i / T)}{\sum_j exp(z_j / T)} $$
其中,$q_i$ 是产生的soft targets, $z_i$ 是network output, $T$ 是 temperature。
3.利用soft targets与 hard targets训练小模型: 小模型使用相同的temperature与soft targets进行crossentropy就算soft损失,使用hard targets计算hard 损失,由于soft targets被缩放了 $\frac{1}{T^2}$, 为了平衡训练的梯度,需要将soft 损失扩大 $T^2$.
4.在训练结束后,使用整成的推断方式测试小网络性能,并且与正常训练方式进行比较。
更多细节讨论参考论文[^distill]
## 代码结构
| File&Folder | Description |
| :-------- | :-------- |
| train.py | Training script |
| infer.py | Predcition using the trained model script |
| utils.py | Dataset loading classes and functions |
| mnist_prepare.py | Preparing Mnist Datasets |
| ./data/ | The folder stores dataset and soft targets dataset |
| ./models/ | The folder stores trained models |
| ./images/ | Illustration graphs |
## 数据准备
运行`mnist_prepare.py` 使用PaddlePaddle dataset 函数准备原始的MNIST数据 。MNIST数据存储在文件 `./data/mnist.npz` , 所有的数据都被标准化到[-1, 1].
```python
reader = paddle.dataset.mnist.train()
~
np.savez('./data/mnist.npz', train_x=train_x, train_y=train_y, test_x=test_x, test_y=test_y)
```
## Training
**首先**, 我们训练一个单个大的神经网络,我们使用一个两层,每层1200隐藏节点的全连接网络作为大网络。
```python
def mlp_teacher(img, drop_prob):
h1 = fluid.layers.fc(input=img, size=1200, act='relu')
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=1200, act='relu')
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
```
运行`python train.py --phase teacher --drop_prob 0.4` 来训练大的神经网络,训练好的模型参数存储在 `.\models\teacher_net`.
**然后**,我们运行 `python infer.py --phase teacher --temp 4.0` 来产生训练数据对应的soft targets。 我们可以设置不同的 `temp` 超参数来产生不同的soft targets,不同的`temp`参数产生的soft targets对小模型的影响是不同的。为了保持新产生的soft targets与训练样本一致,训练小网络的数据存储在`./data/mnist_soft_{temp}.npz`.
```python
print("Generating soft-targets.......")
g_iternum = train_set.num_examples // G_batch_size
soft_list = []
for i in range(g_iternum):
g_batch = list(zip(train_x[i*G_batch_size:(i+1)*G_batch_size], train_y[i*G_batch_size:(i+1)*G_batch_size]))
soft_targets, gb_acc = exe.run(inference_program, feed=feeder.feed(g_batch),
fetch_list=[temp_softmax_logits, batch_acc])
soft_list.append(soft_targets)
train_y_soft = np.vstack(soft_list)
print("saving soft targets")
np.savez(data_dir+'mnist_soft_{}.npz'.format(temp),
train_x=train_x, train_y=train_y, test_x=test_x, test_y=test_y, train_y_soft=train_y_soft)
```
**然后**, 运行 `python train.py --phase student --dropout 0.1 --stu_hsize 30 --temp 4.0 --use_soft True` 结合soft targets训练小网络。训练小网络时,temperature应该与大的网络产生的soft targets的temperature一致。 Softmax with temperature 函数定义如下:
```python
def softmax_with_temperature(logits, temp=1.0):
logits_with_temp = logits/temp
_softmax = fluid.layers.softmax(logits_with_temp)
return _softmax
```
定义目标函数如下:
```python
def soft_crossentropy(input, label):
epsilon = 1e-8
eps = fluid.layers.ones(shape=[1], dtype='float32') * epsilon
loss = reduce_sum(-1.0 * label * log(elementwise_max(input, eps)), dim=1, keep_dim=True)
return loss
~~
softmax_logits = fluid.layers.softmax(logits)
temp_softmax_logits = softmax_with_temperature(logits, temp=temp)
hard_loss = soft_crossentropy(input=softmax_logits, label=label)
soft_loss = soft_crossentropy(input=temp_softmax_logits, label=soft_label)
```
小网络也是一个两层全连接网络,但是每层的隐含单元数量要少很多。
```python
def mlp_student(img, drop_prob, h_size):
h1 = fluid.layers.fc(input=img, size=h_size, act='relu')
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=h_size, act='relu')
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
```
**最后**,为了验证使用soft targets的有效性,我们需要要比较小网络在使用和不适用soft targets时候测试集性能的表现。 运行 `python train.py --phase student --stu_hsize 30 --drop_prob 0.1 ` 训练小网络不使用 soft targets,运行`python infer.py --phase student --stu_hsize 30` 获得测试性能。运行 `python infer.py --phase student --stu_hsize 30 --use_soft True`. 获得使用soft targets训练的小网络的测试性能。
## Results
我们使用SGD-Momentum优化器,0.001的学习率训练大的网络200 个周期, dropout 比率设置为0.4。训练两层1200隐含单元的全连接的大网络, 最终测试结果为99.33%的正确率。为了展示知识蒸馏方法的有效性,我们用两层30隐含单元的全连接网络作为小网络。如果使用SGD 优化器,在0.001学习率下,正常训练小网络,得到的测试结果为94.38%。而结合temperature 为 4.0的soft target,在同样条件下训练的小网络,测试集结果为97.01%。可以看到,使用soft targets有了非常明显的提高。
| Methods | Test Accuracy |
| :-------- | --------:|
| TeacherNet | 99.33% |
| StudentNet 30units without soft targets | 94.38% |
| StudentNet 30units with 4.0temp soft targets | 97.01% |
训练过程中的测试集准确率如下图所示,可以看到使用soft targets训练的小网络的收敛速度快于不使用soft targets的网络。
![收敛](https://github.com/likesiwell/models/blob/develop/distill_knowledge/images/plots.png)
## 参考文献
[^distill]: [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
[^mnist]: [THE MNIST DATABASE of handwritten digits](http://yann.lecun.com/exdb/mnist/)
The codes in this example is tested using PaddlePaddle v0.12.0. You can install this version according to [Installation Document](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html).
# Distill the Knowledge in a Neural Network
In this example, we introduce how to use PaddlePaddle to implement the approach described in Distilling the Knowledge in a Neural Network》[^distill](2014 NIPS workshop). We demonstrate the approach on MNIST[^mnist] classification problem.
## Introduction
The aim of knowledge distillation is to extract knowledge from the large models which can facilitate the training of small models. We can take the classification problem on MNIST as an example. The trained large models can usually generate results with high confidence. In addition to the classification results we demanded, more valuable knowledge are provided in the soft predicted probabilites. We can utilize these soft targets of training data produced by the large model as additional supervised information to train small models. The experiments shows that the performance of small models improves by using appropriate temperature soft targets.
## Training Procedure
1.Train a single large model: Using hard targets to train a single large model.
2.Compute the soft targets of training examples: Using the trained large model to generate the soft targets of training examples in different temperature. The softmax function with temperature is formuated as:
$$ q_i = \frac{exp(z_i / T)}{\sum_j exp(z_j / T)}, $$
where $q_i$ is the produced soft targets, $z_i$ is network ouputs, $T$ denotes temperature.
3.Train small a network with soft targets and hard targets: The soft targets loss is the crossentropy between soft targets and the same temperature softmax predictions. In order to balance the magnitude of the gradients between hard loss and soft loss, the soft loss is scaled by $T^2$.
4.After training, we test and compare the testing performance of small networks trained with and without soft targets.
More discussions and details, please refer to original paper[^distill].
## Directory Overview
| File&Folder | Description |
| :-------- | :-------- |
| train.py | Training script |
| infer.py | Predcition using the trained model script |
| utils.py | Dataset loading classes and functions |
| mnist_prepare.py | Preparing Mnist Datasets |
| ./data/ | The folder stores dataset and soft targets dataset |
| ./models/ | The folder stores trained models |
| ./images/ | Illustration graphs |
## DataSet Prepare
Run `mnist_prepare.py` to prepare the original mnist data by using paddle dataset. The MNIST is stored in `./data/mnist.npz` and the character images are standardized to range [-1, 1].
```python
reader = paddle.dataset.mnist.train()
~
np.savez('./data/mnist.npz', train_x=train_x, train_y=train_y, test_x=test_x, test_y=test_y)
```
## Training
**First**, we should train a single large neural network. We use a two layer fully connected network with 1200 hidden units each layer.
```python
def mlp_teacher(img, drop_prob):
h1 = fluid.layers.fc(input=img, size=1200, act='relu')
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=1200, act='relu')
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
```
Run `python train.py --phase teacher --drop_prob 0.4` to train the large network. The trained parameters are saved in `.\models\teacher_net`.
**Second**, we should produce the soft targets of training set by running `python infer.py --phase teacher --temp 4.0`. We can set different `temp` hyperparameters to generate different soft targets, since different temperature generated soft targets have different impacts on training small network. Since we need keep the same order of training data and soft targets, we construct the new dataset for student network training in `./data/mnist_soft_{temp}.npz`.
```python
print("Generating soft-targets.......")
g_iternum = train_set.num_examples // G_batch_size
soft_list = []
for i in range(g_iternum):
g_batch = list(zip(train_x[i*G_batch_size:(i+1)*G_batch_size], train_y[i*G_batch_size:(i+1)*G_batch_size]))
soft_targets, gb_acc = exe.run(inference_program, feed=feeder.feed(g_batch),
fetch_list=[temp_softmax_logits, batch_acc])
soft_list.append(soft_targets)
train_y_soft = np.vstack(soft_list)
print("saving soft targets")
np.savez(data_dir+'mnist_soft_{}.npz'.format(temp),
train_x=train_x, train_y=train_y, test_x=test_x, test_y=test_y, train_y_soft=train_y_soft)
```
**Third**, we should train the student network with soft targets by running `python train.py --phase student --dropout 0.1 --stu_hsize 30 --temp 4.0 --use_soft True` . The temperature need to be the same as which used in teacher network. The softmax with temperature function is defined as follow:
```python
def softmax_with_temperature(logits, temp=1.0):
logits_with_temp = logits/temp
_softmax = fluid.layers.softmax(logits_with_temp)
return _softmax
```
We define the hard loss and soft loss as:
```python
def soft_crossentropy(input, label):
epsilon = 1e-8
eps = fluid.layers.ones(shape=[1], dtype='float32') * epsilon
loss = reduce_sum(-1.0 * label * log(elementwise_max(input, eps)), dim=1, keep_dim=True)
return loss
~~
softmax_logits = fluid.layers.softmax(logits)
temp_softmax_logits = softmax_with_temperature(logits, temp=temp)
hard_loss = soft_crossentropy(input=softmax_logits, label=label)
soft_loss = soft_crossentropy(input=temp_softmax_logits, label=soft_label)
```
The structure of student network is also a two layer fully connected network with fewer units each layer.
```python
def mlp_student(img, drop_prob, h_size):
h1 = fluid.layers.fc(input=img, size=h_size, act='relu')
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=h_size, act='relu')
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
```
**Four**, to verifythe effectiveness of using soft targets, we should also compare with results of training student without soft targets. We run `python train.py --phase student --stu_hsize 30 --drop_prob 0.1 ` and `python infer.py --phase student --stu_hsize 30` to get the testing performance without soft targets. The testing performance with soft targets is by running `python infer.py --phase student --stu_hsize 30 --use_soft True`.
## Results
We train the teacher network using SGD-Momentum optimizer with 0.001 learning rate for 200 epochs. The dropout ratio is 0.4. The teacher network is the large network with 1200 units per fully connected layer. Finally, we get the testing accuracy 99.33%.
In order to show the effectiveness of knowledge distillation approach, we use a 30 units two layer fully connected network as the small network. The baseline student net is trained without soft targets by using SGD optimizer with 0.001 learning rate for 200 epochs. Finally, we get the testing accuracy 94.38%.
When we train the small network with soft targets of temperature 4.0, we get the testing accuracy 97.01%. It shows obvious improvement by using soft targets.
| Methods | Test Accuracy |
| :-------- | --------:|
| TeacherNet | 99.33% |
| StudentNet 30units without soft targets | 94.38% |
| StudentNet 30units with 4.0temp soft targets | 97.01% |
The testing accuracy during training procedure are visualized here. We can see that using soft targets speed up the convergence.
![convergence](https://github.com/likesiwell/models/blob/develop/distill_knowledge/images/plots.png)
## Reference
[^mnist]: [THE MNIST DATABASE of handwritten digits](http://yann.lecun.com/exdb/mnist/)
[^distill]: [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)
import numpy as np
import argparse
import paddle.fluid as fluid
from utils import read_data_sets
from paddle.fluid.layers import log, elementwise_max, create_tensor, assign, reduce_sum
def mlp_teacher(img, drop_prob):
h1 = fluid.layers.fc(input=img, size=1200, act="relu")
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=1200, act="relu")
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
def mlp_student(img, drop_prob, h_size):
h1 = fluid.layers.fc(input=img, size=h_size, act="relu")
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=h_size, act="relu")
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
def softmax_with_temperature(logits, temp=1.0):
logits_with_temp = logits / temp
_softmax = fluid.layers.softmax(logits_with_temp)
return _softmax
def soft_crossentropy(input, label):
epsilon = 1e-8
eps = fluid.layers.ones(shape=[1], dtype="float32") * epsilon
loss = reduce_sum(
-1.0 * label * log(elementwise_max(input, eps)), dim=1, keep_dim=True)
return loss
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--phase",
type=str,
default="teacher",
help="choose from teacher or student")
parser.add_argument(
"--stu_hsize",
type=int,
default=30,
help="The hidden layer size of student net")
parser.add_argument(
"--drop_prob",
type=float,
default=0.1,
help="The dropout probability of fully connected layers")
parser.add_argument(
"--temp",
type=float,
default=4.0,
help="The temperature of softmax which is used to generate soft targets")
parser.add_argument(
"--teacher_dir",
type=str,
default="./models/teacher_net",
help="Set the directory for saving teacher network parameters")
parser.add_argument(
"--student_dir",
type=str,
default="./models/student_net",
help="set the directory for saving student network parameters")
parser.add_argument(
"--data_dir",
type=str,
default="./data/",
help="The directory of datasets")
parser.add_argument(
"--use_soft",
type=bool,
default=False,
help="Whether using soft targets to train student network, set a switch to True"
)
args = parser.parse_args()
print(args)
if args.phase == "teacher":
print("Infer Teacher networks")
data_dir = args.data_dir
teacher_dir = args.teacher_dir
temp = args.temp
img = fluid.layers.data(name="img", shape=[1, 28, 28], dtype="float32")
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
logits = mlp_teacher(img, drop_prob=0.0)
softmax_logits = fluid.layers.softmax(logits)
temp_softmax_logits = softmax_with_temperature(logits, temp=temp)
cost = fluid.layers.cross_entropy(input=softmax_logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
batch_acc = fluid.layers.accuracy(input=softmax_logits, label=label)
train_set, test_set = read_data_sets(
is_soft=False, one_hot=False, reshape=True)
G_batch_size = 100
T_batch_size = 100
train_x = train_set.images
train_y = train_set.labels
test_x = test_set.images
test_y = test_set.labels
print(train_set.num_examples, train_set.images.shape)
test_iternum = test_set.num_examples // T_batch_size
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(fluid.default_startup_program())
inference_program = fluid.default_main_program().clone(for_test=False)
print("loading params")
fluid.io.load_params(
exe, dirname=teacher_dir, main_program=fluid.default_main_program())
print("Check testing")
test_loss_list = []
test_acc_list = []
for i in range(test_iternum):
test_batch = list(
zip(test_x[i * T_batch_size:(i + 1) * T_batch_size], test_y[
i * T_batch_size:(i + 1) * T_batch_size]))
loss, acc = exe.run(inference_program,
feed=feeder.feed(test_batch),
fetch_list=[avg_cost, batch_acc])
test_loss_list.append(loss)
test_acc_list.append(acc)
print("Testing Loss {}, Acc {}".format(
np.mean(test_loss_list), np.mean(test_acc_list)))
print("Generating soft-targets.......")
g_iternum = train_set.num_examples // G_batch_size
soft_list = []
for i in range(g_iternum):
g_batch = list(
zip(train_x[i * G_batch_size:(i + 1) * G_batch_size], train_y[
i * G_batch_size:(i + 1) * G_batch_size]))
soft_targets, gb_acc = exe.run(
inference_program,
feed=feeder.feed(g_batch),
fetch_list=[temp_softmax_logits, batch_acc])
soft_list.append(soft_targets)
train_y_soft = np.vstack(soft_list)
print("saving soft targets")
np.savez(
data_dir + "mnist_soft_{}.npz".format(temp),
train_x=train_x,
train_y=train_y,
test_x=test_x,
test_y=test_y,
train_y_soft=train_y_soft)
elif args.phase == "student":
print("Infer Student Network")
student_dir = args.student_dir
h_size = args.stu_hsize
temp = args.temp
use_soft = args.use_soft
img = fluid.layers.data(name="img", shape=[1, 28, 28], dtype="float32")
label = fluid.layers.data(name="label", shape=[10], dtype="float32")
soft_label = fluid.layers.data(
name="soft_label", shape=[10], dtype="float32")
logits = mlp_student(img, h_size=h_size, drop_prob=0.0)
softmax_logits = fluid.layers.softmax(logits)
temp_softmax_logits = softmax_with_temperature(logits, temp=temp)
hard_loss = soft_crossentropy(input=softmax_logits, label=label)
soft_loss = soft_crossentropy(
input=temp_softmax_logits, label=soft_label)
if use_soft:
use_soft = "T"
cost = hard_loss + soft_loss * temp**2
else:
use_soft = "F"
cost = hard_loss
avg_cost = fluid.layers.mean(x=cost)
top1_values, top1_indices = fluid.layers.topk(label, k=1)
print(top1_indices.shape)
batch_acc = fluid.layers.accuracy(
input=softmax_logits, label=top1_indices)
dirname = student_dir + "_h_{}_t_{}_soft_{}".format(
h_size, str(temp), use_soft) # "F" means soft loss false
inference_program = fluid.default_main_program().clone(for_test=False)
place = fluid.CUDAPlace(1)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(
feed_list=[img, label, soft_label], place=place)
exe.run(fluid.default_startup_program())
train_set, test_set = read_data_sets(
is_soft=True, one_hot=True, reshape=True, temp=str(temp))
T_batch_size = 100
Tr_batch_size = 128
train_x = train_set.images
train_y = train_set.labels
test_x = test_set.images
test_y = test_set.labels
test_soft_y = test_set.soft_labels
train_iternum = train_set.num_examples // Tr_batch_size
test_iternum = test_set.num_examples // T_batch_size
print("loading params, ", dirname)
fluid.io.load_params(
exe, dirname=dirname, main_program=fluid.default_main_program())
print("Check testing")
test_loss_list = []
test_acc_list = []
sl_list = []
for i in range(test_iternum):
test_batch = list(
zip(test_x[i * T_batch_size:(i + 1) * T_batch_size], test_y[
i * T_batch_size:(i + 1) * T_batch_size], test_soft_y[
i * T_batch_size:(i + 1) * T_batch_size]))
loss, acc, sl = exe.run(
inference_program,
feed=feeder.feed(test_batch),
fetch_list=[avg_cost, batch_acc, softmax_logits])
test_loss_list.append(loss)
test_acc_list.append(acc)
sl_list.append(sl)
preds = np.argmax(np.vstack(sl_list), axis=1)
labels_ = np.argmax(test_y, axis=1)
print("Number of Correct predictions, ", np.sum(preds == labels_))
print("Testing Loss {}, Acc {}".format(
np.mean(test_loss_list), np.mean(test_acc_list)))
else:
print("Please choose teacher or student for --phase")
if __name__ == "__main__":
main()
import paddle
import numpy as np
reader = paddle.dataset.mnist.train()
img_list = []
label_list = []
for e in reader():
img_list.append(e[0])
label_list.append(e[1])
train_x = np.vstack(img_list)
train_y = np.vstack(label_list)
print(train_x.shape, train_y.shape)
print(np.min(train_x), np.max(train_x))
reader = paddle.dataset.mnist.test()
img_list = []
label_list = []
for e in reader():
img_list.append(e[0])
label_list.append(e[1])
test_x = np.vstack(img_list)
test_y = np.vstack(label_list)
print(test_x.shape, test_y.shape)
np.savez(
'./data/mnist.npz',
train_x=train_x,
train_y=train_y,
test_x=test_x,
test_y=test_y)
import numpy as np
import argparse
import paddle.fluid as fluid
from utils import read_data_sets
from paddle.fluid.layers import log, elementwise_max, create_tensor, assign, reduce_sum
def mlp_teacher(img, drop_prob):
h1 = fluid.layers.fc(input=img, size=1200, act='relu')
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=1200, act='relu')
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
def mlp_student(img, drop_prob, h_size):
h1 = fluid.layers.fc(input=img, size=h_size, act='relu')
drop1 = fluid.layers.dropout(h1, dropout_prob=drop_prob)
h2 = fluid.layers.fc(input=drop1, size=h_size, act='relu')
drop2 = fluid.layers.dropout(h2, dropout_prob=drop_prob)
logits = fluid.layers.fc(input=drop2, size=10, act=None)
return logits
def softmax_with_temperature(logits, temp=1.0):
logits_with_temp = logits / temp
_softmax = fluid.layers.softmax(logits_with_temp)
return _softmax
def soft_crossentropy(input, label):
epsilon = 1e-8
eps = fluid.layers.ones(shape=[1], dtype='float32') * epsilon
loss = reduce_sum(
-1.0 * label * log(elementwise_max(input, eps)), dim=1, keep_dim=True)
return loss
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--phase",
type=str,
default="teacher",
help="choose from teacher or student")
parser.add_argument(
"--stu_hsize",
type=int,
default=30,
help="The hidden layer size of student net")
parser.add_argument(
"--drop_prob",
type=float,
default=0.1,
help="The dropout probability of fully connected layers")
parser.add_argument(
"--temp",
type=float,
default=4.0,
help='The temperature of softmax which is used to generate soft targets')
parser.add_argument(
"--teacher_dir",
type=str,
default="./models/teacher_net",
help="Set the directory for saving teacher network parameters")
parser.add_argument(
"--student_dir",
type=str,
default="./models/student_net",
help="set the directory for saving student network parameters")
parser.add_argument(
"--use_soft",
type=bool,
default=False,
help="Whether using soft targets to train student network, set a switch to True"
)
parser.add_argument(
"--epoch_num", type=int, default=200, help="Number of training epoches")
args = parser.parse_args()
print(args)
if args.phase == 'teacher':
print("Training Teacher network")
drop_prob = args.drop_prob
teacher_dir = args.teacher_dir
epoch_num = args.epoch_num
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
logits = mlp_teacher(img, drop_prob=drop_prob)
softmax_logits = fluid.layers.softmax(logits)
cost = fluid.layers.cross_entropy(input=softmax_logits, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_cost)
batch_acc = fluid.layers.accuracy(input=softmax_logits, label=label)
train_set, test_set = read_data_sets(
is_soft=False, one_hot=False, reshape=True)
G_batch_size = 100
T_batch_size = 100
Tr_batch_size = 128
train_x = train_set.images
train_y = train_set.labels
test_x = test_set.images
test_y = test_set.labels
print(train_set.num_examples, train_set.images.shape)
train_iternum = train_set.num_examples // Tr_batch_size
test_iternum = test_set.num_examples // T_batch_size
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(fluid.default_startup_program())
inference_program = fluid.default_main_program().clone(for_test=False)
print("Begin training teacher network")
for epoch in range(epoch_num):
train_loss_list = []
train_acc_list = []
for i in range(train_iternum):
train_batch = train_set.next_batch(Tr_batch_size)
loss, acc = exe.run(fluid.default_main_program(),
feed=feeder.feed(train_batch),
fetch_list=[avg_cost, batch_acc])
train_loss_list.append(loss)
train_acc_list.append(acc)
test_loss_list = []
test_acc_list = []
for i in range(test_iternum):
test_batch = list(
zip(test_x[i * T_batch_size:(i + 1) * T_batch_size], test_y[
i * T_batch_size:(i + 1) * T_batch_size]))
loss, acc = exe.run(inference_program,
feed=feeder.feed(test_batch),
fetch_list=[avg_cost, batch_acc])
test_loss_list.append(loss)
test_acc_list.append(acc)
print(
"Epoch {}, train acc {}, train loss {} ; test acc {}, test loss {} ".
format(epoch,
np.mean(train_acc_list),
np.mean(train_loss_list),
np.mean(test_acc_list), np.mean(test_loss_list)))
fluid.io.save_params(
exe, dirname=teacher_dir, main_program=fluid.default_main_program())
print('Train teacher network done')
elif args.phase == 'student':
print("Training Student Network")
drop_prob = args.drop_prob
student_dir = args.student_dir
h_size = args.stu_hsize
temp = args.temp
use_soft = args.use_soft
epoch_num = args.epoch_num
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[10], dtype='float32')
soft_label = fluid.layers.data(
name='soft_label', shape=[10], dtype='float32')
logits = mlp_student(img, h_size=h_size, drop_prob=drop_prob)
softmax_logits = fluid.layers.softmax(logits)
temp_softmax_logits = softmax_with_temperature(logits, temp=temp)
hard_loss = soft_crossentropy(input=softmax_logits, label=label)
soft_loss = soft_crossentropy(
input=temp_softmax_logits, label=soft_label)
if use_soft:
use_soft = 'T'
cost = hard_loss + soft_loss * temp**2
else:
use_soft = 'F'
cost = hard_loss
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
# optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_cost)
top1_values, top1_indices = fluid.layers.topk(label, k=1)
print(top1_indices.shape)
batch_acc = fluid.layers.accuracy(
input=softmax_logits, label=top1_indices)
dirname = student_dir + '_h_{}_t_{}_soft_{}'.format(
h_size, str(temp), use_soft) # 'F' means soft loss false
inference_program = fluid.default_main_program().clone(for_test=False)
place = fluid.CUDAPlace(1)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(
feed_list=[img, label, soft_label], place=place)
exe.run(fluid.default_startup_program())
train_set, test_set = read_data_sets(
is_soft=True, one_hot=True, reshape=True, temp=str(temp))
T_batch_size = 100
Tr_batch_size = 128
train_x = train_set.images
train_y = train_set.labels
test_x = test_set.images
test_y = test_set.labels
test_soft_y = test_set.soft_labels
train_iternum = train_set.num_examples // Tr_batch_size
test_iternum = test_set.num_examples // T_batch_size
for epoch in range(epoch_num):
train_loss_list = []
train_acc_list = []
for i in range(train_iternum):
train_batch = train_set.next_batch(Tr_batch_size)
loss, acc = exe.run(fluid.default_main_program(),
feed=feeder.feed(train_batch),
fetch_list=[avg_cost, batch_acc])
train_loss_list.append(loss)
train_acc_list.append(acc)
test_loss_list = []
test_acc_list = []
for i in range(test_iternum):
test_batch = list(
zip(test_x[i * T_batch_size:(i + 1) * T_batch_size], test_y[
i * T_batch_size:(i + 1) * T_batch_size], test_soft_y[
i * T_batch_size:(i + 1) * T_batch_size]))
loss, acc = exe.run(inference_program,
feed=feeder.feed(test_batch),
fetch_list=[avg_cost, batch_acc])
test_loss_list.append(loss)
test_acc_list.append(acc)
print(
"Epoch {}, train acc {}, train loss {} ; test acc {}, test loss {} ".
format(epoch,
np.mean(train_acc_list),
np.mean(train_loss_list),
np.mean(test_acc_list), np.mean(test_loss_list)))
fluid.io.save_params(
exe, dirname=dirname, main_program=fluid.default_main_program())
print('Train Student done')
print(dirname)
else:
print("Please choose teacher or student for --phase")
if __name__ == '__main__':
main()
"""
This script contains the dataset loader
"""
import numpy as np
def dense_to_one_hot(labels_dense, num_classes):
num_labels = labels_dense.shape[0]
index_offset = np.arange(num_labels) * num_classes
labels_one_hot = np.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
class Dataset(object):
def __init__(self,
images,
labels,
soft_labels=None,
one_hot=False,
reshape=False,
seed=123):
np.random.seed(seed)
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape %s' % (images.shape, labels.shape))
self._num_examples = images.shape[0]
self.images = images
self.labels = labels
self.soft_labels = soft_labels
if reshape:
self.images = self.images.reshape(images.shape[0], 1, 28, 28)
if one_hot:
self.labels = dense_to_one_hot(self.labels, 10)
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, shuffle=True):
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples or self._index_in_epoch == batch_size and shuffle:
# print("Data Shuffling")
self._epochs_completed += 1
perm0 = np.arange(self._num_examples)
np.random.shuffle(perm0)
self.images = self.images[perm0]
self.labels = self.labels[perm0]
if self.soft_labels is not None:
# print("soft shape", self.soft_labels.shape)
self.soft_labels = self.soft_labels[perm0]
# start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
img_batch = self.images[start:end]
label_batch = self.labels[start:end]
if self.soft_labels is not None:
soft_label_batch = self.soft_labels[start:end]
return list(zip(img_batch, label_batch, soft_label_batch))
else:
# print(img_batch.shape, label_batch.shape)
return list(zip(img_batch, label_batch))
def read_data_sets(is_soft=False, one_hot=False, reshape=False, temp=str(3.0)):
if is_soft:
data_dir = './data/mnist_soft_{}.npz'.format(temp)
else:
data_dir = './data/mnist.npz'
print("Loading ", data_dir)
mnist_data = np.load(data_dir)
train_x = mnist_data['train_x']
train_y = mnist_data['train_y']
test_x = mnist_data['test_x']
test_y = mnist_data['test_y']
if is_soft:
train_y_soft = mnist_data['train_y_soft']
test_y_soft = np.zeros(shape=(test_x.shape[0], 10))
train_set = Dataset(
train_x,
train_y,
soft_labels=train_y_soft,
one_hot=one_hot,
reshape=reshape)
test_set = Dataset(
test_x,
test_y,
soft_labels=test_y_soft,
one_hot=one_hot,
reshape=reshape)
# test_set = Dataset(test_x, test_y, one_hot=one_hot, reshape=reshape)
else:
train_set = Dataset(train_x, train_y, one_hot=one_hot, reshape=reshape)
test_set = Dataset(test_x, test_y, one_hot=one_hot, reshape=reshape)
return train_set, test_set
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册