未验证 提交 6a477596 编写于 作者: L LI Xuhong 提交者: GitHub

add DELTA algorithm (#551)

* add DELTA
上级 d569573b
......@@ -13,9 +13,8 @@ module = hub.Module(name="pyramidbox_lite_server_mask")
def paint_chinese(im, chinese, position, fontsize, color_bgr):
# 图像从OpenCV格式转换成PIL格式
img_PIL = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
font = ImageFont.truetype('SourceHanSansSC-Medium.otf',
fontsize,
encoding="utf-8")
font = ImageFont.truetype(
'SourceHanSansSC-Medium.otf', fontsize, encoding="utf-8")
#color = (255,0,0) # 字体颜色
#position = (100,100)# 文字输出位置
color = color_bgr[::-1]
......
......@@ -40,13 +40,17 @@ def predict(args):
# define model(program)
module = hub.Module(name=module_name)
if model_type == 'rcnn':
input_dict, output_dict, program = module.context(trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(trainable=False)
input_dict, output_dict, program = module.context(
trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(
trainable=False)
else:
input_dict, output_dict, program = module.context(trainable=True)
input_dict_pred = output_dict_pred = None
feed_list, pred_feed_list = get_feed_list(module_name, input_dict, input_dict_pred)
feature, pred_feature = get_mid_feature(module_name, output_dict, output_dict_pred)
feed_list, pred_feed_list = get_feed_list(module_name, input_dict,
input_dict_pred)
feature, pred_feature = get_mid_feature(module_name, output_dict,
output_dict_pred)
config = hub.RunConfig(
use_data_parallel=False,
......@@ -67,7 +71,10 @@ def predict(args):
model_type=model_type,
config=config)
data = ["./test/test_img_bird.jpg", "./test/test_img_cat.jpg",]
data = [
"./test/test_img_bird.jpg",
"./test/test_img_cat.jpg",
]
label_map = ds.label_dict()
run_states = task.predict(data=data, accelerate_mode=False)
results = [run_state.run_results for run_state in run_states]
......
......@@ -48,20 +48,24 @@ def finetune(args):
# define model(program)
module = hub.Module(name=module_name)
if model_type == 'rcnn':
input_dict, output_dict, program = module.context(trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(trainable=False)
input_dict, output_dict, program = module.context(
trainable=True, phase='train')
input_dict_pred, output_dict_pred, program_pred = module.context(
trainable=False)
else:
input_dict, output_dict, program = module.context(trainable=True)
input_dict_pred = output_dict_pred = None
print("input_dict keys", input_dict.keys())
print("output_dict keys", output_dict.keys())
feed_list, pred_feed_list = get_feed_list(module_name, input_dict, input_dict_pred)
feed_list, pred_feed_list = get_feed_list(module_name, input_dict,
input_dict_pred)
print("output_dict length:", len(output_dict))
print(output_dict.keys())
if output_dict_pred is not None:
print(output_dict_pred.keys())
feature, pred_feature = get_mid_feature(module_name, output_dict, output_dict_pred)
feature, pred_feature = get_mid_feature(module_name, output_dict,
output_dict_pred)
config = hub.RunConfig(
log_interval=10,
......@@ -73,7 +77,8 @@ def finetune(args):
batch_size=args.batch_size,
enable_memory_optim=False,
checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy(learning_rate=0.00025, optimizer_name="adam"))
strategy=hub.finetune.strategy.DefaultFinetuneStrategy(
learning_rate=0.00025, optimizer_name="adam"))
task = hub.DetectionTask(
data_reader=data_reader,
......
# Introduction
This page implements the [DELTA](https://arxiv.org/abs/1901.09229) algorithm in [PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick).
> Li, Xingjian, et al. "DELTA: Deep learning transfer using feature map with attention for convolutional networks." ICLR 2019.
# Preparation of Data and Pre-trained Model
- Download transfer learning target datasets, like [Caltech-256](http://www.vision.caltech.edu/Image_Datasets/Caltech256/), [CUB_200_2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) or others. Arrange the dataset in this way:
```
root/train/dog/xxy.jpg
root/train/dog/xxz.jpg
...
root/train/cat/nsdf3.jpg
root/train/cat/asd932_.jpg
...
root/test/dog/xxx.jpg
...
root/test/cat/123.jpg
...
```
- Download [the pretrained models](https://github.com/PaddlePaddle/models/tree/release/1.7/PaddleCV/image_classification#resnet-series). We give the results of ResNet-101 below.
# Running Scripts
Modify `global_data_path` in `datasets/data_path` to the path root where the dataset is.
```bash
python -u main.py --dataset Caltech30 --delta_reg 0.1 --wd_rate 1e-4 --batch_size 64 --outdir outdir --num_epoch 100 --use_cuda 0
python -u main.py --dataset CUB_200_2011 --delta_reg 0.1 --wd_rate 1e-4 --batch_size 64 --outdir outdir --num_epoch 100 --use_cuda 0
```
Those scripts give the results below:
\ | l2 | delta
---|---|---
Caltech-256|79.86|84.71
CUB_200|77.41|80.05
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--prefix', default=None, type=str, help='prefix for model id')
parser.add_argument('--dataset', default='PetImages', type=str, help='dataset')
parser.add_argument(
'--seed',
default=None,
type=int,
help='random seed (default: None, i.e., not fix the randomness).')
parser.add_argument('--batch_size', default=20, type=int, help='batch_size.')
parser.add_argument('--delta_reg', default=0.1, type=float, help='delta_reg.')
parser.add_argument('--wd_rate', default=1e-4, type=float, help='wd_rate.')
parser.add_argument(
'--use_cuda', default=0, type=int, help='use_cuda device. -1 cpu.')
parser.add_argument('--num_epoch', default=100, type=int, help='num_epoch.')
parser.add_argument('--outdir', default='outdir', type=str, help='outdir')
parser.add_argument(
'--pretrained_model',
default='./pretrained_models/ResNet101_pretrained',
type=str,
help='pretrained model pathname')
args = parser.parse_args()
global_data_path = '[root_path]/datasets'
import cv2
import numpy as np
import six
import os
import glob
def resize_short(img, target_size, interpolation=None):
"""resize image
Args:
img: image data
target_size: resize short target size
interpolation: interpolation mode
Returns:
resized image data
"""
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
if interpolation:
resized = cv2.resize(
img, (resized_width, resized_height), interpolation=interpolation)
else:
resized = cv2.resize(img, (resized_width, resized_height))
return resized
def crop_image(img, target_size, center):
"""crop image
Args:
img: images data
target_size: crop target size
center: crop mode
Returns:
img: cropped image data
"""
height, width = img.shape[:2]
size = target_size
if center == True:
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def preprocess_image(img, random_mirror=True):
"""
centered, scaled by 1/255.
:param img: np.array: shape: [ns, h, w, 3], color order: rgb.
:return: np.array: shape: [ns, h, w, 3]
"""
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# transpose to [ns, 3, h, w]
img = img.astype('float32').transpose((0, 3, 1, 2)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
if random_mirror:
mirror = int(np.random.uniform(0, 2))
if mirror == 1:
img = img[:, :, ::-1, :]
return img
def _find_classes(dir):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
class ReaderConfig():
"""
A generic data loader where the images are arranged in this way:
root/train/dog/xxy.jpg
root/train/dog/xxz.jpg
...
root/train/cat/nsdf3.jpg
root/train/cat/asd932_.jpg
...
root/test/dog/xxx.jpg
...
root/test/cat/123.jpg
...
"""
def __init__(self, dataset_dir, is_test):
image_paths, labels, self.num_classes = self.reader_creator(
dataset_dir, is_test)
random_per = np.random.permutation(range(len(image_paths)))
self.image_paths = image_paths[random_per]
self.labels = labels[random_per]
self.is_test = is_test
def get_reader(self):
def reader():
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm',
'.tif', '.tiff', '.webp')
target_size = 256
crop_size = 224
for i, img_path in enumerate(self.image_paths):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
img = cv2.imread(img_path)
if img is None:
print(img_path)
continue
img = resize_short(img, target_size, interpolation=None)
img = crop_image(img, crop_size, center=self.is_test)
img = img[:, :, ::-1]
img = np.expand_dims(img, axis=0)
img = preprocess_image(img, not self.is_test)
yield img, self.labels[i]
return reader
def reader_creator(self, dataset_dir, is_test=False):
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm',
'.tif', '.tiff', '.webp')
# read
if is_test:
datasubset_dir = os.path.join(dataset_dir, 'test')
else:
datasubset_dir = os.path.join(dataset_dir, 'train')
class_names, class_to_idx = _find_classes(datasubset_dir)
# num_classes = len(class_names)
image_paths = []
labels = []
for class_name in class_names:
classes_dir = os.path.join(datasubset_dir, class_name)
for img_path in glob.glob(os.path.join(classes_dir, '*')):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
image_paths.append(img_path)
labels.append(class_to_idx[class_name])
image_paths = np.array(image_paths)
labels = np.array(labels)
return image_paths, labels, len(class_names)
import os
import time
import sys
import math
import numpy as np
import functools
import re
import logging
import glob
import paddle
import paddle.fluid as fluid
from models.resnet import ResNet101
from datasets.readers import ReaderConfig
# import cv2
# import skimage
# import matplotlib.pyplot as plt
# from paddle.fluid.core import PaddleTensor
# from paddle.fluid.core import AnalysisConfig
# from paddle.fluid.core import create_paddle_predictor
from args import args
from datasets.data_path import global_data_path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
if args.seed is not None:
np.random.seed(args.seed)
print(os.environ.get('LD_LIBRARY_PATH', None))
print(os.environ.get('PATH', None))
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def load_vars_by_dict(executor, name_var_dict, main_program=None):
from paddle.fluid.framework import Program, Variable
from paddle.fluid import core
load_prog = Program()
load_block = load_prog.global_block()
if main_program is None:
main_program = fluid.default_main_program()
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
for each_var_name in name_var_dict.keys():
assert isinstance(name_var_dict[each_var_name], Variable)
if name_var_dict[each_var_name].type == core.VarDesc.VarType.RAW:
continue
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [name_var_dict[each_var_name]]},
attrs={'file_path': each_var_name})
executor.run(load_prog)
def get_model_id():
prefix = ''
if args.prefix is not None:
prefix = args.prefix + '-' # for some notes.
model_id = prefix + args.dataset + \
'-epo_' + str(args.num_epoch) + \
'-b_' + str(args.batch_size) + \
'-reg_' + str(args.delta_reg) + \
'-wd_' + str(args.wd_rate)
return model_id
def train():
dataset = args.dataset
image_shape = [3, 224, 224]
pretrained_model = args.pretrained_model
class_map_path = f'{global_data_path}/{dataset}/readable_label.txt'
if os.path.exists(class_map_path):
logger.info(
"The map of readable label and numerical label has been found!")
with open(class_map_path) as f:
label_dict = {}
strinfo = re.compile(r"\d+ ")
for item in f.readlines():
key = int(item.split(" ")[0])
value = [
strinfo.sub("", l).replace("\n", "")
for l in item.split(", ")
]
label_dict[key] = value[0]
assert os.path.isdir(
pretrained_model), "please load right pretrained model path for infer"
# data reader
batch_size = args.batch_size
reader_config = ReaderConfig(f'{global_data_path}/{dataset}', is_test=False)
reader = reader_config.get_reader()
train_reader = paddle.batch(
paddle.reader.shuffle(reader, buf_size=batch_size),
batch_size,
drop_last=True)
# model ops
image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
model = ResNet101(is_test=False)
features, logits = model.net(
input=image, class_dim=reader_config.num_classes)
out = fluid.layers.softmax(logits)
# loss, metric
cost = fluid.layers.mean(fluid.layers.cross_entropy(out, label))
accuracy = fluid.layers.accuracy(input=out, label=label)
# delta regularization
# teacher model pre-trained on Imagenet, 1000 classes.
global_name = 't_'
t_model = ResNet101(is_test=True, global_name=global_name)
t_features, _ = t_model.net(input=image, class_dim=1000)
for f in t_features.keys():
t_features[f].stop_gradient = True
# delta loss. hard code for the layer name, which is just before global pooling.
delta_loss = fluid.layers.square(t_features['t_res5c.add.output.5.tmp_0'] -
features['res5c.add.output.5.tmp_0'])
delta_loss = fluid.layers.reduce_mean(delta_loss)
params = fluid.default_main_program().global_block().all_parameters()
parameters = []
for param in params:
if param.trainable:
if global_name in param.name:
print('\tfixing', param.name)
else:
print('\ttraining', param.name)
parameters.append(param.name)
# optimizer, with piecewise_decay learning rate.
total_steps = len(reader_config.image_paths) * args.num_epoch // batch_size
boundaries = [int(total_steps * 2 / 3)]
print('\ttotal learning steps:', total_steps)
print('\tlr decays at:', boundaries)
values = [0.01, 0.001]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=boundaries, values=values),
momentum=0.9,
parameter_list=parameters,
regularization=fluid.regularizer.L2Decay(args.wd_rate))
cur_lr = optimizer._global_learning_rate()
optimizer.minimize(
cost + args.delta_reg * delta_loss, parameter_list=parameters)
# data reader
feed_order = ['image', 'label']
# executor (session)
place = fluid.CUDAPlace(
args.use_cuda) if args.use_cuda >= 0 else fluid.CPUPlace()
exe = fluid.Executor(place)
# running
main_program = fluid.default_main_program()
start_program = fluid.default_startup_program()
feed_var_list_loop = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
exe.run(start_program)
loading_parameters = {}
t_loading_parameters = {}
for p in main_program.all_parameters():
if 'fc' not in p.name:
if global_name in p.name:
new_name = os.path.join(pretrained_model,
p.name.split(global_name)[-1])
t_loading_parameters[new_name] = p
print(new_name, p.name)
else:
name = os.path.join(pretrained_model, p.name)
loading_parameters[name] = p
print(name, p.name)
else:
print(f'not loading {p.name}')
load_vars_by_dict(exe, loading_parameters, main_program=main_program)
load_vars_by_dict(exe, t_loading_parameters, main_program=main_program)
step = 0
# test_data = reader_creator_all_in_memory('./datasets/PetImages', is_test=True)
for e_id in range(args.num_epoch):
avg_delta_loss = AverageMeter()
avg_loss = AverageMeter()
avg_accuracy = AverageMeter()
batch_time = AverageMeter()
end = time.time()
for step_id, data_train in enumerate(train_reader()):
wrapped_results = exe.run(
main_program,
feed=feeder.feed(data_train),
fetch_list=[cost, accuracy, delta_loss, cur_lr])
# print(avg_loss_value[2])
batch_time.update(time.time() - end)
end = time.time()
avg_loss.update(wrapped_results[0][0], len(data_train))
avg_accuracy.update(wrapped_results[1][0], len(data_train))
avg_delta_loss.update(wrapped_results[2][0], len(data_train))
if step % 100 == 0:
print(
f"\tEpoch {e_id}, Global_Step {step}, Batch_Time {batch_time.avg: .2f},"
f" LR {wrapped_results[3][0]}, "
f"Loss {avg_loss.avg: .4f}, Acc {avg_accuracy.avg: .4f}, Delta_Loss {avg_delta_loss.avg: .4f}"
)
step += 1
if args.outdir is not None:
try:
os.makedirs(args.outdir, exist_ok=True)
fluid.io.save_params(
executor=exe, dirname=args.outdir + '/' + get_model_id())
except:
print('\t Not saving trained parameters.')
if e_id == args.num_epoch - 1:
print("kpis\ttrain_cost\t%f" % avg_loss.avg)
print("kpis\ttrain_acc\t%f" % avg_accuracy.avg)
def test():
image_shape = [3, 224, 224]
pretrained_model = args.outdir + '/' + get_model_id()
# data reader
batch_size = args.batch_size
reader_config = ReaderConfig(
f'{global_data_path}/{args.dataset}', is_test=True)
reader = reader_config.get_reader()
test_reader = paddle.batch(reader, batch_size)
# model ops
image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
model = ResNet101(is_test=True)
_, logits = model.net(input=image, class_dim=reader_config.num_classes)
out = fluid.layers.softmax(logits)
# loss, metric
cost = fluid.layers.mean(fluid.layers.cross_entropy(out, label))
accuracy = fluid.layers.accuracy(input=out, label=label)
# data reader
feed_order = ['image', 'label']
# executor (session)
place = fluid.CUDAPlace(
args.use_cuda) if args.use_cuda >= 0 else fluid.CPUPlace()
exe = fluid.Executor(place)
# running
main_program = fluid.default_main_program()
start_program = fluid.default_startup_program()
feed_var_list_loop = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
exe.run(start_program)
fluid.io.load_params(exe, pretrained_model)
step = 0
avg_loss = AverageMeter()
avg_accuracy = AverageMeter()
for step_id, data_train in enumerate(test_reader()):
avg_loss_value = exe.run(
main_program,
feed=feeder.feed(data_train),
fetch_list=[cost, accuracy])
avg_loss.update(avg_loss_value[0], len(data_train))
avg_accuracy.update(avg_loss_value[1], len(data_train))
if step_id % 10 == 0:
print("\nBatch %d, Loss %f, Acc %f" % (step_id, avg_loss.avg,
avg_accuracy.avg))
step += 1
print("test counts:", avg_loss.count)
print("test_cost\t%f" % avg_loss.avg)
print("test_acc\t%f" % avg_accuracy.avg)
if __name__ == '__main__':
print(args)
train()
test()
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# from https://github.com/PaddlePaddle/models/blob/release/1.7/PaddleCV/image_classification/models/resnet.py.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = [
"ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"
]
class ResNet():
def __init__(self, layers=50, is_test=True, global_name=''):
self.layers = layers
self.is_test = is_test
self.features = {}
self.global_name = global_name
def net(self, input, class_dim=1000, data_format="NCHW"):
layers = self.layers
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name="conv1",
data_format=data_format)
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name=self.global_name + 'poo1',
data_format=data_format)
self.features[conv.name] = conv
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name,
data_format=data_format)
self.features[conv.name] = conv
pool = fluid.layers.pool2d(
input=conv,
pool_type='avg',
global_pooling=True,
name=self.global_name + 'global_pooling',
data_format=data_format)
self.features[pool.name] = pool
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
bias_attr=fluid.param_attr.ParamAttr(
name=self.global_name + 'fc_0.b_0'),
param_attr=fluid.param_attr.ParamAttr(
name=self.global_name + 'fc_0.w_0',
initializer=fluid.initializer.Uniform(-stdv, stdv)))
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
is_first=block == i == 0,
name=conv_name,
data_format=data_format)
self.features[conv.name] = conv
pool = fluid.layers.pool2d(
input=conv,
pool_type='avg',
global_pooling=True,
name=self.global_name + 'global_pooling',
data_format=data_format)
self.features[pool.name] = pool
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
bias_attr=fluid.param_attr.ParamAttr(
name=self.global_name + 'fc_0.b_0'),
param_attr=fluid.param_attr.ParamAttr(
name=self.global_name + 'fc_0.w_0',
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return self.features, out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None,
data_format='NCHW'):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=self.global_name + name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1',
data_format=data_format)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
name=self.global_name + bn_name + '.output.1',
param_attr=ParamAttr(self.global_name + bn_name + '_scale'),
bias_attr=ParamAttr(self.global_name + bn_name + '_offset'),
moving_mean_name=self.global_name + bn_name + '_mean',
moving_variance_name=self.global_name + bn_name + '_variance',
data_layout=data_format,
use_global_stats=self.is_test)
def shortcut(self, input, ch_out, stride, is_first, name, data_format):
if data_format == 'NCHW':
ch_in = input.shape[1]
else:
ch_in = input.shape[-1]
if ch_in != ch_out or stride != 1 or is_first == True:
return self.conv_bn_layer(
input, ch_out, 1, stride, name=name, data_format=data_format)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, data_format):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a",
data_format=data_format)
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b",
data_format=data_format)
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c",
data_format=data_format)
short = self.shortcut(
input,
num_filters * 4,
stride,
is_first=False,
name=name + "_branch1",
data_format=data_format)
return fluid.layers.elementwise_add(
x=short,
y=conv2,
act='relu',
name=self.global_name + name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name,
data_format):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a",
data_format=data_format)
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b",
data_format=data_format)
short = self.shortcut(
input,
num_filters,
stride,
is_first,
name=name + "_branch1",
data_format=data_format)
return fluid.layers.elementwise_add(
x=short,
y=conv1,
act='relu',
name=self.global_name + name + ".add.output.5")
def ResNet18(is_test=True, global_name=''):
model = ResNet(layers=18, is_test=is_test, global_name=global_name)
return model
def ResNet34(is_test=True, global_name=''):
model = ResNet(layers=34, is_test=is_test, global_name=global_name)
return model
def ResNet50(is_test=True, global_name=''):
model = ResNet(layers=50, is_test=is_test, global_name=global_name)
return model
def ResNet101(is_test=True, global_name=''):
model = ResNet(layers=101, is_test=is_test, global_name=global_name)
return model
def ResNet152(is_test=True, global_name=''):
model = ResNet(layers=152, is_test=is_test, global_name=global_name)
return model
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
# from https://github.com/PaddlePaddle/models/blob/release/1.7/PaddleCV/image_classification/models/resnet_vc.py.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet50_vc", "ResNet101_vc", "ResNet152_vc"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50, is_test=False, global_name=''):
self.params = train_parameters
self.layers = layers
self.is_test = is_test
self.features = {}
self.global_name = global_name
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max',
name=self.global_name + 'poo1')
self.features[conv.name] = conv
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name)
self.features[conv.name] = conv
pool = fluid.layers.pool2d(
input=conv,
pool_type='avg',
global_pooling=True,
name=self.global_name + 'global_pooling')
self.features[pool.name] = pool
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=class_dim,
bias_attr=fluid.param_attr.ParamAttr(
name=self.global_name + 'fc_0.b_0'),
param_attr=fluid.param_attr.ParamAttr(
name=self.global_name + 'fc_0.w_0',
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return self.features, out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=self.global_name + name + "_weights"),
bias_attr=False,
name=self.global_name + name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
name=self.global_name + bn_name + '.output.1',
param_attr=ParamAttr(self.global_name + bn_name + '_scale'),
bias_attr=ParamAttr(self.global_name + bn_name + '_offset'),
moving_mean_name=self.global_name + bn_name + '_mean',
moving_variance_name=self.global_name + bn_name + '_variance',
use_global_stats=self.is_test)
def shortcut(self, input, ch_out, stride, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input, num_filters * 4, stride, name=name + "_branch1")
return fluid.layers.elementwise_add(
x=short,
y=conv2,
act='relu',
name=self.global_name + name + ".add.output.5")
def ResNet50_vc(is_test=True, global_name=''):
model = ResNet(layers=50, is_test=is_test, global_name=global_name)
return model
def ResNet101_vc(is_test=True, global_name=''):
model = ResNet(layers=101, is_test=is_test, global_name=global_name)
return model
def ResNet152_vc(is_test=True, global_name=''):
model = ResNet(layers=152, is_test=is_test, global_name=global_name)
return model
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
conf = {
"ssd": {
"with_background": True,
......@@ -65,7 +64,9 @@ ssd_train_ops = [
dict(op='ArrangeSSD')
]
ssd_eval_fields = ['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult']
ssd_eval_fields = [
'image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult'
]
ssd_eval_ops = [
dict(op='DecodeImage', to_rgb=True, with_mixup=False),
dict(op='NormalizeBox'),
......@@ -139,9 +140,10 @@ yolo_train_ops = [
dict(op='RandomCrop'),
dict(op='RandomFlipImage', is_normalized=False),
dict(op='Resize', target_dim=608, interp='random'),
dict(op='NormalizePermute',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
dict(
op='NormalizePermute',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
dict(op='NormalizeBox'),
dict(op='ArrangeYOLO'),
]
......@@ -194,18 +196,28 @@ feed_config = {
},
"rcnn": {
"train": {
"fields": ['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'],
"OPS": rcnn_train_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"fields":
['image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd'],
"OPS":
rcnn_train_ops,
"IS_PADDING":
True,
"COARSEST_STRIDE":
32,
},
"dev": {
"fields": ['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'],
"OPS": rcnn_eval_ops,
"IS_PADDING": True,
"COARSEST_STRIDE": 32,
"USE_PADDED_IM_INFO": True,
"fields": [
'image', 'im_info', 'im_id', 'im_shape', 'gt_box', 'gt_label',
'is_difficult'
],
"OPS":
rcnn_eval_ops,
"IS_PADDING":
True,
"COARSEST_STRIDE":
32,
"USE_PADDED_IM_INFO":
True,
},
"predict": {
"fields": ['image', 'im_info', 'im_id', 'im_shape'],
......@@ -222,8 +234,10 @@ feed_config = {
"RANDOM_SHAPES": [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
},
"dev": {
"fields": ['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'],
"OPS": yolo_eval_ops,
"fields":
['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'],
"OPS":
yolo_eval_ops,
},
"predict": {
"fields": ['image', 'im_size', 'im_id'],
......@@ -261,7 +275,9 @@ def get_feed_list(module_name, input_dict, input_dict_pred=None):
gt_bbox = input_dict['gt_bbox']
gt_class = input_dict['gt_class']
is_crowd = input_dict['is_crowd']
feed_list = [image.name, im_info.name, gt_bbox.name, gt_class.name, is_crowd.name]
feed_list = [
image.name, im_info.name, gt_bbox.name, gt_class.name, is_crowd.name
]
assert input_dict_pred is not None
image = input_dict_pred['image']
im_info = input_dict_pred['im_info']
......@@ -283,7 +299,9 @@ def get_mid_feature(module_name, output_dict, output_dict_pred=None):
rpn_cls_loss = output_dict['rpn_cls_loss']
rpn_reg_loss = output_dict['rpn_reg_loss']
generate_proposal_labels = output_dict['generate_proposal_labels']
feature = [head_feat, rpn_cls_loss, rpn_reg_loss, generate_proposal_labels]
feature = [
head_feat, rpn_cls_loss, rpn_reg_loss, generate_proposal_labels
]
assert output_dict_pred is not None
head_feat = output_dict_pred['head_feat']
rois = output_dict_pred['rois']
......@@ -291,4 +309,3 @@ def get_mid_feature(module_name, output_dict, output_dict_pred=None):
else:
raise NotImplementedError
return feature, feature_pred
......@@ -36,5 +36,6 @@ class Coco10(ObjectDetectionDataset):
validate_list_file = 'annotations/val.json'
test_image_dir = 'val'
test_list_file = 'annotations/val.json'
super(Coco10, self).__init__(base_path, train_image_dir, train_list_file, validate_image_dir, validate_list_file,
test_image_dir, test_list_file, model_type)
super(Coco10, self).__init__(
base_path, train_image_dir, train_list_file, validate_image_dir,
validate_list_file, test_image_dir, test_list_file, model_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册