未验证 提交 86c06294 编写于 作者: W wuzewu 提交者: GitHub

add detection module

import paddle
import paddlehub as hub
if __name__ == '__main__':
place = paddle.CUDAPlace(0)
paddle.disable_static()
model = hub.Module(name='yolov3_darknet53_pascalvoc', is_train=False)
model.eval()
model.predict(imgpath="4026.jpeg", filelist="/PATH/TO/JSON/FILE")
import paddle
import paddlehub as hub
import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.pascalvoc import DetectionData
import paddlehub.process.detect_transforms as T
if __name__ == "__main__":
paddle.disable_static()
transform = T.Compose([
T.RandomDistort(),
T.RandomExpand(fill=[0.485, 0.456, 0.406]),
T.RandomCrop(),
T.Resize(target_size=416),
T.RandomFlip(),
T.ShuffleBox(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_reader = DetectionData(transform)
model = hub.Module(name='yolov3_darknet53_pascalvoc')
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_det')
trainer.train(train_reader, epochs=5, batch_size=4, eval_dataset=train_reader, log_interval=1, save_interval=1)
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from paddle.regularizer import L2Decay
from paddlehub.module.cv_module import Yolov3Module
import paddlehub.process.detect_transforms as T
from paddlehub.module.module import moduleinfo
class ConvBNLayer(nn.Layer):
"""Basic block for Darknet"""
def __init__(self,
ch_in: int,
ch_out: int,
filter_size: int = 3,
stride: int = 1,
groups: int = 1,
padding: int = 0,
act: str = 'leakly',
is_test: bool = False):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2d(ch_in,
ch_out,
filter_size,
padding=padding,
stride=stride,
groups=groups,
weight_attr=paddle.ParamAttr(initializer=Normal(0., 0.02)),
bias_attr=False)
self.batch_norm = nn.BatchNorm(num_channels=ch_out,
is_test=is_test,
param_attr=paddle.ParamAttr(initializer=Normal(0., 0.02),
regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == "leakly":
out = F.leaky_relu(x=out, negative_slope=0.1)
return out
class DownSample(nn.Layer):
"""Downsample block for Darknet"""
def __init__(self,
ch_in: int,
ch_out: int,
filter_size: int = 3,
stride: int = 2,
padding: int = 1,
is_test: bool = False):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
is_test=is_test)
self.ch_out = ch_out
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(nn.Layer):
"""Basic residual block for Darknet"""
def __init__(self, ch_in: int, ch_out: int, is_test: bool = False):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(ch_in=ch_in, ch_out=ch_out, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv2 = ConvBNLayer(ch_in=ch_out, ch_out=ch_out * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = paddle.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(nn.Layer):
"""Warp layer composed by basic residual blocks"""
def __init__(self, ch_in: int, ch_out: int, count: int, is_test: bool = False):
super(LayerWarp, self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out, is_test=is_test)
self.res_out_list = []
for i in range(1, count):
res_out = self.add_sublayer("basic_block_%d" % (i), BasicBlock(ch_out * 2, ch_out, is_test=is_test))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
class DarkNet53_conv_body(nn.Layer):
"""Darknet53
Args:
ch_in(int): Input channels, default is 3.
is_test (bool): Set the test mode, default is True.
"""
def __init__(self, ch_in: int = 3, is_test: bool = False):
super(DarkNet53_conv_body, self).__init__()
self.stages = [1, 2, 8, 8, 4]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(ch_in=ch_in, ch_out=32, filter_size=3, stride=1, padding=1, is_test=is_test)
self.downsample0 = DownSample(ch_in=32, ch_out=32 * 2, is_test=is_test)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64, 128, 256, 512, 1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer("stage_%d" % (i),
LayerWarp(int(ch_in[i]), 32 * (2**i), stage, is_test=is_test))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i, DownSample(ch_in=32 * (2**(i + 1)),
ch_out=32 * (2**(i + 2)),
is_test=is_test))
self.downsample_list.append(downsample)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]
class YoloDetectionBlock(nn.Layer):
"""Basic block for Yolov3"""
def __init__(self, ch_in: int, channel: int, is_test: bool = True):
super(YoloDetectionBlock, self).__init__()
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
self.conv0 = ConvBNLayer(ch_in=ch_in, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv1 = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
self.conv2 = ConvBNLayer(ch_in=channel * 2, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv3 = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
self.route = ConvBNLayer(ch_in=channel * 2, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.tip = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
tip = self.tip(route)
return route, tip
class Upsample(nn.Layer):
"""Upsample block for Yolov3"""
def __init__(self, scale: int = 2):
super(Upsample, self).__init__()
self.scale = scale
def forward(self, inputs: paddle.Tensor):
shape_nchw = paddle.to_tensor(inputs.shape)
shape_hw = paddle.slice(shape_nchw, axes=[0], starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = paddle.cast(shape_hw, dtype='int32')
out_shape = in_shape * self.scale
out_shape.stop_gradient = True
out = F.resize_nearest(input=inputs, scale=self.scale, actual_shape=out_shape)
return out
@moduleinfo(name="yolov3_darknet53_pascalvoc",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="Yolov3 is a detection model, this module is trained with VOC dataset.",
version="1.0.0",
meta=Yolov3Module)
class YOLOv3(nn.Layer):
"""YOLOV3 for detection
Args:
ch_in(int): Input channels, default is 3.
class_num(int): Categories for detection,if dataset is voc, class_num is 20.
ignore_thresh(float): The ignore threshold to ignore confidence loss.
valid_thresh(float): Threshold to filter out bounding boxes with low confidence score.
nms_topk(int): Maximum number of detections to be kept according to the confidences after the filtering
detections based on score_threshold.
nms_posk(int): Number of total bboxes to be kept per image after NMS step. -1 means keeping all bboxes after NMS
step.
nms_thresh (float): The threshold to be used in NMS. Default: 0.3.
is_train (bool): Set the train mode, default is True.
load_checkpoint(str): Whether to load checkpoint.
"""
def __init__(self,
ch_in: int = 3,
class_num: int = 20,
ignore_thresh: float = 0.7,
valid_thresh: float = 0.005,
nms_topk: int = 400,
nms_posk: int = 100,
nms_thresh: float = 0.45,
is_train: bool = True,
load_checkpoint: str = None):
super(YOLOv3, self).__init__()
self.is_train = is_train
self.block = DarkNet53_conv_body(ch_in=ch_in, is_test=not self.is_train)
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks_2 = []
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
self.class_num = class_num
self.ignore_thresh = ignore_thresh
self.valid_thresh = valid_thresh
self.nms_topk = nms_topk
self.nms_posk = nms_posk
self.nms_thresh = nms_thresh
ch_in_list = [1024, 768, 384]
for i in range(3):
yolo_block = self.add_sublayer(
"yolo_detecton_block_%d" % (i),
YoloDetectionBlock(ch_in_list[i], channel=512 // (2**i), is_test=not self.is_train))
self.yolo_blocks.append(yolo_block)
num_filters = len(self.anchor_masks[i]) * (self.class_num + 5)
block_out = self.add_sublayer(
"block_out_%d" % (i),
nn.Conv2d(1024 // (2**i),
num_filters,
1,
stride=1,
padding=0,
weight_attr=paddle.ParamAttr(initializer=Normal(0., 0.02)),
bias_attr=paddle.ParamAttr(initializer=Constant(0.0), regularizer=L2Decay(0.))))
self.block_outputs.append(block_out)
if i < 2:
route = self.add_sublayer(
"route2_%d" % i,
ConvBNLayer(ch_in=512 // (2**i),
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not self.is_train)))
self.route_blocks_2.append(route)
self.upsample = Upsample()
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'yolov3_darknet53_voc.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://paddlehub.bj.bcebos.com/dygraph/detection/yolov3_darknet53_voc.pdparams -O ' \
+ checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transform(self, img):
if self.is_train:
transform = T.Compose([
T.RandomDistort(),
T.RandomExpand(fill=[0.485, 0.456, 0.406]),
T.RandomCrop(),
T.Resize(target_size=416),
T.RandomFlip(),
T.ShuffleBox(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
transform = T.Compose([
T.Resize(target_size=416, interp='CUBIC'),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(img)
def forward(self, inputs: paddle.Tensor):
outputs = []
blocks = self.block(inputs)
route = None
for i, block in enumerate(blocks):
if i > 0:
block = paddle.concat([route, block], axis=1)
route, tip = self.yolo_blocks[i](block)
block_out = self.block_outputs[i](tip)
outputs.append(block_out)
if i < 2:
route = self.route_blocks_2[i](route)
route = self.upsample(route)
return outputs
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
from typing import Callable
import paddle
import numpy as np
from paddlehub.env import DATA_HOME
from pycocotools.coco import COCO
class DetectCatagory:
"""Load label name, id and map from detection dataset.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
Returns:
label_names(List(str)): The dataset label names.
label_ids(List(int)): The dataset label ids.
category_to_id_map(dict): Mapping relations of category and id for images.
"""
def __init__(self, attrbox: Callable, data_dir: str):
self.attrbox = attrbox
self.img_dir = data_dir
def __call__(self):
self.categories = self.attrbox.loadCats(self.attrbox.getCatIds())
self.num_category = len(self.categories)
label_names = []
label_ids = []
for category in self.categories:
label_names.append(category['name'])
label_ids.append(int(category['id']))
category_to_id_map = {v: i for i, v in enumerate(label_ids)}
return label_names, label_ids, category_to_id_map
class ParseImages:
"""Prepare images for detection.
Args:
attrbox(Callable): Method to get detection attributes of images.
data_dir(str): Image dataset path.
category_to_id_map(dict): Mapping relations of category and id for images.
Returns:
imgs(dict): The input for detection model, it is a dict.
"""
def __init__(self, attrbox: Callable, data_dir: str, category_to_id_map: dict):
self.attrbox = attrbox
self.img_dir = data_dir
self.category_to_id_map = category_to_id_map
self.parse_gt_annotations = GTAnotations(self.attrbox, self.category_to_id_map)
def __call__(self):
image_ids = self.attrbox.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.attrbox.loadImgs(image_ids))
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), "image {} not found.".format(img['image'])
box_num = 50
img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32)
img['gt_labels'] = np.zeros((box_num), dtype=np.int32)
img = self.parse_gt_annotations(img)
return imgs
class GTAnotations:
"""Set gt boxes and gt labels for train.
Args:
attrbox(Callable): Method for get detection attributes for images.
category_to_id_map(dict): Mapping relations of category and id for images.
img(dict): Input for detection model.
Returns:
img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input.
"""
def __init__(self, attrbox: Callable, category_to_id_map: dict):
self.attrbox = attrbox
self.category_to_id_map = category_to_id_map
def box_to_center_relative(self, box: list, img_height: int, img_width: int) -> np.ndarray:
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert len(box) == 4, "box should be a len(4) list or tuple"
x, y, w, h = box
x1 = max(x, 0)
x2 = min(x + w - 1, img_width - 1)
y1 = max(y, 0)
y2 = min(y + h - 1, img_height - 1)
x = (x1 + x2) / 2 / img_width
y = (y1 + y2) / 2 / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
return np.array([x, y, w, h])
def __call__(self, img: dict):
img_height = img['height']
img_width = img['width']
anno = self.attrbox.loadAnns(self.attrbox.getAnnIds(imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < -1:
continue
if 'ignore' in target and target['ignore']:
continue
box = self.box_to_center_relative(target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
img['gt_boxes'][gt_index] = box
img['gt_labels'][gt_index] = \
self.category_to_id_map[target['category_id']]
gt_index += 1
if gt_index >= 50:
break
return img
class DetectionData(paddle.io.Dataset):
"""
Dataset for image detection.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
def __init__(self, transform: Callable, size: int = 416, mode: str = 'train'):
self.mode = mode
self.transform = transform
self.size = size
if self.mode == 'train':
train_file_list = 'annotations/instances_train2017.json'
train_data_dir = 'train2017'
self.train_file_list = os.path.join(DATA_HOME, 'voc', train_file_list)
self.train_data_dir = os.path.join(DATA_HOME, 'voc', train_data_dir)
self.COCO = COCO(self.train_file_list)
self.img_dir = self.train_data_dir
elif self.mode == 'test':
val_file_list = 'annotations/instances_val2017.json'
val_data_dir = 'val2017'
self.val_file_list = os.path.join(DATA_HOME, 'voc', val_file_list)
self.val_data_dir = os.path.join(DATA_HOME, 'voc', val_data_dir)
self.COCO = COCO(self.val_file_list)
self.img_dir = self.val_data_dir
parse_dataset_catagory = DetectCatagory(self.COCO, self.img_dir)
self.label_names, self.label_ids, self.category_to_id_map = parse_dataset_catagory()
parse_images = ParseImages(self.COCO, self.img_dir, self.category_to_id_map)
self.data = parse_images()
def __getitem__(self, idx: int):
img = self.data[idx]
im, data = self.transform(img)
out_img, gt_boxes, gt_labels, gt_scores = im, data['gt_boxes'], data['gt_labels'], data['gt_scores']
return out_img, gt_boxes, gt_labels, gt_scores
def __len__(self):
return len(self.data)
......@@ -27,8 +27,8 @@ from PIL import Image
from paddlehub.module.module import serving, RunModule
from paddlehub.utils.utils import base64_to_cv2
from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize
from paddlehub.process.functional import subtract_imagenet_mean_batch, gram_matrix
import paddlehub.process.transforms as T
import paddlehub.process.functional as Func
class ImageServing(object):
......@@ -136,8 +136,8 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_ret = OrderedDict()
psnrs = []
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB')
process = T.ColorPostprocess()
for i in range(batch[0].numpy().shape[0]):
real = lab2rgb(np.concatenate((batch[0].numpy(), batch[3].numpy()), axis=1))[i]
......@@ -163,9 +163,9 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns:
results(list[dict]) : The prediction result of each input image
'''
lab2rgb = ConvertColorSpace(mode='LAB2RGB')
process = ColorPostprocess()
resize = Resize((256, 256))
lab2rgb = T.ConvertColorSpace(mode='LAB2RGB')
process = T.ColorPostprocess()
resize = T.Resize((256, 256))
visual_ret = OrderedDict()
im = self.transforms(images, is_train=False)
out_class, out_reg = self(paddle.to_tensor(im['A']), paddle.to_variable(im['hint_B']),
......@@ -196,6 +196,124 @@ class ImageColorizeModule(RunModule, ImageServing):
return result
class Yolov3Module(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict): The model outputs, such as loss.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
img = batch[0].astype('float32')
gtbox = batch[1].astype('float32')
gtlabel = batch[2].astype('int32')
gtscore = batch[3].astype("float32")
losses = []
outputs = self(img)
self.downsample = 32
for i, out in enumerate(outputs):
anchor_mask = self.anchor_masks[i]
loss = F.yolov3_loss(x=out,
gt_box=gtbox,
gt_label=gtlabel,
gt_score=gtscore,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.class_num,
ignore_thresh=self.ignore_thresh,
downsample_ratio=32,
use_label_smooth=False)
losses.append(paddle.reduce_mean(loss))
self.downsample //= 2
return {'loss': sum(losses)}
def predict(self, imgpath: str, filelist: str, visualization: bool = True, save_path: str = 'result'):
'''
Detect images
Args:
imgpath(str): Image path .
filelist(str): Path to get label name.
visualization(bool): Whether to save result image.
save_path(str) : Path to save detected images.
Returns:
boxes(np.ndarray): Predict box information.
scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels.
'''
boxes = []
scores = []
self.downsample = 32
im = self.transform(imgpath)
h, w, c = Func.img_shape(imgpath)
im_shape = paddle.to_tensor(np.array([[h, w]]).astype('int32'))
label_names = Func.get_label_infos(filelist)
img_data = paddle.to_tensor(np.array([im]).astype('float32'))
outputs = self(img_data)
for i, out in enumerate(outputs):
anchor_mask = self.anchor_masks[i]
mask_anchors = []
for m in anchor_mask:
mask_anchors.append((self.anchors[2 * m]))
mask_anchors.append(self.anchors[2 * m + 1])
box, score = F.yolo_box(x=out,
img_size=im_shape,
anchors=mask_anchors,
class_num=self.class_num,
conf_thresh=self.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box" + str(i))
boxes.append(box)
scores.append(paddle.transpose(score, perm=[0, 2, 1]))
self.downsample //= 2
yolo_boxes = paddle.concat(boxes, axis=1)
yolo_scores = paddle.concat(scores, axis=2)
pred = F.multiclass_nms(bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
bboxes = pred.numpy()
labels = bboxes[:, 0].astype('int32')
scores = bboxes[:, 1].astype('float32')
boxes = bboxes[:, 2:].astype('float32')
if visualization:
Func.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5)
return boxes, scores, labels
class StyleTransferModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
......@@ -228,19 +346,19 @@ class StyleTransferModule(RunModule, ImageServing):
y = self(batch[0])
xc = paddle.to_tensor(batch[0].numpy().copy())
y = subtract_imagenet_mean_batch(y)
xc = subtract_imagenet_mean_batch(xc)
y = Func.subtract_imagenet_mean_batch(y)
xc = Func.subtract_imagenet_mean_batch(xc)
features_y = self.getFeature(y)
features_xc = self.getFeature(xc)
f_xc_c = paddle.to_tensor(features_xc[1].numpy(), stop_gradient=True)
content_loss = mse_loss(features_y[1], f_xc_c)
batch[1] = subtract_imagenet_mean_batch(batch[1])
batch[1] = Func.subtract_imagenet_mean_batch(batch[1])
features_style = self.getFeature(batch[1])
gram_style = [gram_matrix(y) for y in features_style]
gram_style = [Func.gram_matrix(y) for y in features_style]
style_loss = 0.
for m in range(len(features_y)):
gram_y = gram_matrix(features_y[m])
gram_y = Func.gram_matrix(features_y[m])
gram_s = paddle.to_tensor(np.tile(gram_style[m].numpy(), (N, 1, 1, 1)))
style_loss += mse_loss(gram_y, gram_s[:N, :, :])
......
import os
import random
from typing import Callable
import cv2
import numpy as np
import matplotlib
import PIL
from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt
from paddlehub.process.functional import *
matplotlib.use('Agg')
class RandomDistort:
"""
Distort the input image randomly.
Args:
lower(float): The lower bound value for enhancement, default is 0.5.
upper(float): The upper bound value for enhancement, default is 1.5.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def __init__(self, lower: float = 0.5, upper: float = 1.5):
self.lower = lower
self.upper = upper
def random_brightness(self, img: PIL.Image):
e = np.random.uniform(self.lower, self.upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(self, img: PIL.Image):
e = np.random.uniform(self.lower, self.upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(self, img: PIL.Image):
e = np.random.uniform(self.lower, self.upper)
return ImageEnhance.Color(img).enhance(e)
def __call__(self, img: np.ndarray, data: dict):
ops = [self.random_brightness, self.random_contrast, self.random_color]
np.random.shuffle(ops)
img = Image.fromarray(img)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
img = np.asarray(img)
return img, data
class RandomExpand:
"""
Randomly expand images and gt boxes by random ratio. It is a data enhancement operation for model training.
Args:
max_ratio(float): Max value for expansion ratio, default is 4.
fill(list): Initialize the pixel value of the image with the input fill value, default is None.
keep_ratio(bool): Whether image keeps ratio.
thresh(float): If random ratio does not exceed the thresh, return original images and gt boxes, default is 0.5.
Return:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def __init__(self, max_ratio: float = 4., fill: list = None, keep_ratio: bool = True, thresh: float = 0.5):
self.max_ratio = max_ratio
self.fill = fill
self.keep_ratio = keep_ratio
self.thresh = thresh
def __call__(self, img: np.ndarray, data: dict):
gtboxes = data['gt_boxes']
if random.random() > self.thresh:
return img, data
if self.max_ratio < 1.0:
return img, data
h, w, c = img.shape
ratio_x = random.uniform(1, self.max_ratio)
if self.keep_ratio:
ratio_y = ratio_x
else:
ratio_y = random.uniform(1, self.max_ratio)
oh = int(h * ratio_y)
ow = int(w * ratio_x)
off_x = random.randint(0, ow - w)
off_y = random.randint(0, oh - h)
out_img = np.zeros((oh, ow, c))
if self.fill and len(self.fill) == c:
for i in range(c):
out_img[:, :, i] = self.fill[i] * 255.0
out_img[off_y:off_y + h, off_x:off_x + w, :] = img
gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow)
gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh)
gtboxes[:, 2] = gtboxes[:, 2] / ratio_x
gtboxes[:, 3] = gtboxes[:, 3] / ratio_y
data['gt_boxes'] = gtboxes
img = out_img.astype('uint8')
return img, data
class RandomCrop:
"""
Random crop the input image according to constraints.
Args:
scales(list): The value of the cutting area relative to the original area, expressed in the form of \
[min, max]. The default value is [.3, 1.].
max_ratio(float): Max ratio of the original area relative to the cutting area, default is 2.0.
constraints(list): The value of min and max iou values, default is None.
max_trial(int): The max trial for finding a valid crop area. The default value is 50.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def __init__(self,
scales: list = [0.3, 1.0],
max_ratio: float = 2.0,
constraints: list = None,
max_trial: int = 50):
self.scales = scales
self.max_ratio = max_ratio
self.constraints = constraints
self.max_trial = max_trial
def __call__(self, img: np.ndarray, data: dict):
boxes = data['gt_boxes']
labels = data['gt_labels']
scores = data['gt_scores']
if len(boxes) == 0:
return img, data
if not self.constraints:
self.constraints = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0), (0.9, 1.0), (0.0, 1.0)]
img = Image.fromarray(img)
w, h = img.size
crops = [(0, 0, w, h)]
for min_iou, max_iou in self.constraints:
for _ in range(self.max_trial):
scale = random.uniform(self.scales[0], self.scales[1])
aspect_ratio = random.uniform(max(1 / self.max_ratio, scale * scale), \
min(self.max_ratio, 1 / scale / scale))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_x = random.randrange(w - crop_w)
crop_y = random.randrange(h - crop_h)
crop_box = np.array([[(crop_x + crop_w / 2.0) / w, (crop_y + crop_h / 2.0) / h, crop_w / float(w),
crop_h / float(h)]])
iou = box_iou_xywh(crop_box, boxes)
if min_iou <= iou.min() and max_iou >= iou.max():
crops.append((crop_x, crop_y, crop_w, crop_h))
break
while crops:
crop = crops.pop(np.random.randint(0, len(crops)))
crop_boxes, crop_labels, crop_scores, box_num = box_crop(boxes, labels, scores, crop, (w, h))
if box_num < 1:
continue
img = img.crop((crop[0], crop[1], crop[0] + crop[2], crop[1] + crop[3])).resize(img.size, Image.LANCZOS)
img = np.asarray(img)
data['gt_boxes'] = crop_boxes
data['gt_labels'] = crop_labels
data['gt_scores'] = crop_scores
return img, data
img = np.asarray(img)
data['gt_boxes'] = boxes
data['gt_labels'] = labels
data['gt_scores'] = scores
return img, data
class RandomFlip:
"""Flip the images and gt boxes randomly.
Args:
thresh: Probability for random flip.
Returns:
img(np.ndarray): Distorted image.
data(dict): Image info and label info.
"""
def __init__(self, thresh: float = 0.5):
self.thresh = thresh
def __call__(self, img, data):
gtboxes = data['gt_boxes']
if random.random() > self.thresh:
img = img[:, ::-1, :]
gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
data['gt_boxes'] = gtboxes
return img, data
class Compose:
"""
Preprocess the input data according to the operators.
Args:
transforms(list): Preprocessing operators.
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def __init__(self, transforms: list):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
if len(transforms) < 1:
raise ValueError('The length of transforms ' + \
'must be equal or larger than 1!')
self.transforms = transforms
def __call__(self, data: dict):
if isinstance(data, dict):
if isinstance(data['image'], str):
img = cv2.imread(data['image'])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gt_labels = data['gt_labels'].copy()
data['gt_scores'] = np.ones_like(gt_labels)
for op in self.transforms:
img, data = op(img, data)
img = img.transpose((2, 0, 1))
return img, data
if isinstance(data, str):
img = cv2.imread(data)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
for op in self.transforms:
img, data = op(img, data)
img = img.transpose((2, 0, 1))
return img
class Resize:
"""
Resize the input images.
Args:
target_size(int): Targeted input size.
interp(str): Interpolation method.
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def __init__(self, target_size: int = 512, interp: str = 'RANDOM'):
self.interp_dict = {
'NEAREST': cv2.INTER_NEAREST,
'LINEAR': cv2.INTER_LINEAR,
'CUBIC': cv2.INTER_CUBIC,
'AREA': cv2.INTER_AREA,
'LANCZOS4': cv2.INTER_LANCZOS4
}
self.interp = interp
if not (interp == "RANDOM" or interp in self.interp_dict):
raise ValueError("interp should be one of {}".format(self.interp_dict.keys()))
if isinstance(target_size, list) or isinstance(target_size, tuple):
if len(target_size) != 2:
raise TypeError(
'when target is list or tuple, it should include 2 elements, but it is {}'.format(target_size))
elif not isinstance(target_size, int):
raise TypeError("Type of target_size is invalid. Must be Integer or List or tuple, now is {}".format(
type(target_size)))
self.target_size = target_size
def __call__(self, img, data=None):
if self.interp == "RANDOM":
interp = random.choice(list(self.interp_dict.keys()))
else:
interp = self.interp
img = resize(img, self.target_size, self.interp_dict[interp])
if data is not None:
return img, data
else:
return img
class Normalize:
"""
Normalize the input images.
Args:
mean(list): Mean values for normalization, default is [0.5, 0.5, 0.5].
std(list): Standard deviation for normalization, default is [0.5, 0.5, 0.5].
Returns:
img(np.ndarray): Preprocessed image.
data(dict): Image info and label info, default is None.
"""
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean
self.std = std
if not (isinstance(self.mean, list) and isinstance(self.std, list)):
raise ValueError("{}: input type is invalid.".format(self))
from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self))
def __call__(self, im, data=None):
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im = normalize(im, mean, std)
if data is not None:
return im, data
else:
return im
class ShuffleBox:
"""Shuffle detection information for corresponding input image."""
def __call__(self, img, data):
gt = np.concatenate([data['gt_boxes'], data['gt_labels'][:, np.newaxis], data['gt_scores'][:, np.newaxis]],
axis=1)
idx = np.arange(gt.shape[0])
np.random.shuffle(idx)
gt = gt[idx, :]
data['gt_boxes'], data['gt_labels'], data['gt_scores'] = gt[:, :4], gt[:, 4], gt[:, 5]
return img, data
......@@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import paddle
import matplotlib
import numpy as np
from pycocotools.coco import COCO
from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt
matplotlib.use('Agg')
def normalize(im, mean, std):
......@@ -120,6 +124,129 @@ def get_img_file(dir_name: str) -> list:
return images
def box_crop(boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list):
"""Crop the boxes ,labels, scores according to the given shape"""
x, y, w, h = map(float, crop)
im_w, im_h = map(float, img_shape)
boxes = boxes.copy()
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (boxes[:, 0] + boxes[:, 2] / 2) * im_w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (boxes[:, 1] + boxes[:, 3] / 2) * im_h
crop_box = np.array([x, y, x + w, y + h])
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1)
boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2])
boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:])
boxes[:, :2] -= crop_box[:2]
boxes[:, 2:] -= crop_box[:2]
mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1))
boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1)
labels = labels * mask.astype('float32')
scores = scores * mask.astype('float32')
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (boxes[:, 2] - boxes[:, 0]) / w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (boxes[:, 3] - boxes[:, 1]) / h
return boxes, labels, scores, mask.sum()
def box_iou_xywh(box1: np.ndarray, box2: np.ndarray) -> float:
"""Calculate iou by xywh"""
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.minimum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1
inter_h = inter_y2 - inter_y1
inter_w[inter_w < 0] = 0
inter_h[inter_h < 0] = 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
return inter_area / (b1_area + b2_area - inter_area)
def draw_boxes_on_image(image_path: str,
boxes: np.ndarray,
scores: np.ndarray,
labels: np.ndarray,
label_names: list,
score_thresh: float = 0.5):
"""Draw boxes on images."""
image = np.array(Image.open(image_path))
plt.figure()
_, ax = plt.subplots(1)
ax.imshow(image)
image_name = image_path.split('/')[-1]
print("Image {} detect: ".format(image_name))
colors = {}
for box, score, label in zip(boxes, scores, labels):
if score < score_thresh:
continue
if box[2] <= box[0] or box[3] <= box[1]:
continue
label = int(label)
if label not in colors:
colors[label] = plt.get_cmap('hsv')(label / len(label_names))
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label])
ax.add_patch(rect)
ax.text(x1,
y1,
'{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom',
horizontalalignment='left',
bbox={
'facecolor': colors[label],
'alpha': 0.5,
'pad': 0
},
fontsize=8,
color='white')
print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], str(list(map(int, list(box)))), score))
image_name = image_name.replace('jpg', 'png')
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.cla()
plt.close('all')
def img_shape(img_path: str):
"""Get image shape."""
im = cv2.imread(img_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w, c = im.shape
return h, w, c
def get_label_infos(file_list: str):
"""Get label names by corresponding category ids."""
map_label = COCO(file_list)
label_names = []
categories = map_label.loadCats(map_label.getCatIds())
for category in categories:
label_names.append(category['name'])
return label_names
def subtract_imagenet_mean_batch(batch: paddle.Tensor) -> paddle.Tensor:
"""Subtract ImageNet mean pixel-wise from a BGR image."""
mean = np.zeros(shape=batch.shape, dtype='float32')
......
......@@ -12,13 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import copy
from typing import Callable
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
from PIL import Image, ImageEnhance
from paddlehub.process.functional import *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册