未验证 提交 8c6e3ab9 编写于 作者: C ceci3 提交者: GitHub

add tensorflow mobilenetv1 model for auto compress (#1183)

* add x2paddle

* update reader
上级 b18e0ff4
......@@ -25,6 +25,25 @@
- 测试环境:`SDM710 2*A75(2.2GHz) 6*A55(1.7GHz)`
- MobileNetV1模型
| 模型 | 策略 | Top-1 Acc | 耗时(ms) threads=4 | Inference模型 |
|:------:|:------:|:------:|:------:|:------:|
| MobileNetV1 | Base模型 | 71.0 | - | [Model]() |
| MobileNetV1 | 量化+蒸馏 | 70.22 | -| [Model]() |
- 测试环境:
说明:
- MobileNetV1模型源自[tensorflow/models](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz),通过[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)工具转换MobileNetV1预测模型步骤:
(1) 安装X2Paddle的1.3.6以上版本;(pip install x2paddle)
(2) 转换模型:
x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model
即可得到MobileNetV1模型的预测模型(`model.pdmodel``model.pdiparams`)。如想快速体验,可直接下载上方表格中MobileNetV1的Base预测模型。
## 3. 自动压缩流程
#### 3.1 准备环境
......
Distillation:
alpha: 1.0
loss: l2
node:
- softmax_0.tmp_0
teacher_model_dir: MobileNetV1_infer
teacher_model_filename: inference.pdmodel
teacher_params_filename: inference.pdiparams
Quantization:
activation_bits: 8
is_full_quantize: false
......
......@@ -10,7 +10,8 @@ import numpy as np
import paddle
import paddle.nn as nn
from paddle.io import Dataset, BatchSampler, DataLoader
import imagenet_reader as reader
import imagenet_reader as pd_imagenet_reader
import tf_imagenet_reader
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression import AutoCompression
from utility import add_arguments, print_arguments
......@@ -26,14 +27,17 @@ add_arg('save_dir', str, None, "directory to save
add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.")
add_arg('data_dir', str, None, "path of dataset")
add_arg('input_name', str, "inputs", "input name of the model")
add_arg('input_name', str, "inputs", "input name of the model")
add_arg('input_shape', int, [3,224,224], "input shape of the model except batch_size", nargs='+')
add_arg('image_reader_type', str, "paddle", "the preprocess of data. choice in [\"paddle\", \"tensorflow\"]")
# yapf: enable
def reader_wrapper(reader, input_name):
def reader_wrapper(reader, input_name, input_shape):
def gen():
for i, data in enumerate(reader()):
imgs = np.float32([item[0] for item in data])
imgs = imgs.reshape([len(data)] + input_shape)
yield {input_name: imgs}
return gen
......@@ -48,7 +52,7 @@ def eval_reader(data_dir, batch_size):
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = eval_reader(data_dir, batch_size=args.batch_size)
image = paddle.static.data(
name=args.input_name, shape=[None, 3, 224, 224], dtype='float32')
name=args.input_name, shape=[None] + args.input_shape, dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
results = []
......@@ -56,7 +60,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
# top1_acc, top5_acc
if len(test_feed_names) == 1:
image = np.array([[d[0]] for d in data])
image = image.reshape((len(data), 3, 224, 224))
image = image.reshape([len(data)] + args.input_shape)
label = [[d[1]] for d in data]
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
......@@ -76,7 +80,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array([[d[0]] for d in data])
image = image.reshape((len(data), 3, 224, 224))
image = image.reshape([len(data)] + args.input_shape)
label = [[d[1]] for d in data]
label = [[d[1]] for d in data]
result = exe.run(
compiled_test_program,
......@@ -98,9 +103,18 @@ if __name__ == '__main__':
compress_config, train_config, _ = load_config(args.config_path)
data_dir = args.data_dir
if args.image_reader_type == 'paddle':
reader = pd_imagenet_reader
elif args.image_reader_type == 'tensorflow':
reader = tf_imagenet_reader
else:
raise NotImplementedError(
"image_reader_type only can be set to paddle or tensorflow, but now is {}".
format(args.image_reader_type))
train_reader = paddle.batch(
reader.train(data_dir=data_dir), batch_size=args.batch_size)
train_dataloader = reader_wrapper(train_reader, args.input_name)
train_dataloader = reader_wrapper(train_reader, args.input_name,
args.input_shape)
ac = AutoCompression(
model_dir=args.model_dir,
......@@ -111,6 +125,8 @@ if __name__ == '__main__':
train_config=train_config,
train_dataloader=train_dataloader,
eval_callback=eval_function,
eval_dataloader=reader_wrapper(eval_reader(data_dir, args.batch_size), args.input_name))
eval_dataloader=reader_wrapper(
eval_reader(data_dir, args.batch_size), args.input_name,
args.input_shape))
ac.compress()
......@@ -6,7 +6,9 @@ python run.py \
--params_filename='inference.pdiparams' \
--save_dir='./save_quant_mobilev1/' \
--batch_size=128 \
--config_path='./configs/mobilev1.yaml'\
--config_path='./configs/mobilenetv1_qat_dis.yaml'\
--input_shape 3 224 224 \
--image_reader_type='paddle' \
--data_dir='ILSVRC2012'
# 多卡启动
......@@ -16,6 +18,6 @@ python run.py \
# --params_filename='inference.pdiparams' \
# --save_dir='./save_quant_mobilev1/' \
# --batch_size=128 \
# --config_path='./configs/mobilev1.yaml'\
# --data_dir='/workspace/dataset/ILSVRC2012/'
# --config_path='./configs/mobilenetv1_qat_dis.yaml'\
# --data_dir='ILSVRC2012'
# 单卡启动
export CUDA_VISIBLE_DEVICES=0
python run.py \
--model_dir='inference_model_usex2paddle' \
--model_filename='model.pdmodel' \
--params_filename='model.pdiparams' \
--save_dir='./save_quant_mobilev1/' \
--batch_size=128 \
--config_path='./configs/mobilenetv1_qat_dis.yaml'\
--input_shape 224 224 3 \
--image_reader_type='tensorflow' \
--input_name "input" \
--data_dir='ILSVRC2012'
import os
import math
import random
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
import cv2
from paddle.io import Dataset
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
THREAD = 16
BUF_SIZE = 10240
DATA_DIR = 'data/ILSVRC2012/'
DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR)
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = pil_img_2_cv2(img)
img = cv2.resize(
img, (resized_width, resized_height), interpolation=cv2.INTER_LINEAR)
img = cv2_img_2_pil(img)
return img
def pil_img_2_cv2(img):
return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
def cv2_img_2_pil(img):
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
def crop_image(img, target_size, center, central_fraction=0.875):
width, height = img.size
size = target_size
if center == True:
left = int((width - width * central_fraction) / 2.0)
right = width - left
top = int((height - height * central_fraction) / 2.0)
bottom = height - top
img = img.crop((left, top, right, bottom))
img = pil_img_2_cv2(img)
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_LINEAR)
img = cv2_img_2_pil(img)
else:
img = resize_short(img, target_size=256)
width, height = img.size
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def rotate_image(img):
angle = np.random.randint(-10, 11)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
try:
img = Image.open(img_path)
except:
print(img_path, "not exists!")
return None
if mode == 'train':
if rotate: img = rotate_image(img)
img = crop_image(img, target_size=DATA_DIM, center=False)
else:
img = crop_image(img, target_size=DATA_DIM, center=True)
if mode == 'train':
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.float32(img)
img = img / 255.0
img -= 0.5
img *= 2.0
if mode == 'train' or mode == 'val':
return img, sample[1]
elif mode == 'test':
return [img]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR,
batch_size=1):
def reader():
try:
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label) + 1
elif mode == 'test':
img_path = os.path.join(data_dir, line)
yield [img_path]
except Exception as e:
print("Reader failed!\n{}".format(str(e)))
os._exit(1)
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
file_list,
'train',
shuffle=True,
color_jitter=False,
rotate=False,
data_dir=data_dir)
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'test_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册