未验证 提交 a769b284 编写于 作者: B Bai Yifan 提交者: GitHub

Fix deep mutual learning doc issue and code clean (#402)

* fix dml issue

* disable multiprocess reader

* add other model

* add dml unittest

* fix log_softmax

* make dml compatible with 1.8 and 2.0 version paddle

* fix coverage
上级 fad86fc1
# 深度互学习DML(Deep Mutual Learning)
本示例介绍如何使用PaddleSlim的深度互学习DML方法训练模型,算法原理请参考论文 [Deep Mutual Learning](https://arxiv.org/abs/1706.00384)
![dml_architect](./images/dml_architect.png)
## 使用数据
示例中使用cifar100数据集进行训练, 您可以在启动训练时等待自动下载,
也可以在自行下载[数据集](https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz)之后,放在当前目录的`./dataset/cifar100`路径下
## 启动命令
### 训练MobileNet-Mobilenet的组合
单卡训练, 以0号GPU为例:
```bash
CUDA_VISIBLE_DEVICES=0 python dml_train.py
```
多卡训练, 以0-3号GPU为例:
```bash
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog dml_train.py --use_parallel=True
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog dml_train.py --use_parallel=True --init_lr=0.4
```
### 训练MobileNet-ResNet50的组合
单卡训练, 以0号GPU为例:
```bash
CUDA_VISIBLE_DEVICES=0 python dml_train.py --models='mobilenet-resnet50'
```
多卡训练, 以0-3号GPU为例:
```bash
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog dml_train.py --use_parallel=True --init_lr=0.4 --models='mobilenet-resnet50'
```
## 实验结果
以下实验结果可以由默认实验配置(学习率、优化器等)训练得到,仅调整了DML训练的模型组合
......
......@@ -102,8 +102,7 @@ def cifar100_reader(file_name, data_name, is_shuffle):
for name in names:
print("Reading file " + name)
try:
batch = cPickle.load(
f.extractfile(name), encoding='iso-8859-1')
batch = cPickle.load(f.extractfile(name), encoding='iso-8859-1')
except:
batch = cPickle.load(f.extractfile(name))
data = batch['data']
......
......@@ -26,6 +26,7 @@ from paddle.fluid.dygraph.base import to_variable
from paddleslim.common import AvgrageMeter, get_logger
from paddleslim.dist import DML
from paddleslim.models.dygraph import MobileNetV1
from paddleslim.models.dygraph import ResNet
import cifar100_reader as reader
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
from utility import add_arguments, print_arguments
......@@ -37,6 +38,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('log_freq', int, 100, "Log frequency.")
add_arg('models', str, "mobilenet-mobilenet", "model.")
add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('init_lr', float, 0.1, "The start learning rate.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
......@@ -44,7 +46,6 @@ add_arg('epochs', int, 200, "Epoch number.")
add_arg('class_num', int, 100, "Class number of dataset.")
add_arg('trainset_num', int, 50000, "Images number of trainset.")
add_arg('model_save_dir', str, 'saved_models', "The path to save model.")
add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.")
add_arg('use_parallel', bool, False, "Whether to use data parallel mode to train the model.")
# yapf: enable
......@@ -78,13 +79,9 @@ def create_reader(place, args):
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
train_loader = fluid.io.DataLoader.from_generator(
capacity=1024,
return_list=True,
use_multiprocess=args.use_multiprocess)
capacity=1024, return_list=True)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=1024,
return_list=True,
use_multiprocess=args.use_multiprocess)
capacity=1024, return_list=True)
train_loader.set_batch_generator(train_reader, places=place)
valid_loader.set_batch_generator(valid_reader, places=place)
return train_loader, valid_loader
......@@ -160,10 +157,19 @@ def main(args):
train_loader, valid_loader = create_reader(place, args)
# 2. Define neural network
if args.models == "mobilenet-mobilenet":
models = [
MobileNetV1(class_dim=args.class_num),
MobileNetV1(class_dim=args.class_num)
]
elif args.models == "mobilenet-resnet50":
models = [
MobileNetV1(class_dim=args.class_num),
ResNet(class_dim=args.class_num)
]
else:
logger.info("You can define the model as you wish")
return
optimizers = create_optimizer(models, args)
# 3. Use PaddleSlim DML strategy
......
......@@ -17,11 +17,19 @@ from __future__ import division
from __future__ import print_function
import copy
import paddle
import paddle.fluid as fluid
PADDLE_VERSION = 1.8
try:
from paddle.fluid.layers import log_softmax
except:
from paddle.nn import LogSoftmax
PADDLE_VERSION = 2.0
class DML(fluid.dygraph.Layer):
def __init__(self, model, use_parallel):
def __init__(self, model, use_parallel=False):
super(DML, self).__init__()
self.model = model
self.use_parallel = use_parallel
......@@ -54,8 +62,7 @@ class DML(fluid.dygraph.Layer):
for i in range(self.model_num):
ce_losses.append(
fluid.layers.mean(
fluid.layers.softmax_with_cross_entropy(logits[i],
labels)))
fluid.layers.softmax_with_cross_entropy(logits[i], labels)))
return ce_losses
def kl_loss(self, logits):
......@@ -69,6 +76,10 @@ class DML(fluid.dygraph.Layer):
cur_kl_loss = 0
for j in range(self.model_num):
if i != j:
if PADDLE_VERSION == 2.0:
log_softmax = LogSoftmax(axis=1)
x = log_softmax(logits[i])
else:
x = fluid.layers.log_softmax(logits[i], axis=1)
y = fluid.layers.softmax(logits[j], axis=1)
cur_kl_loss += fluid.layers.kldiv_loss(
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import logging
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.dataset.mnist as reader
from paddle.fluid.dygraph.base import to_variable
from paddleslim.models.dygraph import MobileNetV1
from paddleslim.dist import DML
from paddleslim.common import get_logger
logger = get_logger(__name__, level=logging.INFO)
class Model(fluid.dygraph.Layer):
def __init__(self):
super(Model, self).__init__()
self.conv = fluid.dygraph.nn.Conv2D(
num_channels=1,
num_filters=256,
filter_size=3,
stride=1,
padding=1,
use_cudnn=False)
self.pool2d_avg = fluid.dygraph.nn.Pool2D(
pool_type='avg', global_pooling=True)
self.out = fluid.dygraph.nn.Linear(256, 10)
def forward(self, inputs):
inputs = fluid.layers.reshape(inputs, shape=[0, 1, 28, 28])
y = self.conv(inputs)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, 256])
y = self.out(y)
return y
class TestDML(unittest.TestCase):
def test_dml(self):
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.dygraph.guard(place):
train_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=256)
train_loader = fluid.io.DataLoader.from_generator(
capacity=1024, return_list=True)
train_loader.set_sample_list_generator(train_reader, places=place)
models = [Model(), Model()]
optimizers = []
for cur_model in models:
opt = fluid.optimizer.MomentumOptimizer(
0.1, 0.9, parameter_list=cur_model.parameters())
optimizers.append(opt)
dml_model = DML(models)
dml_optimizer = dml_model.opt(optimizers)
def train(train_loader, dml_model, dml_optimizer):
dml_model.train()
for step_id, (images, labels) in enumerate(train_loader):
images, labels = to_variable(images), to_variable(labels)
labels = fluid.layers.reshape(labels, [0, 1])
logits = dml_model.forward(images)
precs = [
fluid.layers.accuracy(
input=l, label=labels, k=1).numpy() for l in logits
]
losses = dml_model.loss(logits, labels)
dml_optimizer.minimize(losses)
if step_id % 10 == 0:
print(step_id, precs)
for epoch_id in range(10):
current_step_lr = dml_optimizer.get_lr()
lr_msg = "Epoch {}".format(epoch_id)
for model_id, lr in enumerate(current_step_lr):
lr_msg += ", {} lr: {:.6f}".format(
dml_model.full_name()[model_id], lr)
logger.info(lr_msg)
train(train_loader, dml_model, dml_optimizer)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册