未验证 提交 d6c65111 编写于 作者: R ruri 提交者: GitHub

Refine Image classification (#2974)

* Refine Image classification
上级 e85cf404
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
import utils.utility as utility
def _calc_label_smoothing_loss(softmax_out, label, class_dim, epsilon):
"""Calculate label smoothing loss
Returns:
label smoothing loss
"""
label_one_hot = fluid.layers.one_hot(input=label, depth=class_dim)
smooth_label = fluid.layers.label_smooth(
label=label_one_hot, epsilon=epsilon, dtype="float32")
loss = fluid.layers.cross_entropy(
input=softmax_out, label=smooth_label, soft_label=True)
return loss
def _basic_model(data, model, args, is_train):
image = data[0]
label = data[1]
net_out = model.net(input=image, class_dim=args.class_dim)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
if is_train and args.use_label_smoothing:
cost = _calc_label_smoothing_loss(softmax_out, label, args.class_dim,
args.epsilon)
else:
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)
return [avg_cost, acc_top1, acc_top5]
def _googlenet_model(data, model, args, is_train):
"""GoogLeNet model output, include avg_cost, acc_top1 and acc_top5
Returns:
GoogLeNet model output
"""
image = data[0]
label = data[1]
out0, out1, out2 = model.net(input=image, class_dim=args.class_dim)
cost0 = fluid.layers.cross_entropy(input=out0, label=label)
cost1 = fluid.layers.cross_entropy(input=out1, label=label)
cost2 = fluid.layers.cross_entropy(input=out2, label=label)
avg_cost0 = fluid.layers.mean(x=cost0)
avg_cost1 = fluid.layers.mean(x=cost1)
avg_cost2 = fluid.layers.mean(x=cost2)
avg_cost = avg_cost0 + 0.3 * avg_cost1 + 0.3 * avg_cost2
acc_top1 = fluid.layers.accuracy(input=out0, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out0, label=label, k=5)
return [avg_cost, acc_top1, acc_top5]
def _mixup_model(data, model, args, is_train):
"""output of Mixup processing network, include avg_cost
"""
image = data[0]
y_a = data[1]
y_b = data[2]
lam = data[3]
net_out = model.net(input=image, class_dim=args.class_dim)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
if not args.use_label_smoothing:
loss_a = fluid.layers.cross_entropy(input=softmax_out, label=y_a)
loss_b = fluid.layers.cross_entropy(input=softmax_out, label=y_b)
else:
loss_a = _calc_label_smoothing_loss(softmax_out, y_a, args.class_dim,
args.epsilon)
loss_b = _calc_label_smoothing_loss(softmax_out, y_b, args.class_dim,
args.epsilon)
loss_a_mean = fluid.layers.mean(x=loss_a)
loss_b_mean = fluid.layers.mean(x=loss_b)
cost = lam * loss_a_mean + (1 - lam) * loss_b_mean
avg_cost = fluid.layers.mean(x=cost)
return [avg_cost]
def create_model(model, args, is_train):
"""Create model, include basic model, googlenet model and mixup model
"""
py_reader, data = utility.create_pyreader(is_train, args)
if args.model == "GoogLeNet":
loss_out = _googlenet_model(data, model, args, is_train)
else:
if args.use_mixup and is_train:
loss_out = _mixup_model(data, model, args, is_train)
else:
loss_out = _basic_model(data, model, args, is_train)
return py_reader, loss_out
......@@ -26,43 +26,48 @@ import functools
import paddle
import paddle.fluid as fluid
import reader_cv2 as reader
import reader
import models
from utils.learning_rate import cosine_decay
from utils.utility import add_arguments, print_arguments, check_gpu
from utils import *
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 1000, "Class number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('resize_short_size', int, 256, "Set resize short size")
add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet datset")
add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 1000, "Class number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
parser.add_argument("--pretrained_model", default=None, required=True, type=str, help="The path to load pretrained model")
add_arg('model', str, "ResNet50", "Set the network to use.")
add_arg('resize_short_size', int, 256, "Set resize short size")
add_arg('reader_thread', int, 8, "The number of multi thread reader")
add_arg('reader_buf_size', int, 2048, "The buf size of multi thread reader")
parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data")
parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data")
add_arg('crop_size', int, 224, "The value of crop size")
# yapf: enable
def eval(args):
# parameters from arguments
class_dim = args.class_dim
model_name = args.model
pretrained_model = args.pretrained_model
image_shape = [int(m) for m in args.image_shape.split(",")]
model_list = [m for m in dir(models) if "__" not in m]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
assert os.path.isdir(
args.pretrained_model
), "{} doesn't exist, please load right pretrained model path for eval".format(
args.pretrained_model)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[model_name]()
model = models.__dict__[args.model]()
if model_name == "GoogleNet":
out0, out1, out2 = model.net(input=image, class_dim=class_dim)
if args.model == "GoogLeNet":
out0, out1, out2 = model.net(input=image, class_dim=args.class_dim)
cost0 = fluid.layers.cross_entropy(input=out0, label=label)
cost1 = fluid.layers.cross_entropy(input=out1, label=label)
cost2 = fluid.layers.cross_entropy(input=out2, label=label)
......@@ -74,7 +79,8 @@ def eval(args):
acc_top1 = fluid.layers.accuracy(input=out0, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out0, label=label, k=5)
else:
out = model.net(input=image, class_dim=class_dim)
out = model.net(input=image, class_dim=args.class_dim)
cost, pred = fluid.layers.softmax_with_cross_entropy(
out, label, return_softmax=True)
avg_cost = fluid.layers.mean(x=cost)
......@@ -89,9 +95,10 @@ def eval(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, pretrained_model)
fluid.io.load_persistables(exe, args.pretrained_model)
val_reader = reader.val(settings=args, batch_size=args.batch_size)
val_reader = paddle.batch(
reader.val(settings=args), batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
test_info = [[], [], []]
......@@ -129,7 +136,7 @@ def eval(args):
def main():
args = parser.parse_args()
print_arguments(args)
check_gpu(args.use_gpu)
check_gpu()
eval(args)
......
......@@ -26,43 +26,44 @@ import functools
import paddle
import paddle.fluid as fluid
import reader_cv2 as reader
import reader
import models
import utils
from utils.utility import add_arguments, print_arguments, check_gpu
from utils import *
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet data")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 1000, "Class number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('save_inference', bool, False, "Whether to save inference model or not")
add_arg('resize_short_size', int, 256, "Set resize short size")
parser.add_argument("--pretrained_model", default=None, required=True, type=str, help="The path to load pretrained model")
add_arg('model', str, "ResNet50", "Set the network to use.")
add_arg('save_inference', bool, False, "Whether to save inference model or not")
add_arg('resize_short_size',int, 256, "Set resize short size")
add_arg('reader_thread', int, 1, "The number of multi thread reader")
add_arg('reader_buf_size', int, 2048, "The buf size of multi thread reader")
parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data")
parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data")
add_arg('crop_size', int, 224, "The value of crop size")
add_arg('topk', int, 1, "topk")
add_arg('label_path', str, "./utils/tools/readable_label.txt", "readable label filepath")
# yapf: enable
def infer(args):
# parameters from arguments
class_dim = args.class_dim
model_name = args.model
save_inference = args.save_inference
pretrained_model = args.pretrained_model
image_shape = [int(m) for m in args.image_shape.split(",")]
model_list = [m for m in dir(models) if "__" not in m]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
assert args.model in model_list, "{} is not in lists: {}".format(args.model,
model_list)
assert os.path.isdir(args.pretrained_model
), "please load right pretrained model path for infer"
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
# model definition
model = models.__dict__[model_name]()
if model_name == "GoogleNet":
out, _, _ = model.net(input=image, class_dim=class_dim)
model = models.__dict__[args.model]()
if args.model == "GoogLeNet":
out, _, _ = model.net(input=image, class_dim=args.class_dim)
else:
out = model.net(input=image, class_dim=class_dim)
out = model.net(input=image, class_dim=args.class_dim)
out = fluid.layers.softmax(out)
test_program = fluid.default_main_program().clone(for_test=True)
......@@ -73,39 +74,51 @@ def infer(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, pretrained_model)
if save_inference:
fluid.io.load_persistables(exe, args.pretrained_model)
if args.save_inference:
fluid.io.save_inference_model(
dirname=model_name,
dirname=args.model,
feeded_var_names=['image'],
main_program=test_program,
target_vars=out,
executor=exe,
model_filename='model',
params_filename='params')
print("model: ", model_name, " is already saved")
print("model: ", args.model, " is already saved")
exit(0)
test_batch_size = 1
test_reader = reader.test(settings=args, batch_size=test_batch_size)
test_batch_size = 1
test_reader = paddle.batch(
reader.test(settings=args), batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
TOPK = 1
TOPK = args.topk
assert os.path.exists(args.label_path), "Index file doesn't exist!"
f = open(args.label_path)
label_dict = {}
for item in f.readlines():
key = item.split(" ")[0]
value = [l.replace("\n", "") for l in item.split(" ")[1:]]
label_dict[key] = value
for batch_id, data in enumerate(test_reader()):
result = exe.run(test_program,
fetch_list=fetch_list,
feed=feeder.feed(data))
result = result[0][0]
pred_label = np.argsort(result)[::-1][:TOPK]
print("Test-{0}-score: {1}, class {2}"
.format(batch_id, result[pred_label], pred_label))
readable_pred_label = []
for label in pred_label:
readable_pred_label.append(label_dict[str(label)])
print("Test-{0}-score: {1}, class{2} {3}".format(batch_id, result[
pred_label], pred_label, readable_pred_label))
sys.stdout.flush()
def main():
args = parser.parse_args()
print_arguments(args)
check_gpu(args.use_gpu)
check_gpu()
infer(args)
......
......@@ -8,3 +8,9 @@ For historical reasons, We keep "no name" models here, which are different from
|- |:-: |:-:|
|[ResNet152](http://paddle-imagenet-models.bj.bcebos.com/ResNet152_pretrained.zip) | 78.18%/93.93% | 78.11%/94.04% |
|[SE_ResNeXt50_32x4d](http://paddle-imagenet-models.bj.bcebos.com/se_resnext_50_model.tar) | 78.32%/93.96% | 77.58%/93.73% |
---
2019/08/08
We move the dist_train and fp16 part to PaddlePaddle Fleet now.
and dist_train folder is temporary stored here.
......@@ -112,7 +112,7 @@ Speed-ups of Multiple-GPU Training of Resnet50 on Imagenet
#### Environment
- GPU: NVIDIA® Tesla® V100
- GPU: NVIDIA® Tesla® V100
- Machine number * Card number: 4 * 4
- System: Centos 6u3
- Cuda/Cudnn: 9.0/7.1
......@@ -127,5 +127,3 @@ Speed-ups of Multiple-GPU Training of Resnet50 on Imagenet
<img src="../images/resnet_dgc.png" width=528> <br />
Performance using DGC for resnet-fp32 under different bandwidth
</p>
......@@ -15,6 +15,7 @@
import paddle.fluid as fluid
import numpy as np
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
for op in main_prog.global_block().ops:
......@@ -22,9 +23,11 @@ def copyback_repeat_bn_params(main_prog):
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for vname in repeat_vars:
real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
real_var = fluid.global_scope().find_var("%s.repeat.0" %
vname).get_tensor()
orig_var = fluid.global_scope().find_var(vname).get_tensor()
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
repeat_vars = set()
......@@ -32,7 +35,7 @@ def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for i in range(num_repeats):
for op in startup_prog.global_block().ops:
if op.type == "fill_constant":
......@@ -45,13 +48,10 @@ def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable
)
persistable=var.persistable)
main_prog.global_block()._clone_variable(repeat_var)
startup_prog.global_block().append_op(
type="fill_constant",
inputs={},
outputs={"Out": repeat_var},
attrs=op.all_attrs()
)
attrs=op.all_attrs())
......@@ -23,7 +23,7 @@ def dist_env():
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
num_trainers = 1
training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
assert(training_role == "PSERVER" or training_role == "TRAINER")
assert (training_role == "PSERVER" or training_role == "TRAINER")
# - PADDLE_TRAINER_ENDPOINTS means nccl2 mode.
# - PADDLE_PSERVER_ENDPOINTS means pserver mode.
......@@ -36,7 +36,7 @@ def dist_env():
num_trainers = len(trainer_endpoints)
elif pserver_endpoints:
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM"))
return {
"trainer_id": trainer_id,
"num_trainers": num_trainers,
......
......@@ -17,11 +17,9 @@ import math
import random
import functools
import numpy as np
import cv2
import io
from PIL import Image, ImageEnhance
import paddle
import paddle.fluid as fluid
random.seed(0)
np.random.seed(0)
......@@ -31,202 +29,133 @@ DATA_DIM = 224
THREAD = 8
BUF_SIZE = 2048
DATA_DIR = './data/ILSVRC2012'
DATA_DIR = 'data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def rotate_image(img):
""" rotate_image """
(h, w) = img.shape[:2]
center = (w / 2, h / 2)
angle = np.random.randint(-10, 11)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
rotated = cv2.warpAffine(img, M, (w, h))
return rotated
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def random_crop(img, size, settings, scale=None, ratio=None):
""" random_crop """
lower_scale = settings.lower_scale
lower_ratio = settings.lower_ratio
upper_ratio = settings.upper_ratio
scale = [lower_scale, 1.0] if scale is None else scale
ratio = [lower_ratio, upper_ratio] if ratio is None else ratio
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.shape[0]) / img.shape[1]) / (h**2),
(float(img.shape[1]) / img.shape[0]) / (w**2))
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.shape[0] * img.shape[1] * np.random.uniform(scale_min,
scale_max)
target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.shape[0] - h + 1)
j = np.random.randint(0, img.shape[1] - w + 1)
img = img[i:i + h, j:j + w, :]
i = np.random.randint(0, img.size[0] - w + 1)
j = np.random.randint(0, img.size[1] - h + 1)
resized = cv2.resize(
img,
(size, size)
#, interpolation=cv2.INTER_LANCZOS4
)
return resized
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.LANCZOS)
return img
def distort_color(img):
def rotate_image(img):
angle = np.random.randint(-10, 11)
img = img.rotate(angle)
return img
def resize_short(img, target_size):
""" resize_short """
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
resized = cv2.resize(
img,
(resized_width, resized_height),
#interpolation=cv2.INTER_LANCZOS4
)
return resized
def distort_color(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def crop_image(img, target_size, center):
""" crop_image """
height, width = img.shape[:2]
size = target_size
if center == True:
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def create_mixup_reader(settings, rd):
class context:
tmp_mix = []
tmp_l1 = []
tmp_l2 = []
tmp_lam = []
batch_size = settings.batch_size
alpha = settings.mixup_alpha
def fetch_data():
data_list = []
for i, item in enumerate(rd()):
data_list.append(item)
if i % batch_size == batch_size - 1:
yield data_list
data_list = []
def mixup_data():
for data_list in fetch_data():
if alpha > 0.:
lam = np.random.beta(alpha, alpha)
else:
lam = 1.
l1 = np.array(data_list)
l2 = np.random.permutation(l1)
mixed_l = [
l1[i][0] * lam + (1 - lam) * l2[i][0] for i in range(len(l1))
]
yield mixed_l, l1, l2, lam
def mixup_reader():
for context.tmp_mix, context.tmp_l1, context.tmp_l2, context.tmp_lam in mixup_data(
):
for i in range(len(context.tmp_mix)):
mixed_l = context.tmp_mix[i]
l1 = context.tmp_l1[i]
l2 = context.tmp_l2[i]
lam = context.tmp_lam
yield mixed_l, l1[1], l2[1], lam
return mixup_reader
def process_image(sample,
settings,
mode,
color_jitter,
rotate,
crop_size=224,
mean=None,
std=None):
""" process_image """
mean = [0.485, 0.456, 0.406] if mean is None else mean
std = [0.229, 0.224, 0.225] if std is None else std
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = cv2.imread(img_path)
img = Image.open(img_path)
if mode == 'train':
if rotate: img = rotate_image(img)
img = random_crop(img, DATA_DIM)
else:
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if mode == 'train':
if rotate:
img = rotate_image(img)
if crop_size > 0:
img = random_crop(img, crop_size, settings)
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
img = img[:, ::-1, :]
else:
if crop_size > 0:
target_size = settings.resize_short_size
img = resize_short(img, target_size)
img = crop_image(img, target_size=crop_size, center=True)
img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'val':
return (img, sample[1])
return img, sample[1]
elif mode == 'test':
return (img, )
return [img]
def process_batch_data(input_data, settings, mode, color_jitter, rotate):
def process_batch_data(input_data, mode, color_jitter, rotate):
batch_data = []
crop_size = int(settings.image_shape.split(',')[-1])
for sample in input_data:
if os.path.isfile(sample[0]):
batch_data.append(
process_image(sample, settings, mode, color_jitter, rotate, crop_size))
else:
print("File not exist : %s" % sample[0])
batch_data.append(process_image(sample, mode, color_jitter, rotate))
return batch_data
def _reader_creator(settings,
file_list,
def _reader_creator(file_list,
batch_size,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR,
shuffle_seed=0):
shuffle_seed=0,
infinite=False):
def reader():
def read_file_list():
with open(file_list) as flist:
......@@ -241,9 +170,10 @@ def _reader_creator(settings,
img_path = os.path.join(data_dir, img_path)
batch_data.append([img_path, int(label)])
if len(batch_data) == batch_size:
if mode == 'train' or mode == 'val' or mode == 'test':
if mode == 'train' or mode == 'val':
yield batch_data
elif mode == 'test':
yield [sample[0] for sample in batch_data]
batch_data = []
return read_file_list
......@@ -258,20 +188,14 @@ def _reader_creator(settings,
data_reader = fluid.contrib.reader.distributed_batch_reader(data_reader)
mapper = functools.partial(
process_batch_data,
settings=settings,
mode=mode,
color_jitter=color_jitter,
rotate=rotate)
process_batch_data, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(
mapper, data_reader, THREAD, BUF_SIZE, order=False)
return paddle.reader.xmap_readers(mapper, data_reader, THREAD, BUF_SIZE)
def train(settings, batch_size, data_dir=DATA_DIR, shuffle_seed=0):
def train(batch_size, data_dir=DATA_DIR, shuffle_seed=0, infinite=False):
file_list = os.path.join(data_dir, 'train_list.txt')
reader = _reader_creator(
settings,
return _reader_creator(
file_list,
batch_size,
'train',
......@@ -279,29 +203,17 @@ def train(settings, batch_size, data_dir=DATA_DIR, shuffle_seed=0):
color_jitter=False,
rotate=False,
data_dir=data_dir,
shuffle_seed=shuffle_seed)
if settings.use_mixup == True:
reader = create_mixup_reader(settings, reader)
return reader
shuffle_seed=shuffle_seed,
infinite=infinite)
def val(settings, batch_size, data_dir=DATA_DIR):
def val(batch_size, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(
settings,
file_list,
batch_size,
'val',
shuffle=False,
data_dir=data_dir)
file_list, batch_size, 'val', shuffle=False, data_dir=data_dir)
def test(settings, batch_size, data_dir=DATA_DIR):
def test(batch_size, data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(
settings,
file_list,
batch_size,
'test',
shuffle=False,
data_dir=data_dir)
file_list, batch_size, 'test', shuffle=False, data_dir=data_dir)
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from .alexnet import AlexNet
from .mobilenet import MobileNet
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x1_0, MobileNetV2_x1_5, MobileNetV2_x2_0, MobileNetV2_scale
from .googlenet import GoogleNet
from .mobilenet import MobileNet, MobileNetV1
from .mobilenet_v2 import MobileNetV2, MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x1_0, MobileNetV2_x1_5, MobileNetV2_x2_0, MobileNetV2_scale
from .googlenet import GoogLeNet
from .vgg import VGG11, VGG13, VGG16, VGG19
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .resnet_vc import ResNet50_vc, ResNet101_vc, ResNet152_vc
......@@ -11,14 +25,13 @@ from .resnext_vd import ResNeXt50_vd_64x4d, ResNeXt101_vd_64x4d, ResNeXt152_vd_6
from .resnet_dist import DistResNet
from .inception_v4 import InceptionV4
from .se_resnext import SE_ResNeXt50_32x4d, SE_ResNeXt101_32x4d, SE_ResNeXt152_32x4d
from .se_resnext_vd import SE_ResNeXt50_32x4d_vd, SE_ResNeXt101_32x4d_vd, SENet154_vd
from .se_resnext_vd import SE_ResNeXt50_32x4d_vd, SE_ResNeXt101_32x4d_vd, SE_154_vd
from .dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from .shufflenet_v2_swish import ShuffleNetV2, ShuffleNetV2_x0_5_swish, ShuffleNetV2_x1_0_swish, ShuffleNetV2_x1_5_swish, ShuffleNetV2_x2_0_swish
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0
from .shufflenet_v2_swish import ShuffleNetV2_swish, ShuffleNetV2_x0_5_swish, ShuffleNetV2_x1_0_swish, ShuffleNetV2_x1_5_swish, ShuffleNetV2_x2_0_swish
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2_x1_0, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2
from .fast_imagenet import FastImageNet
from .xception import Xception_41, Xception_65, Xception_71
from .densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264
from .squeezenet import SqueezeNet1_0, SqueezeNet1_1
from .darknet import DarkNet53
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl, Fix_ResNeXt101_32x48d_wsl
......@@ -23,22 +23,10 @@ import paddle.fluid as fluid
__all__ = ['AlexNet']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [40, 70, 100],
"steps": [0.01, 0.001, 0.0001, 0.00001]
}
}
class AlexNet():
def __init__(self):
self.params = train_parameters
pass
def net(self, input, class_dim=1000):
stdv = 1.0 / math.sqrt(input.shape[1] * 11 * 11)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import math
__all__ = ["DarkNet53"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class DarkNet53():
def __init__(self):
self.params = train_parameters
pass
def net(self, input, class_dim=1000):
DarkNet_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)}
......@@ -45,17 +38,11 @@ class DarkNet53():
padding=1,
name="yolo_input")
conv = self.downsample(
conv1,
ch_out=conv1.shape[1] * 2,
name="yolo_input.downsample")
conv1, ch_out=conv1.shape[1] * 2, name="yolo_input.downsample")
for i, stage in enumerate(stages):
conv = self.layer_warp(
block_func,
conv,
32 * (2**i),
stage,
name="stage.{}".format(i))
block_func, conv, 32 * (2**i), stage, name="stage.{}".format(i))
if i < len(stages) - 1: # do not downsaple in the last stage
conv = self.downsample(
conv,
......@@ -64,18 +51,22 @@ class DarkNet53():
pool = fluid.layers.pool2d(
input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
return out
def conv_bn_layer(self, input, ch_out, filter_size, stride, padding, name=None):
def conv_bn_layer(self,
input,
ch_out,
filter_size,
stride,
padding,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=ch_out,
......@@ -96,9 +87,13 @@ class DarkNet53():
moving_variance_name=bn_name + '.var')
return out
def downsample(self, input, ch_out, filter_size=3, stride=2, padding=1, name=None):
def downsample(self,
input,
ch_out,
filter_size=3,
stride=2,
padding=1,
name=None):
return self.conv_bn_layer(
input,
ch_out=ch_out,
......@@ -107,22 +102,14 @@ class DarkNet53():
padding=padding,
name=name)
def basicblock(self, input, ch_out, name=None):
conv1 = self.conv_bn_layer(
input, ch_out, 1, 1, 0, name=name + ".0")
conv2 = self.conv_bn_layer(
conv1, ch_out * 2, 3, 1, 1, name=name + ".1")
conv1 = self.conv_bn_layer(input, ch_out, 1, 1, 0, name=name + ".0")
conv2 = self.conv_bn_layer(conv1, ch_out * 2, 3, 1, 1, name=name + ".1")
out = fluid.layers.elementwise_add(x=input, y=conv2, act=None)
return out
def layer_warp(self, block_func, input, ch_out, count, name=None):
res_out = block_func(
input, ch_out, name='{}.0'.format(name))
res_out = block_func(input, ch_out, name='{}.0'.format(name))
for j in range(1, count):
res_out = block_func(
res_out, ch_out, name='{}.{}'.format(name, j))
res_out = block_func(res_out, ch_out, name='{}.{}'.format(name, j))
return res_out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math
from paddle.fluid.param_attr import ParamAttr
__all__ = ["DenseNet", "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201", "DenseNet264"]
__all__ = [
"DenseNet", "DenseNet121", "DenseNet161", "DenseNet169", "DenseNet201",
"DenseNet264"
]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class DenseNet():
def __init__(self, layers=121):
self.params = train_parameters
self.layers = layers
def net(self, input, bn_size=4, dropout=0, class_dim=1000):
layers = self.layers
supported_layers = [121, 161, 169, 201, 264]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
densenet_spec = {121: (64, 32, [6, 12, 24, 16]),
161: (96, 48, [6, 12, 36, 24]),
169: (64, 32, [6, 12, 32, 32]),
201: (64, 32, [6, 12, 48, 32]),
264: (64, 32, [6, 12, 64, 48])}
densenet_spec = {
121: (64, 32, [6, 12, 24, 16]),
161: (96, 48, [6, 12, 36, 24]),
169: (64, 32, [6, 12, 32, 32]),
201: (64, 32, [6, 12, 48, 32]),
264: (64, 32, [6, 12, 64, 48])
}
num_init_features, growth_rate, block_config = densenet_spec[layers]
conv = fluid.layers.conv2d(
input=input,
......@@ -58,46 +53,61 @@ class DenseNet():
act=None,
param_attr=ParamAttr(name="conv1_weights"),
bias_attr=False)
conv = fluid.layers.batch_norm(input=conv,
act='relu',
param_attr=ParamAttr(name='conv1_bn_scale'),
bias_attr=ParamAttr(name='conv1_bn_offset'),
moving_mean_name='conv1_bn_mean',
moving_variance_name='conv1_bn_variance')
conv = fluid.layers.pool2d(input=conv, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
conv = fluid.layers.batch_norm(
input=conv,
act='relu',
param_attr=ParamAttr(name='conv1_bn_scale'),
bias_attr=ParamAttr(name='conv1_bn_offset'),
moving_mean_name='conv1_bn_mean',
moving_variance_name='conv1_bn_variance')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
num_features = num_init_features
for i, num_layers in enumerate(block_config):
conv = self.make_dense_block(conv, num_layers, bn_size, growth_rate, dropout, name='conv'+str(i+2))
conv = self.make_dense_block(
conv,
num_layers,
bn_size,
growth_rate,
dropout,
name='conv' + str(i + 2))
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
conv = self.make_transition(conv, num_features // 2, name='conv'+str(i+2)+'_blk')
conv = self.make_transition(
conv, num_features // 2, name='conv' + str(i + 2) + '_blk')
num_features = num_features // 2
conv = fluid.layers.batch_norm(input=conv,
act='relu',
param_attr=ParamAttr(name='conv5_blk_bn_scale'),
bias_attr=ParamAttr(name='conv5_blk_bn_offset'),
moving_mean_name='conv5_blk_bn_mean',
moving_variance_name='conv5_blk_bn_variance')
conv = fluid.layers.pool2d(input=conv, pool_type='avg', global_pooling=True)
conv = fluid.layers.batch_norm(
input=conv,
act='relu',
param_attr=ParamAttr(name='conv5_blk_bn_scale'),
bias_attr=ParamAttr(name='conv5_blk_bn_offset'),
moving_mean_name='conv5_blk_bn_mean',
moving_variance_name='conv5_blk_bn_variance')
conv = fluid.layers.pool2d(
input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(conv.shape[1] * 1.0)
out = fluid.layers.fc(input=conv,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_weights"),
bias_attr=ParamAttr(name='fc_offset'))
out = fluid.layers.fc(
input=conv,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_weights"),
bias_attr=ParamAttr(name='fc_offset'))
return out
def make_transition(self, input, num_output_features, name=None):
bn_ac = fluid.layers.batch_norm(input,
act='relu',
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance'
)
bn_ac = fluid.layers.batch_norm(
input,
act='relu',
param_attr=ParamAttr(name=name + '_bn_scale'),
bias_attr=ParamAttr(name + '_bn_offset'),
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
bn_ac_conv = fluid.layers.conv2d(
input=bn_ac,
num_filters=num_output_features,
......@@ -105,25 +115,36 @@ class DenseNet():
stride=1,
act=None,
bias_attr=False,
param_attr=ParamAttr(name=name + "_weights")
)
pool = fluid.layers.pool2d(input=bn_ac_conv, pool_size=2, pool_stride=2, pool_type='avg')
param_attr=ParamAttr(name=name + "_weights"))
pool = fluid.layers.pool2d(
input=bn_ac_conv, pool_size=2, pool_stride=2, pool_type='avg')
return pool
def make_dense_block(self, input, num_layers, bn_size, growth_rate, dropout, name=None):
def make_dense_block(self,
input,
num_layers,
bn_size,
growth_rate,
dropout,
name=None):
conv = input
for layer in range(num_layers):
conv = self.make_dense_layer(conv, growth_rate, bn_size, dropout, name=name + '_' + str(layer+1))
conv = self.make_dense_layer(
conv,
growth_rate,
bn_size,
dropout,
name=name + '_' + str(layer + 1))
return conv
def make_dense_layer(self, input, growth_rate, bn_size, dropout, name=None):
bn_ac = fluid.layers.batch_norm(input,
act='relu',
param_attr=ParamAttr(name=name + '_x1_bn_scale'),
bias_attr=ParamAttr(name + '_x1_bn_offset'),
moving_mean_name=name + '_x1_bn_mean',
moving_variance_name=name + '_x1_bn_variance')
bn_ac = fluid.layers.batch_norm(
input,
act='relu',
param_attr=ParamAttr(name=name + '_x1_bn_scale'),
bias_attr=ParamAttr(name + '_x1_bn_offset'),
moving_mean_name=name + '_x1_bn_mean',
moving_variance_name=name + '_x1_bn_variance')
bn_ac_conv = fluid.layers.conv2d(
input=bn_ac,
num_filters=bn_size * growth_rate,
......@@ -132,12 +153,13 @@ class DenseNet():
act=None,
bias_attr=False,
param_attr=ParamAttr(name=name + "_x1_weights"))
bn_ac = fluid.layers.batch_norm(bn_ac_conv,
act='relu',
param_attr=ParamAttr(name=name + '_x2_bn_scale'),
bias_attr=ParamAttr(name + '_x2_bn_offset'),
moving_mean_name=name + '_x2_bn_mean',
moving_variance_name=name + '_x2_bn_variance')
bn_ac = fluid.layers.batch_norm(
bn_ac_conv,
act='relu',
param_attr=ParamAttr(name=name + '_x2_bn_scale'),
bias_attr=ParamAttr(name + '_x2_bn_offset'),
moving_mean_name=name + '_x2_bn_mean',
moving_variance_name=name + '_x2_bn_variance')
bn_ac_conv = fluid.layers.conv2d(
input=bn_ac,
num_filters=growth_rate,
......@@ -148,33 +170,32 @@ class DenseNet():
bias_attr=False,
param_attr=ParamAttr(name=name + "_x2_weights"))
if dropout:
bn_ac_conv = fluid.layers.dropout(x=bn_ac_conv, dropout_prob=dropout)
bn_ac_conv = fluid.layers.dropout(
x=bn_ac_conv, dropout_prob=dropout)
bn_ac_conv = fluid.layers.concat([input, bn_ac_conv], axis=1)
return bn_ac_conv
def DenseNet121():
model=DenseNet(layers=121)
model = DenseNet(layers=121)
return model
def DenseNet161():
model=DenseNet(layers=161)
model = DenseNet(layers=161)
return model
def DenseNet169():
model=DenseNet(layers=169)
model = DenseNet(layers=169)
return model
def DenseNet201():
model=DenseNet(layers=201)
model = DenseNet(layers=201)
return model
def DenseNet264():
model=DenseNet(layers=264)
model = DenseNet(layers=264)
return model
......@@ -27,22 +27,9 @@ from paddle.fluid.param_attr import ParamAttr
__all__ = ["DPN", "DPN68", "DPN92", "DPN98", "DPN107", "DPN131"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class DPN(object):
def __init__(self, layers=68):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
......
......@@ -20,24 +20,13 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ['GoogleNet']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 70, 100],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class GoogleNet():
__all__ = ['GoogLeNet']
class GoogLeNet():
def __init__(self):
self.params = train_parameters
pass
def conv_layer(self,
input,
......
......@@ -24,22 +24,11 @@ from paddle.fluid.param_attr import ParamAttr
__all__ = ['InceptionV4']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class InceptionV4():
def __init__(self):
self.params = train_parameters
pass
def net(self, input, class_dim=1000):
x = self.inception_stem(input)
......
......@@ -20,24 +20,12 @@ import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ['MobileNet']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = ['MobileNet', 'MobileNetV1']
class MobileNet():
def __init__(self):
self.params = train_parameters
pass
def net(self, input, class_dim=1000, scale=1.0):
# conv1: 112x112
......@@ -208,3 +196,8 @@ class MobileNet():
padding=0,
name=name + "_sep")
return pointwise_conv
def MobileNetV1():
model = MobileNet()
return model
......@@ -19,28 +19,17 @@ import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ['MobileNetV2', 'MobileNetV2_x0_25, ''MobileNetV2_x0_5', 'MobileNetV2_x1_0', 'MobileNetV2_x1_5', 'MobileNetV2_x2_0',
'MobileNetV2_scale']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = [
'MobileNetV2', 'MobileNetV2_x0_25, '
'MobileNetV2_x0_5', 'MobileNetV2_x1_0', 'MobileNetV2_x1_5',
'MobileNetV2_x2_0', 'MobileNetV2_scale'
]
class MobileNetV2():
def __init__(self, scale=1.0, change_depth=False):
self.params = train_parameters
self.scale = scale
self.change_depth=change_depth
self.change_depth = change_depth
def net(self, input, class_dim=1000):
scale = self.scale
......@@ -55,13 +44,13 @@ class MobileNetV2():
(6, 160, 3, 2),
(6, 320, 1, 1),
] if change_depth == False else [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 5, 2),
(6, 64, 7, 2),
(6, 96, 5, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 5, 2),
(6, 64, 7, 2),
(6, 96, 5, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
#conv1
......@@ -224,29 +213,33 @@ class MobileNetV2():
expansion_factor=t,
name=name + '_' + str(i + 1))
return last_residual_block
def MobileNetV2_x0_25():
model = MobileNetV2(scale=0.25)
return model
def MobileNetV2_x0_5():
model = MobileNetV2(scale=0.5)
return model
def MobileNetV2_x1_0():
model = MobileNetV2(scale=1.0)
return model
def MobileNetV2_x1_5():
model = MobileNetV2(scale=1.5)
return model
def MobileNetV2_x2_0():
model = MobileNetV2(scale=2.0)
return model
def MobileNetV2_scale():
model = MobileNetV2(scale=1.2, change_depth=True)
return model
\ No newline at end of file
return model
......@@ -22,24 +22,13 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = [
"ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"
]
class ResNet():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
......@@ -59,7 +48,12 @@ class ResNet():
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu',name="conv1")
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name="conv1")
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
......@@ -71,41 +65,44 @@ class ResNet():
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name="res"+str(block+2)+"a"
conv_name = "res" + str(block + 2) + "a"
else:
conv_name="res"+str(block+2)+"b"+str(i)
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name="res"+str(block+2)+chr(97+i)
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1, name=conv_name)
stride=2 if i == 0 and block != 0 else 1,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name="res"+str(block+2)+chr(97+i)
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
is_first=block==i==0,
is_first=block == i == 0,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out
def conv_bn_layer(self,
......@@ -127,18 +124,19 @@ class ResNet():
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
name=bn_name+'.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',)
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
......@@ -149,29 +147,53 @@ class ResNet():
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu',name=name+"_branch2a")
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name+"_branch2b")
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name+"_branch2c")
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input,
num_filters * 4,
stride,
is_first=False,
name=name + "_branch1")
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu', name=name + ".add.output.5")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu',name=name+".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride,
name=name+"_branch2a")
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None,
name=name+"_branch2b")
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self.shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def ResNet18():
model = ResNet(layers=18)
......
......@@ -22,26 +22,16 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet50_vd","ResNet101_vd", "ResNet152_vd", "ResNet200_vd"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = [
"ResNet", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd", "ResNet200_vd"
]
class ResNet():
def __init__(self, layers=50, is_3x3 = False):
self.params = train_parameters
def __init__(self, layers=50, is_3x3=False):
self.layers = layers
self.is_3x3 = is_3x3
def net(self, input, class_dim=1000):
is_3x3 = self.is_3x3
layers = self.layers
......@@ -60,14 +50,33 @@ class ResNet():
num_filters = [64, 128, 256, 512]
if is_3x3 == False:
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
else:
conv = self.conv_bn_layer(
input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1')
input=input,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2')
input=conv,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3')
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
......@@ -80,32 +89,29 @@ class ResNet():
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name="res"+str(block+2)+"a"
conv_name = "res" + str(block + 2) + "a"
else:
conv_name="res"+str(block+2)+"b"+str(i)
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name="res"+str(block+2)+chr(97+i)
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block==0,
if_first=block == 0,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out
def conv_bn_layer(self,
input,
......@@ -128,29 +134,30 @@ class ResNet():
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(input=input,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg')
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
......@@ -165,14 +172,13 @@ class ResNet():
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
......@@ -180,43 +186,57 @@ class ResNet():
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name)
return self.conv_bn_layer_new(
input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name+"_branch2a")
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name+"_branch2b")
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name+"_branch2c")
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(input, num_filters * 4, stride, if_first=if_first, name=name + "_branch1")
short = self.shortcut(
input,
num_filters * 4,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNet50_vd():
model = ResNet(layers=50, is_3x3 = True)
model = ResNet(layers=50, is_3x3=True)
return model
def ResNet101_vd():
model = ResNet(layers=101, is_3x3 = True)
model = ResNet(layers=101, is_3x3=True)
return model
def ResNet152_vd():
model = ResNet(layers=152, is_3x3 = True)
model = ResNet(layers=152, is_3x3=True)
return model
def ResNet200_vd():
model = ResNet(layers=200, is_3x3 = True)
model = ResNet(layers=200, is_3x3=True)
return model
......@@ -22,25 +22,14 @@ import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNeXt", "ResNeXt50_64x4d", "ResNeXt101_64x4d", "ResNeXt152_64x4d", "ResNeXt50_32x4d", "ResNeXt101_32x4d",
"ResNeXt152_32x4d"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = [
"ResNeXt", "ResNeXt50_64x4d", "ResNeXt101_64x4d", "ResNeXt152_64x4d",
"ResNeXt50_32x4d", "ResNeXt101_32x4d", "ResNeXt152_32x4d"
]
class ResNeXt():
def __init__(self, layers=50, cardinality=64):
self.params = train_parameters
self.layers = layers
self.cardinality = cardinality
......@@ -57,7 +46,7 @@ class ResNeXt():
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters1 = [256, 512, 1024, 2048]
num_filters2 = [128, 256, 512, 1024]
......@@ -67,7 +56,7 @@ class ResNeXt():
filter_size=7,
stride=2,
act='relu',
name="res_conv1") #debug
name="res_conv1") #debug
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
......@@ -86,7 +75,8 @@ class ResNeXt():
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters1[block] if cardinality == 64 else num_filters2[block],
num_filters=num_filters1[block]
if cardinality == 64 else num_filters2[block],
stride=2 if i == 0 and block != 0 else 1,
cardinality=cardinality,
name=conv_name)
......@@ -94,11 +84,13 @@ class ResNeXt():
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc_weights'),
bias_attr=fluid.param_attr.ParamAttr(name='fc_offset'))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc_weights'),
bias_attr=fluid.param_attr.ParamAttr(name='fc_offset'))
return out
def conv_bn_layer(self,
......@@ -158,13 +150,16 @@ class ResNeXt():
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters if cardinality == 64 else num_filters*2,
num_filters=num_filters if cardinality == 64 else num_filters * 2,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input, num_filters if cardinality == 64 else num_filters*2, stride, name=name + "_branch1")
input,
num_filters if cardinality == 64 else num_filters * 2,
stride,
name=name + "_branch1")
return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu', name=name + ".add.output.5")
......@@ -174,6 +169,7 @@ def ResNeXt50_64x4d():
model = ResNeXt(layers=50, cardinality=64)
return model
def ResNeXt50_32x4d():
model = ResNeXt(layers=50, cardinality=32)
return model
......@@ -183,6 +179,7 @@ def ResNeXt101_64x4d():
model = ResNeXt(layers=101, cardinality=64)
return model
def ResNeXt101_32x4d():
model = ResNeXt(layers=101, cardinality=32)
return model
......@@ -192,6 +189,7 @@ def ResNeXt152_64x4d():
model = ResNeXt(layers=152, cardinality=64)
return model
def ResNeXt152_32x4d():
model = ResNeXt(layers=152, cardinality=32)
return model
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -19,24 +19,14 @@ import paddle.fluid as fluid
import math
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl", "ResNeXt101_32x32d_wsl", "ResNeXt101_32x48d_wsl", "Fix_ResNeXt101_32x48d_wsl"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = [
"ResNeXt101_32x8d_wsl", "ResNeXt101_32x16d_wsl", "ResNeXt101_32x32d_wsl",
"ResNeXt101_32x48d_wsl", "Fix_ResNeXt101_32x48d_wsl"
]
class ResNeXt101_wsl():
def __init__(self, layers=101, cardinality=32, width=48):
self.params = train_parameters
self.layers = layers
self.cardinality = cardinality
self.width = width
......@@ -49,7 +39,6 @@ class ResNeXt101_wsl():
depth = [3, 4, 23, 3]
base_width = cardinality * width
num_filters = [base_width * i for i in [1, 2, 4, 8]]
conv = self.conv_bn_layer(
input=input,
......@@ -57,7 +46,7 @@ class ResNeXt101_wsl():
filter_size=7,
stride=2,
act='relu',
name="conv1") #debug
name="conv1") #debug
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
......@@ -67,7 +56,7 @@ class ResNeXt101_wsl():
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = 'layer' + str(block+1) + "." + str(i)
conv_name = 'layer' + str(block + 1) + "." + str(i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
......@@ -78,11 +67,13 @@ class ResNeXt101_wsl():
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc.weight'),
bias_attr=fluid.param_attr.ParamAttr(name='fc.bias'))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc.weight'),
bias_attr=fluid.param_attr.ParamAttr(name='fc.bias'))
return out
def conv_bn_layer(self,
......@@ -113,7 +104,8 @@ class ResNeXt101_wsl():
if "conv1" == name:
bn_name = 'bn' + name[-1]
else:
bn_name = (name[:10] if name[7:9].isdigit() else name[:9]) + 'bn' + name[-1]
bn_name = (name[:10] if name[7:9].isdigit() else name[:9]
) + 'bn' + name[-1]
return fluid.layers.batch_norm(
input=conv,
act=act,
......@@ -148,32 +140,35 @@ class ResNeXt101_wsl():
name=name + ".conv2")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters//(width//8),
num_filters=num_filters // (width // 8),
filter_size=1,
act=None,
name=name + ".conv3")
short = self.shortcut(
input, num_filters//(width//8), stride, name=name + ".downsample")
return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu')
input,
num_filters // (width // 8),
stride,
name=name + ".downsample")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNeXt101_32x8d_wsl():
model = ResNeXt101_wsl(cardinality=32, width=8)
return model
def ResNeXt101_32x16d_wsl():
model = ResNeXt101_wsl(cardinality=32, width=16)
return model
def ResNeXt101_32x32d_wsl():
model = ResNeXt101_wsl(cardinality=32, width=32)
return model
def ResNeXt101_32x48d_wsl():
model = ResNeXt101_wsl(cardinality=32, width=48)
return model
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import math
__all__ = ["ResNeXt","ResNeXt50_vd_64x4d","ResNeXt101_vd_64x4d","ResNeXt152_vd_64x4d","ResNeXt50_vd_32x4d","ResNeXt101_vd_32x4d", "ResNeXt152_vd_32x4d"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = [
"ResNeXt", "ResNeXt50_vd_64x4d", "ResNeXt101_vd_64x4d",
"ResNeXt152_vd_64x4d", "ResNeXt50_vd_32x4d", "ResNeXt101_vd_32x4d",
"ResNeXt152_vd_32x4d"
]
class ResNeXt():
def __init__(self, layers=50, is_3x3 = False, cardinality=64):
self.params = train_parameters
def __init__(self, layers=50, is_3x3=False, cardinality=64):
self.layers = layers
self.is_3x3 = is_3x3
self.cardinality = cardinality
def net(self, input, class_dim=1000):
is_3x3 = self.is_3x3
layers = self.layers
......@@ -52,17 +45,36 @@ class ResNeXt():
depth = [3, 8, 36, 3]
num_filters1 = [256, 512, 1024, 2048]
num_filters2 = [128, 256, 512, 1024]
if is_3x3 == False:
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
else:
conv = self.conv_bn_layer(
input=input, num_filters=32, filter_size=3, stride=2, act='relu', name='conv1_1')
input=input,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv, num_filters=32, filter_size=3, stride=1, act='relu', name='conv1_2')
input=conv,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_3')
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
......@@ -75,32 +87,32 @@ class ResNeXt():
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name="res"+str(block+2)+"a"
conv_name = "res" + str(block + 2) + "a"
else:
conv_name="res"+str(block+2)+"b"+str(i)
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name="res"+str(block+2)+chr(97+i)
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters1[block] if cardinality == 64 else num_filters2[block],
num_filters=num_filters1[block]
if cardinality == 64 else num_filters2[block],
stride=2 if i == 0 and block != 0 else 1,
cardinality=cardinality,
if_first=block==0,
if_first=block == 0,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc_weights'),
bias_attr=fluid.param_attr.ParamAttr(name='fc_offset'))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc_weights'),
bias_attr=fluid.param_attr.ParamAttr(name='fc_offset'))
return out
def conv_bn_layer(self,
input,
......@@ -118,34 +130,36 @@ class ResNeXt():
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
use_cudnn=False,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(input=input,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg')
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
......@@ -154,20 +168,20 @@ class ResNeXt():
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
use_cudnn=False,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
......@@ -175,13 +189,19 @@ class ResNeXt():
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(input, ch_out, 1, stride, name=name)
return self.conv_bn_layer_new(
input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, cardinality, name, if_first):
def bottleneck_block(self, input, num_filters, stride, cardinality, name,
if_first):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name+"_branch2a")
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
......@@ -189,36 +209,49 @@ class ResNeXt():
stride=stride,
act='relu',
groups=cardinality,
name=name+"_branch2b")
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters if cardinality == 64 else num_filters*2, filter_size=1, act=None, name=name+"_branch2c")
input=conv1,
num_filters=num_filters if cardinality == 64 else num_filters * 2,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(input, num_filters if cardinality == 64 else num_filters*2, stride, if_first=if_first, name=name + "_branch1")
short = self.shortcut(
input,
num_filters if cardinality == 64 else num_filters * 2,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNeXt50_vd_64x4d():
model = ResNeXt(layers=50, is_3x3 = True)
model = ResNeXt(layers=50, is_3x3=True)
return model
def ResNeXt50_vd_32x4d():
model = ResNeXt(layers=50, cardinality=32, is_3x3 = True)
model = ResNeXt(layers=50, cardinality=32, is_3x3=True)
return model
def ResNeXt101_vd_64x4d():
model = ResNeXt(layers=101, is_3x3 = True)
model = ResNeXt(layers=101, is_3x3=True)
return model
def ResNeXt101_vd_32x4d():
model = ResNeXt(layers=101, cardinality=32, is_3x3 = True)
model = ResNeXt(layers=101, cardinality=32, is_3x3=True)
return model
def ResNeXt152_vd_64x4d():
model = ResNeXt(layers=152, is_3x3 = True)
model = ResNeXt(layers=152, is_3x3=True)
return model
def ResNeXt152_vd_32x4d():
model = ResNeXt(layers=152, cardinality=32, is_3x3 = True)
model = ResNeXt(layers=152, cardinality=32, is_3x3=True)
return model
......@@ -27,23 +27,9 @@ __all__ = [
"SE_ResNeXt152_32x4d"
]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"dropout_seed": None,
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [40, 80, 100],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class SE_ResNeXt():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
......@@ -139,8 +125,7 @@ class SE_ResNeXt():
pool_type='avg',
global_pooling=True,
use_cudnn=False)
drop = fluid.layers.dropout(
x=pool, dropout_prob=0.5, seed=self.params['dropout_seed'])
drop = fluid.layers.dropout(x=pool, dropout_prob=0.5)
stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)
out = fluid.layers.fc(
input=drop,
......
......@@ -23,26 +23,12 @@ import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = [
"SE_ResNeXt", "SE_ResNeXt50_32x4d_vd", "SE_ResNeXt101_32x4d_vd",
"SENet154_vd"
"SE_ResNeXt", "SE_ResNeXt50_32x4d_vd", "SE_ResNeXt101_32x4d_vd", "SE154_vd"
]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [10, 16, 20],
"steps": [0.01, 0.001, 0.0001, 0.00001]
}
}
class SE_ResNeXt():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
......@@ -57,11 +43,26 @@ class SE_ResNeXt():
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=2, act='relu', name='conv1_1')
input=input,
num_filters=64,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_2')
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu', name='conv1_3')
input=conv,
num_filters=128,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
......@@ -75,11 +76,26 @@ class SE_ResNeXt():
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=2, act='relu', name='conv1_1')
input=input,
num_filters=64,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu', name='conv1_2')
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu', name='conv1_3')
input=conv,
num_filters=128,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
......@@ -100,7 +116,12 @@ class SE_ResNeXt():
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu',name='conv1_2')
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv,
num_filters=128,
......@@ -121,20 +142,22 @@ class SE_ResNeXt():
stride=2 if i == 0 and block != 0 else 1,
cardinality=cardinality,
reduction_ratio=reduction_ratio,
if_first=block==0,
name=str(n)+'_'+str(i+1))
if_first=block == 0,
name=str(n) + '_' + str(i + 1))
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
if layers == 152:
pool = fluid.layers.dropout(x=pool, dropout_prob=0.2)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
out = fluid.layers.fc(
input=pool,
size=class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
return out
def shortcut(self, input, ch_out, stride, name, if_first=False):
......@@ -142,17 +165,36 @@ class SE_ResNeXt():
if ch_in != ch_out or stride != 1:
filter_size = 1
if if_first:
return self.conv_bn_layer(input, ch_out, filter_size, stride, name='conv'+name+'_prj')
return self.conv_bn_layer(
input,
ch_out,
filter_size,
stride,
name='conv' + name + '_prj')
else:
return self.conv_bn_layer_new(input, ch_out, filter_size, stride, name='conv'+name+'_prj')
return self.conv_bn_layer_new(
input,
ch_out,
filter_size,
stride,
name='conv' + name + '_prj')
else:
return input
def bottleneck_block(self, input, num_filters, stride, cardinality,
reduction_ratio,if_first, name=None):
def bottleneck_block(self,
input,
num_filters,
stride,
cardinality,
reduction_ratio,
if_first,
name=None):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu',name='conv'+name+'_x1')
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name='conv' + name + '_x1')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
......@@ -160,18 +202,23 @@ class SE_ResNeXt():
stride=stride,
groups=cardinality,
act='relu',
name='conv'+name+'_x2')
name='conv' + name + '_x2')
if cardinality == 64:
num_filters = num_filters // 2
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 2, filter_size=1, act=None, name='conv'+name+'_x3')
input=conv1,
num_filters=num_filters * 2,
filter_size=1,
act=None,
name='conv' + name + '_x3')
scale = self.squeeze_excitation(
input=conv2,
num_channels=num_filters * 2,
reduction_ratio=reduction_ratio,
name='fc'+name)
name='fc' + name)
short = self.shortcut(input, num_filters * 2, stride, if_first=if_first, name=name)
short = self.shortcut(
input, num_filters * 2, stride, if_first=if_first, name=name)
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
......@@ -192,29 +239,31 @@ class SE_ResNeXt():
groups=groups,
act=None,
bias_attr=False,
param_attr=ParamAttr(name=name + '_weights'),
)
param_attr=ParamAttr(name=name + '_weights'), )
bn_name = name + "_bn"
return fluid.layers.batch_norm(input=conv, act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(input=input,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg')
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
......@@ -226,33 +275,39 @@ class SE_ResNeXt():
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
bn_name = name + "_bn"
return fluid.layers.batch_norm(input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def squeeze_excitation(self, input, num_channels, reduction_ratio, name=None):
def squeeze_excitation(self,
input,
num_channels,
reduction_ratio,
name=None):
pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(input=pool,
size=num_channels // reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv),name=name+'_sqz_weights'),
bias_attr=ParamAttr(name=name+'_sqz_offset'))
squeeze = fluid.layers.fc(
input=pool,
size=num_channels // reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_sqz_weights'),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv),name=name+'_exc_weights'),
bias_attr=ParamAttr(name=name+'_exc_offset'))
excitation = fluid.layers.fc(
input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_exc_weights'),
bias_attr=ParamAttr(name=name + '_exc_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
......@@ -267,6 +322,6 @@ def SE_ResNeXt101_32x4d_vd():
return model
def SENet154_vd():
def SE_154_vd():
model = SE_ResNeXt(layers=152)
return model
......@@ -22,36 +22,27 @@ import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ['ShuffleNetV2_x0_25', 'ShuffleNetV2_x0_33', 'ShuffleNetV2_x0_5', 'ShuffleNetV2_x1_0', 'ShuffleNetV2_x1_5', 'ShuffleNetV2_x2_0']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
__all__ = [
'ShuffleNetV2_x0_25', 'ShuffleNetV2_x0_33', 'ShuffleNetV2_x0_5',
'ShuffleNetV2_x1_0', 'ShuffleNetV2_x1_5', 'ShuffleNetV2_x2_0',
'ShuffleNetV2'
]
class ShuffleNetV2():
def __init__(self, scale=1.0):
self.params = train_parameters
self.scale = scale
def net(self, input, class_dim=1000):
scale = self.scale
scale = self.scale
stage_repeats = [4, 8, 4]
if scale == 0.25:
stage_out_channels = [-1, 24, 24, 48, 96, 512]
stage_out_channels = [-1, 24, 24, 48, 96, 512]
elif scale == 0.33:
stage_out_channels = [-1, 24, 32, 64, 128, 512]
stage_out_channels = [-1, 24, 32, 64, 128, 512]
elif scale == 0.5:
stage_out_channels = [-1, 24, 48, 96, 192, 1024]
stage_out_channels = [-1, 24, 48, 96, 192, 1024]
elif scale == 1.0:
stage_out_channels = [-1, 24, 116, 232, 464, 1024]
elif scale == 1.5:
......@@ -59,50 +50,77 @@ class ShuffleNetV2():
elif scale == 2.0:
stage_out_channels = [-1, 24, 224, 488, 976, 2048]
else:
raise ValueError(
"""{} groups is not supported for
raise ValueError("""{} groups is not supported for
1x1 Grouped Convolutions""".format(num_groups))
#conv1
input_channel = stage_out_channels[1]
conv1 = self.conv_bn_layer(input=input, filter_size=3, num_filters=input_channel, padding=1, stride=2,name='stage1_conv')
pool1 = fluid.layers.pool2d(input=conv1, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
conv1 = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=input_channel,
padding=1,
stride=2,
name='stage1_conv')
pool1 = fluid.layers.pool2d(
input=conv1,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
conv = pool1
# bottleneck sequences
for idxstage in range(len(stage_repeats)):
numrepeat = stage_repeats[idxstage]
output_channel = stage_out_channels[idxstage+2]
output_channel = stage_out_channels[idxstage + 2]
for i in range(numrepeat):
if i == 0:
conv = self.inverted_residual_unit(input=conv, num_filters=output_channel, stride=2,
benchmodel=2,name=str(idxstage+2)+'_'+str(i+1))
conv = self.inverted_residual_unit(
input=conv,
num_filters=output_channel,
stride=2,
benchmodel=2,
name=str(idxstage + 2) + '_' + str(i + 1))
else:
conv = self.inverted_residual_unit(input=conv, num_filters=output_channel, stride=1,
benchmodel=1,name=str(idxstage+2)+'_'+str(i+1))
conv_last = self.conv_bn_layer(input=conv, filter_size=1, num_filters=stage_out_channels[-1],
padding=0, stride=1, name='conv5')
pool_last = fluid.layers.pool2d(input=conv_last, pool_size=7, pool_stride=1, pool_padding=0, pool_type='avg')
conv = self.inverted_residual_unit(
input=conv,
num_filters=output_channel,
stride=1,
benchmodel=1,
name=str(idxstage + 2) + '_' + str(i + 1))
conv_last = self.conv_bn_layer(
input=conv,
filter_size=1,
num_filters=stage_out_channels[-1],
padding=0,
stride=1,
name='conv5')
pool_last = fluid.layers.pool2d(
input=conv_last,
pool_size=7,
pool_stride=1,
pool_padding=0,
pool_type='avg')
output = fluid.layers.fc(input=pool_last,
size=class_dim,
param_attr=ParamAttr(initializer=MSRA(),name='fc6_weights'),
param_attr=ParamAttr(
initializer=MSRA(), name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
return output
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
use_cudnn=True,
if_act=True,
name=None):
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
use_cudnn=True,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -112,162 +130,179 @@ class ShuffleNetV2():
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(initializer=MSRA(),name=name+'_weights'),
param_attr=ParamAttr(
initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
out = int((input.shape[2] - 1)/float(stride) + 1)
out = int((input.shape[2] - 1) / float(stride) + 1)
bn_name = name + '_bn'
if if_act:
return fluid.layers.batch_norm(input=conv, act='relu',
param_attr = ParamAttr(name=bn_name+"_scale"),
bias_attr=ParamAttr(name=bn_name+"_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
act='relu',
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
else:
return fluid.layers.batch_norm(input=conv,
param_attr = ParamAttr(name=bn_name+"_scale"),
bias_attr=ParamAttr(name=bn_name+"_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def channel_shuffle(self, x, groups):
batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
batchsize, num_channels, height, width = x.shape[0], x.shape[
1], x.shape[2], x.shape[3]
channels_per_group = num_channels // groups
# reshape
x = fluid.layers.reshape(x=x, shape=[batchsize, groups, channels_per_group, height, width])
x = fluid.layers.reshape(
x=x, shape=[batchsize, groups, channels_per_group, height, width])
x = fluid.layers.transpose(x=x, perm=[0,2,1,3,4])
x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3, 4])
# flatten
x = fluid.layers.reshape(x=x, shape=[batchsize, num_channels, height, width])
x = fluid.layers.reshape(
x=x, shape=[batchsize, num_channels, height, width])
return x
def inverted_residual_unit(self, input, num_filters, stride, benchmodel, name=None):
def inverted_residual_unit(self,
input,
num_filters,
stride,
benchmodel,
name=None):
assert stride in [1, 2], \
"supported stride are {} but your stride is {}".format([1,2], stride)
oup_inc = num_filters//2
oup_inc = num_filters // 2
inp = input.shape[1]
if benchmodel == 1:
x1, x2 = fluid.layers.split(
input, num_or_sections=[input.shape[1]//2, input.shape[1]//2], dim=1)
input,
num_or_sections=[input.shape[1] // 2, input.shape[1] // 2],
dim=1)
conv_pw = self.conv_bn_layer(
input=x2,
num_filters=oup_inc,
filter_size=1,
input=x2,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv1')
name='stage_' + name + '_conv1')
conv_dw = self.conv_bn_layer(
input=conv_pw,
num_filters=oup_inc,
filter_size=3,
stride=stride,
input=conv_pw,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
num_groups=oup_inc,
if_act=False,
use_cudnn=False,
name='stage_'+name+'_conv2')
name='stage_' + name + '_conv2')
conv_linear = self.conv_bn_layer(
input=conv_dw,
num_filters=oup_inc,
filter_size=1,
stride=1,
input=conv_dw,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv3')
name='stage_' + name + '_conv3')
out = fluid.layers.concat([x1, conv_linear], axis=1)
else:
#branch1
conv_dw_1 = self.conv_bn_layer(
input=input,
num_filters=inp,
filter_size=3,
input=input,
num_filters=inp,
filter_size=3,
stride=stride,
padding=1,
num_groups=inp,
if_act=False,
use_cudnn=False,
name='stage_'+name+'_conv4')
name='stage_' + name + '_conv4')
conv_linear_1 = self.conv_bn_layer(
input=conv_dw_1,
num_filters=oup_inc,
filter_size=1,
input=conv_dw_1,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv5')
name='stage_' + name + '_conv5')
#branch2
conv_pw_2 = self.conv_bn_layer(
input=input,
num_filters=oup_inc,
filter_size=1,
input=input,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv1')
name='stage_' + name + '_conv1')
conv_dw_2 = self.conv_bn_layer(
input=conv_pw_2,
num_filters=oup_inc,
filter_size=3,
stride=stride,
input=conv_pw_2,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
num_groups=oup_inc,
if_act=False,
use_cudnn=False,
name='stage_'+name+'_conv2')
name='stage_' + name + '_conv2')
conv_linear_2 = self.conv_bn_layer(
input=conv_dw_2,
num_filters=oup_inc,
filter_size=1,
stride=1,
input=conv_dw_2,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv3')
name='stage_' + name + '_conv3')
out = fluid.layers.concat([conv_linear_1, conv_linear_2], axis=1)
return self.channel_shuffle(out, 2)
def ShuffleNetV2_x0_25():
model = ShuffleNetV2(scale=0.25)
return model
def ShuffleNetV2_x0_33():
model = ShuffleNetV2(scale=0.33)
return model
def ShuffleNetV2_x0_5():
model = ShuffleNetV2(scale=0.5)
return model
def ShuffleNetV2_x1_0():
model = ShuffleNetV2(scale=1.0)
return model
def ShuffleNetV2_x1_5():
model = ShuffleNetV2(scale=1.5)
return model
def ShuffleNetV2_x2_0():
model = ShuffleNetV2(scale=2.0)
return model
......@@ -22,32 +22,22 @@ import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ['ShuffleNetV2_x0_5_swish', 'ShuffleNetV2_x1_0_swish', 'ShuffleNetV2_x1_5_swish', 'ShuffleNetV2_x2_0_swish']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ShuffleNetV2():
__all__ = [
'ShuffleNetV2_x0_5_swish', 'ShuffleNetV2_x1_0_swish',
'ShuffleNetV2_x1_5_swish', 'ShuffleNetV2_x2_0_swish', 'ShuffleNetV2_swish'
]
class ShuffleNetV2_swish():
def __init__(self, scale=1.0):
self.params = train_parameters
self.scale = scale
def net(self, input, class_dim=1000):
scale = self.scale
scale = self.scale
stage_repeats = [4, 8, 4]
if scale == 0.5:
stage_out_channels = [-1, 24, 48, 96, 192, 1024]
stage_out_channels = [-1, 24, 48, 96, 192, 1024]
elif scale == 1.0:
stage_out_channels = [-1, 24, 116, 232, 464, 1024]
elif scale == 1.5:
......@@ -55,50 +45,77 @@ class ShuffleNetV2():
elif scale == 2.0:
stage_out_channels = [-1, 24, 224, 488, 976, 2048]
else:
raise ValueError(
"""{} groups is not supported for
raise ValueError("""{} groups is not supported for
1x1 Grouped Convolutions""".format(num_groups))
#conv1
input_channel = stage_out_channels[1]
conv1 = self.conv_bn_layer(input=input, filter_size=3, num_filters=input_channel, padding=1, stride=2,name='stage1_conv')
pool1 = fluid.layers.pool2d(input=conv1, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
conv1 = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=input_channel,
padding=1,
stride=2,
name='stage1_conv')
pool1 = fluid.layers.pool2d(
input=conv1,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
conv = pool1
# bottleneck sequences
for idxstage in range(len(stage_repeats)):
numrepeat = stage_repeats[idxstage]
output_channel = stage_out_channels[idxstage+2]
output_channel = stage_out_channels[idxstage + 2]
for i in range(numrepeat):
if i == 0:
conv = self.inverted_residual_unit(input=conv, num_filters=output_channel, stride=2,
benchmodel=2,name=str(idxstage+2)+'_'+str(i+1))
conv = self.inverted_residual_unit(
input=conv,
num_filters=output_channel,
stride=2,
benchmodel=2,
name=str(idxstage + 2) + '_' + str(i + 1))
else:
conv = self.inverted_residual_unit(input=conv, num_filters=output_channel, stride=1,
benchmodel=1,name=str(idxstage+2)+'_'+str(i+1))
conv_last = self.conv_bn_layer(input=conv, filter_size=1, num_filters=stage_out_channels[-1],
padding=0, stride=1, name='conv5')
pool_last = fluid.layers.pool2d(input=conv_last, pool_size=7, pool_stride=1, pool_padding=0, pool_type='avg')
conv = self.inverted_residual_unit(
input=conv,
num_filters=output_channel,
stride=1,
benchmodel=1,
name=str(idxstage + 2) + '_' + str(i + 1))
conv_last = self.conv_bn_layer(
input=conv,
filter_size=1,
num_filters=stage_out_channels[-1],
padding=0,
stride=1,
name='conv5')
pool_last = fluid.layers.pool2d(
input=conv_last,
pool_size=7,
pool_stride=1,
pool_padding=0,
pool_type='avg')
output = fluid.layers.fc(input=pool_last,
size=class_dim,
param_attr=ParamAttr(initializer=MSRA(),name='fc6_weights'),
param_attr=ParamAttr(
initializer=MSRA(), name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
return output
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
use_cudnn=True,
if_act=True,
name=None):
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
use_cudnn=True,
if_act=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
......@@ -108,154 +125,169 @@ class ShuffleNetV2():
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(initializer=MSRA(),name=name+'_weights'),
param_attr=ParamAttr(
initializer=MSRA(), name=name + '_weights'),
bias_attr=False)
out = int((input.shape[2] - 1)/float(stride) + 1)
out = int((input.shape[2] - 1) / float(stride) + 1)
bn_name = name + '_bn'
if if_act:
return fluid.layers.batch_norm(input=conv, act='swish',
param_attr = ParamAttr(name=bn_name+"_scale"),
bias_attr=ParamAttr(name=bn_name+"_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
act='swish',
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
else:
return fluid.layers.batch_norm(input=conv,
param_attr = ParamAttr(name=bn_name+"_scale"),
bias_attr=ParamAttr(name=bn_name+"_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
return fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def channel_shuffle(self, x, groups):
batchsize, num_channels, height, width = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
batchsize, num_channels, height, width = x.shape[0], x.shape[
1], x.shape[2], x.shape[3]
channels_per_group = num_channels // groups
# reshape
x = fluid.layers.reshape(x=x, shape=[batchsize, groups, channels_per_group, height, width])
x = fluid.layers.reshape(
x=x, shape=[batchsize, groups, channels_per_group, height, width])
x = fluid.layers.transpose(x=x, perm=[0,2,1,3,4])
x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3, 4])
# flatten
x = fluid.layers.reshape(x=x, shape=[batchsize, num_channels, height, width])
x = fluid.layers.reshape(
x=x, shape=[batchsize, num_channels, height, width])
return x
def inverted_residual_unit(self, input, num_filters, stride, benchmodel, name=None):
def inverted_residual_unit(self,
input,
num_filters,
stride,
benchmodel,
name=None):
assert stride in [1, 2], \
"supported stride are {} but your stride is {}".format([1,2], stride)
oup_inc = num_filters//2
oup_inc = num_filters // 2
inp = input.shape[1]
if benchmodel == 1:
x1, x2 = fluid.layers.split(
input, num_or_sections=[input.shape[1]//2, input.shape[1]//2], dim=1)
input,
num_or_sections=[input.shape[1] // 2, input.shape[1] // 2],
dim=1)
conv_pw = self.conv_bn_layer(
input=x2,
num_filters=oup_inc,
filter_size=1,
input=x2,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv1')
name='stage_' + name + '_conv1')
conv_dw = self.conv_bn_layer(
input=conv_pw,
num_filters=oup_inc,
filter_size=3,
stride=stride,
input=conv_pw,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
num_groups=oup_inc,
if_act=False,
use_cudnn=False,
name='stage_'+name+'_conv2')
name='stage_' + name + '_conv2')
conv_linear = self.conv_bn_layer(
input=conv_dw,
num_filters=oup_inc,
filter_size=1,
stride=1,
input=conv_dw,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv3')
name='stage_' + name + '_conv3')
out = fluid.layers.concat([x1, conv_linear], axis=1)
else:
#branch1
conv_dw_1 = self.conv_bn_layer(
input=input,
num_filters=inp,
filter_size=3,
input=input,
num_filters=inp,
filter_size=3,
stride=stride,
padding=1,
num_groups=inp,
if_act=False,
use_cudnn=False,
name='stage_'+name+'_conv4')
name='stage_' + name + '_conv4')
conv_linear_1 = self.conv_bn_layer(
input=conv_dw_1,
num_filters=oup_inc,
filter_size=1,
input=conv_dw_1,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv5')
name='stage_' + name + '_conv5')
#branch2
conv_pw_2 = self.conv_bn_layer(
input=input,
num_filters=oup_inc,
filter_size=1,
input=input,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv1')
name='stage_' + name + '_conv1')
conv_dw_2 = self.conv_bn_layer(
input=conv_pw_2,
num_filters=oup_inc,
filter_size=3,
stride=stride,
input=conv_pw_2,
num_filters=oup_inc,
filter_size=3,
stride=stride,
padding=1,
num_groups=oup_inc,
num_groups=oup_inc,
if_act=False,
use_cudnn=False,
name='stage_'+name+'_conv2')
name='stage_' + name + '_conv2')
conv_linear_2 = self.conv_bn_layer(
input=conv_dw_2,
num_filters=oup_inc,
filter_size=1,
stride=1,
input=conv_dw_2,
num_filters=oup_inc,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
num_groups=1,
if_act=True,
name='stage_'+name+'_conv3')
name='stage_' + name + '_conv3')
out = fluid.layers.concat([conv_linear_1, conv_linear_2], axis=1)
return self.channel_shuffle(out, 2)
def ShuffleNetV2_x0_5_swish():
model = ShuffleNetV2(scale=0.5)
model = ShuffleNetV2_swish(scale=0.5)
return model
def ShuffleNetV2_x1_0_swish():
model = ShuffleNetV2(scale=1.0)
model = ShuffleNetV2_swish(scale=1.0)
return model
def ShuffleNetV2_x1_5_swish():
model = ShuffleNetV2(scale=1.5)
model = ShuffleNetV2_swish(scale=1.5)
return model
def ShuffleNetV2_x2_0_swish():
model = ShuffleNetV2(scale=2.0)
model = ShuffleNetV2_swish(scale=2.0)
return model
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math
......@@ -18,99 +22,111 @@ from paddle.fluid.param_attr import ParamAttr
__all__ = ["SqueezeNet", "SqueezeNet1_0", "SqueezeNet1_1"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class SqueezeNet():
def __init__(self, version='1.0'):
self.params = train_parameters
self.version = version
def net(self, input, class_dim=1000):
version = self.version
assert version in ['1.0', '1.1'], \
"supported version are {} but input version is {}".format(['1.0', '1.1'], version)
if version == '1.0':
conv = fluid.layers.conv2d(input,
num_filters=96,
filter_size=7,
stride=2,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"),
bias_attr=ParamAttr(name='conv1_offset'))
conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2,pool_type='max')
conv = fluid.layers.conv2d(
input,
num_filters=96,
filter_size=7,
stride=2,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"),
bias_attr=ParamAttr(name='conv1_offset'))
conv = fluid.layers.pool2d(
conv, pool_size=3, pool_stride=2, pool_type='max')
conv = self.make_fire(conv, 16, 64, 64, name='fire2')
conv = self.make_fire(conv, 16, 64, 64, name='fire3')
conv = self.make_fire(conv, 32, 128, 128, name='fire4')
conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max')
conv = fluid.layers.pool2d(
conv, pool_size=3, pool_stride=2, pool_type='max')
conv = self.make_fire(conv, 32, 128, 128, name='fire5')
conv = self.make_fire(conv, 48, 192, 192, name='fire6')
conv = self.make_fire(conv, 48, 192, 192, name='fire7')
conv = self.make_fire(conv, 64, 256, 256, name='fire8')
conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max')
conv = fluid.layers.pool2d(
conv, pool_size=3, pool_stride=2, pool_type='max')
conv = self.make_fire(conv, 64, 256, 256, name='fire9')
else:
conv = fluid.layers.conv2d(input,
num_filters=64,
filter_size=3,
stride=2,
padding=1,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"),
bias_attr=ParamAttr(name='conv1_offset'))
conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max')
conv = fluid.layers.conv2d(
input,
num_filters=64,
filter_size=3,
stride=2,
padding=1,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name="conv1_weights"),
bias_attr=ParamAttr(name='conv1_offset'))
conv = fluid.layers.pool2d(
conv, pool_size=3, pool_stride=2, pool_type='max')
conv = self.make_fire(conv, 16, 64, 64, name='fire2')
conv = self.make_fire(conv, 16, 64, 64, name='fire3')
conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max')
conv = fluid.layers.pool2d(
conv, pool_size=3, pool_stride=2, pool_type='max')
conv = self.make_fire(conv, 32, 128, 128, name='fire4')
conv = self.make_fire(conv, 32, 128, 128, name='fire5')
conv = fluid.layers.pool2d(conv, pool_size=3, pool_stride=2, pool_type='max')
conv = fluid.layers.pool2d(
conv, pool_size=3, pool_stride=2, pool_type='max')
conv = self.make_fire(conv, 48, 192, 192, name='fire6')
conv = self.make_fire(conv, 48, 192, 192, name='fire7')
conv = self.make_fire(conv, 64, 256, 256, name='fire8')
conv = self.make_fire(conv, 64, 256, 256, name='fire9')
conv = fluid.layers.dropout(conv, dropout_prob=0.5)
conv = fluid.layers.conv2d(conv,
num_filters=class_dim,
filter_size=1,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name="conv10_weights"),
bias_attr=ParamAttr(name='conv10_offset'))
conv = fluid.layers.conv2d(
conv,
num_filters=class_dim,
filter_size=1,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name="conv10_weights"),
bias_attr=ParamAttr(name='conv10_offset'))
conv = fluid.layers.pool2d(conv, pool_type='avg', global_pooling=True)
out = fluid.layers.flatten(conv)
return out
def make_fire_conv(self, input, num_filters, filter_size, padding=0, name=None):
conv = fluid.layers.conv2d(input,
num_filters=num_filters,
filter_size=filter_size,
padding=padding,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(name=name + '_offset'))
def make_fire_conv(self,
input,
num_filters,
filter_size,
padding=0,
name=None):
conv = fluid.layers.conv2d(
input,
num_filters=num_filters,
filter_size=filter_size,
padding=padding,
act='relu',
param_attr=fluid.param_attr.ParamAttr(name=name + "_weights"),
bias_attr=ParamAttr(name=name + '_offset'))
return conv
def make_fire(self, input, squeeze_channels, expand1x1_channels, expand3x3_channels, name=None):
conv = self.make_fire_conv(input, squeeze_channels, 1, name=name+'_squeeze1x1')
conv_path1 = self.make_fire_conv(conv, expand1x1_channels, 1, name=name+'_expand1x1')
conv_path2 = self.make_fire_conv(conv, expand3x3_channels, 3, 1, name=name+'_expand3x3')
def make_fire(self,
input,
squeeze_channels,
expand1x1_channels,
expand3x3_channels,
name=None):
conv = self.make_fire_conv(
input, squeeze_channels, 1, name=name + '_squeeze1x1')
conv_path1 = self.make_fire_conv(
conv, expand1x1_channels, 1, name=name + '_expand1x1')
conv_path2 = self.make_fire_conv(
conv, expand3x3_channels, 3, 1, name=name + '_expand3x3')
out = fluid.layers.concat([conv_path1, conv_path2], axis=1)
return out
def SqueezeNet1_0():
model = SqueezeNet(version='1.0')
return model
def SqueezeNet1_1():
model = SqueezeNet(version='1.1')
return model
......@@ -21,22 +21,9 @@ import paddle.fluid as fluid
__all__ = ["VGGNet", "VGG11", "VGG13", "VGG16", "VGG19"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class VGGNet():
def __init__(self, layers=16):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
......@@ -93,8 +80,7 @@ class VGGNet():
act='relu',
param_attr=fluid.param_attr.ParamAttr(
name=name + str(i + 1) + "_weights"),
bias_attr=fluid.param_attr.ParamAttr(
name=name + str(i + 1) + "_offset"))
bias_attr=False)
return fluid.layers.pool2d(
input=conv, pool_size=2, pool_type='max', pool_stride=2)
......
此差异已折叠。
#Training details
#GPU: NVIDIA® Tesla® P40 8cards 120epochs 55h
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
#AlexNet:
python train.py \
--model=AlexNet \
--batch_size=256 \
--total_images=1281167 \
--class_dim=1000 \
--image_shape=3,224,224 \
--model_save_dir=output/ \
--lr_strategy=piecewise_decay \
--num_epochs=120 \
--lr=0.01 \
--l2_decay=1e-4
#Training details
#DarkNet53
python train.py \
--model=DarkNet53 \
--batch_size=256 \
--total_images=1281167 \
--image_shape=3,256,256 \
--class_dim=1000 \
--lr_strategy=cosine_decay \
--lr=0.1 \
--num_epochs=200 \
--model_save_dir=output/ \
--l2_decay=1e-4 \
--use_mixup=True \
--resize_short_size=256 \
--use_label_smoothing=True \
--label_smoothing_epsilon=0.1 \
#Training details
#DenseNet121
python train.py \
--model=DenseNet121 \
--batch_size=256 \
--total_images=1281167 \
--image_shape=3,224,224 \
--class_dim=1000 \
--lr_strategy=piecewise_decay \
--lr=0.1 \
--num_epochs=120 \
--model_save_dir=output/ \
--l2_decay=1e-4
#Traing details
#DenseNet161
python train.py \
--model=DenseNet161 \
--batch_size=256 \
--total_images=1281167 \
--image_shape=3,224,224 \
--class_dim=1000 \
--lr_strategy=piecewise_decay \
--lr=0.1 \
--num_epochs=120 \
--model_save_dir=output/ \
--l2_decay=1e-4
#Training details
#DenseNet169
python train.py \
--model=DenseNet169 \
--batch_size=256 \
--total_images=1281167 \
--image_shape=3,224,224 \
--class_dim=1000 \
--lr_strategy=piecewise_decay \
--lr=0.1 \
--num_epochs=120 \
--model_save_dir=output/ \
--l2_decay=1e-4
python train.py \
--model=ResNeXt50_32x4d \
--batch_size=256 \
--total_images=1281167 \
--image_shape=3,224,224 \
--class_dim=1000 \
--lr_strategy=piecewise_decay \
--lr=0.1 \
--num_epochs=120 \
--model_save_dir=output/ \
--l2_decay=1e-4
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册