提交 5148424a 编写于 作者: H haoyuying

add yolov3_darknet_pascalvoc

上级 9a1eac7b
import paddle
import paddlehub as hub
if __name__ == '__main__':
place = paddle.CUDAPlace(0)
paddle.disable_static()
model = model = hub.Module(name='yolov3_darknet53_pascalvoc', is_train=False)
model.eval()
model.predict(imgpath="/PATH/TO/IMAGE", filelist="/PATH/TO/JSON/FILE")
import paddle
import paddlehub as hub
import paddle.nn as nn
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.pascalvoc import DetectionData
from paddlehub.process.transforms import DetectTrainReader, DetectTestReader
if __name__ == "__main__":
place = paddle.CUDAPlace(0)
paddle.disable_static()
is_train = True
if is_train:
transform = DetectTrainReader()
train_reader = DetectionData(transform)
else:
transform = DetectTestReader()
test_reader = DetectionData(transform)
model = hub.Module(name='yolov3_darknet53_pascalvoc')
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(train_reader, epochs=5, batch_size=4, eval_dataset=train_reader, log_interval=1, save_interval=1)
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal
class ConvBNLayer(nn.Layer):
"""Basic block for Darknet"""
def __init__(self,
ch_in: int,
ch_out: int,
filter_size: int = 3,
stride: int = 1,
groups: int = 1,
padding: int = 0,
act: str = 'leakly',
is_test: bool = False):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2d(ch_in,
ch_out,
filter_size,
padding=padding,
stride=stride,
groups=groups,
weight_attr=paddle.ParamAttr(initializer=Normal(0., 0.02)),
bias_attr=False)
self.batch_norm = nn.BatchNorm(num_channels=ch_out,
is_test=is_test,
param_attr=paddle.ParamAttr(initializer=Normal(0., 0.02),
regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == "leakly":
out = F.leaky_relu(x=out, negative_slope=0.1)
return out
class DownSample(nn.Layer):
"""Downsample block for Darknet"""
def __init__(self,
ch_in: int,
ch_out: int,
filter_size: int = 3,
stride: int = 2,
padding: int = 1,
is_test: bool = False):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(ch_in=ch_in,
ch_out=ch_out,
filter_size=filter_size,
stride=stride,
padding=padding,
is_test=is_test)
self.ch_out = ch_out
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(nn.Layer):
"""Basic residual block for Darknet"""
def __init__(self, ch_in: int, ch_out: int, is_test: bool = False):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(ch_in=ch_in, ch_out=ch_out, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv2 = ConvBNLayer(ch_in=ch_out, ch_out=ch_out * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = paddle.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(nn.Layer):
"""Warp layer composed by basic residual blocks"""
def __init__(self, ch_in: int, ch_out: int, count: int, is_test: bool = False):
super(LayerWarp, self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out, is_test=is_test)
self.res_out_list = []
for i in range(1, count):
res_out = self.add_sublayer("basic_block_%d" % (i), BasicBlock(ch_out * 2, ch_out, is_test=is_test))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
class DarkNet53_conv_body(nn.Layer):
"""Darknet53
Args:
ch_in(int): Input channels, default is 3.
is_test (bool): Set the test mode, default is True.
"""
def __init__(self, ch_in: int = 3, is_test: bool = False):
super(DarkNet53_conv_body, self).__init__()
self.stages = DarkNet_cfg[53]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(ch_in=ch_in, ch_out=32, filter_size=3, stride=1, padding=1, is_test=is_test)
self.downsample0 = DownSample(ch_in=32, ch_out=32 * 2, is_test=is_test)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64, 128, 256, 512, 1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer("stage_%d" % (i),
LayerWarp(int(ch_in[i]), 32 * (2**i), stage, is_test=is_test))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i, DownSample(ch_in=32 * (2**(i + 1)),
ch_out=32 * (2**(i + 2)),
is_test=is_test))
self.downsample_list.append(downsample)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from paddle.regularizer import L2Decay
from pycocotools.coco import COCO
from darknet import DarkNet53_conv_body
from darknet import ConvBNLayer
from paddlehub.module.cv_module import Yolov3Module
from paddlehub.process.transforms import DetectTrainReader, DetectTestReader
from paddlehub.module.module import moduleinfo
class YoloDetectionBlock(nn.Layer):
"""Basic block for Yolov3"""
def __init__(self, ch_in: int, channel: int, is_test: bool = True):
super(YoloDetectionBlock, self).__init__()
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
self.conv0 = ConvBNLayer(ch_in=ch_in, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv1 = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
self.conv2 = ConvBNLayer(ch_in=channel * 2, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv3 = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
self.route = ConvBNLayer(ch_in=channel * 2, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.tip = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
tip = self.tip(route)
return route, tip
class Upsample(nn.Layer):
"""Upsample block for Yolov3"""
def __init__(self, scale: int = 2):
super(Upsample, self).__init__()
self.scale = scale
def forward(self, inputs: paddle.Tensor):
shape_nchw = paddle.to_tensor(inputs.shape)
shape_hw = paddle.slice(shape_nchw, axes=[0], starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = paddle.cast(shape_hw, dtype='int32')
out_shape = in_shape * self.scale
out_shape.stop_gradient = True
out = F.resize_nearest(input=inputs, scale=self.scale, actual_shape=out_shape)
return out
@moduleinfo(name="yolov3_darknet53_pascalvoc",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="Yolov3 is a detection model, this module is trained with VOC dataset.",
version="1.0.0",
meta=Yolov3Module)
class YOLOv3(nn.Layer):
"""YOLOV3 for detection
Args:
ch_in(int): Input channels, default is 3.
class_num(int): Categories for detection,if dataset is voc, class_num is 20.
ignore_thresh(float): The ignore threshold to ignore confidence loss.
valid_thresh(float): Threshold to filter out bounding boxes with low confidence score.
nms_topk(int): Maximum number of detections to be kept according to the confidences after the filtering
detections based on score_threshold.
nms_posk(int): Number of total bboxes to be kept per image after NMS step. -1 means keeping all bboxes after NMS
step.
nms_thresh (float): The threshold to be used in NMS. Default: 0.3.
is_train (bool): Set the train mode, default is True.
load_checkpoint(str): Whether to load checkpoint.
"""
def __init__(self,
ch_in: int = 3,
class_num: int = 20,
ignore_thresh: float = 0.7,
valid_thresh: float = 0.005,
nms_topk: int = 400,
nms_posk: int = 100,
nms_thresh: float = 0.45,
is_train: bool = True,
load_checkpoint: str = None):
super(YOLOv3, self).__init__()
self.is_train = is_train
self.block = DarkNet53_conv_body(ch_in=ch_in, is_test=not self.is_train)
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks_2 = []
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
self.class_num = class_num
self.ignore_thresh = ignore_thresh
self.valid_thresh = valid_thresh
self.nms_topk = nms_topk
self.nms_posk = nms_posk
self.nms_thresh = nms_thresh
ch_in_list = [1024, 768, 384]
for i in range(3):
yolo_block = self.add_sublayer(
"yolo_detecton_block_%d" % (i),
YoloDetectionBlock(ch_in_list[i], channel=512 // (2**i), is_test=not self.is_train))
self.yolo_blocks.append(yolo_block)
num_filters = len(self.anchor_masks[i]) * (self.class_num + 5)
block_out = self.add_sublayer(
"block_out_%d" % (i),
nn.Conv2d(1024 // (2**i),
num_filters,
1,
stride=1,
padding=0,
weight_attr=paddle.ParamAttr(initializer=Normal(0., 0.02)),
bias_attr=paddle.ParamAttr(initializer=Constant(0.0), regularizer=L2Decay(0.))))
self.block_outputs.append(block_out)
if i < 2:
route = self.add_sublayer(
"route2_%d" % i,
ConvBNLayer(ch_in=512 // (2**i),
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not self.is_train)))
self.route_blocks_2.append(route)
self.upsample = Upsample()
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'yolov3_70000.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://bj.bcebos.com/paddlehub/model/image/object_detection/yolov3_70000.pdparams -O ' \
+ checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transform(self, img: paddle.Tensor, size: int):
if self.is_train:
transforms = DetectTrainReader()
else:
transforms = DetectTestReader()
return transforms(img, size)
def get_label_infos(self, file_list: str):
self.COCO = COCO(file_list)
label_names = []
categories = self.COCO.loadCats(self.COCO.getCatIds())
for category in categories:
label_names.append(category['name'])
return label_names
def forward(self,
inputs: paddle.Tensor,
gtbox: paddle.Tensor = None,
gtlabel: paddle.Tensor = None,
gtscore: paddle.Tensor = None,
im_shape: paddle.Tensor = None):
self.gtbox = gtbox
self.gtlabel = gtlabel
self.gtscore = gtscore
self.im_shape = im_shape
self.outputs = []
self.boxes = []
self.scores = []
self.losses = []
self.pred = []
self.downsample = 32
blocks = self.block(inputs)
route = None
for i, block in enumerate(blocks):
if i > 0:
block = paddle.concat([route, block], axis=1)
route, tip = self.yolo_blocks[i](block)
block_out = self.block_outputs[i](tip)
self.outputs.append(block_out)
if i < 2:
route = self.route_blocks_2[i](route)
route = self.upsample(route)
for i, out in enumerate(self.outputs):
anchor_mask = self.anchor_masks[i]
if self.is_train:
loss = F.yolov3_loss(x=out,
gt_box=self.gtbox,
gt_label=self.gtlabel,
gt_score=self.gtscore,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.class_num,
ignore_thresh=self.ignore_thresh,
downsample_ratio=self.downsample,
use_label_smooth=False)
else:
loss = paddle.to_tensor(0.0)
self.losses.append(paddle.reduce_mean(loss))
mask_anchors = []
for m in anchor_mask:
mask_anchors.append((self.anchors[2 * m]))
mask_anchors.append(self.anchors[2 * m + 1])
boxes, scores = F.yolo_box(x=out,
img_size=self.im_shape,
anchors=mask_anchors,
class_num=self.class_num,
conf_thresh=self.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box" + str(i))
self.boxes.append(boxes)
self.scores.append(paddle.transpose(scores, perm=[0, 2, 1]))
self.downsample //= 2
for i in range(self.boxes[0].shape[0]):
yolo_boxes = paddle.unsqueeze(paddle.concat([self.boxes[0][i], self.boxes[1][i], self.boxes[2][i]], axis=0),
0)
yolo_scores = paddle.unsqueeze(
paddle.concat([self.scores[0][i], self.scores[1][i], self.scores[2][i]], axis=1), 0)
pred = F.multiclass_nms(bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
self.pred.append(pred)
return sum(self.losses), self.pred
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Callable
import paddle
from paddlehub.env import DATA_HOME
from pycocotools.coco import COCO
from paddlehub.process.transforms import DetectCatagory, ParseImages
class DetectionData(paddle.io.Dataset):
"""
Dataset for image detection.
Args:
transform(callmethod) : The method of preprocess images.
mode(str): The mode for preparing dataset.
Returns:
DataSet: An iterable object for data iterating
"""
def __init__(self, transform: Callable, size: int = 416, mode: str = 'train'):
self.mode = mode
self.transform = transform
self.size = size
if self.mode == 'train':
train_file_list = 'annotations/instances_train2017.json'
train_data_dir = 'train2017'
self.train_file_list = os.path.join(DATA_HOME, 'voc', train_file_list)
self.train_data_dir = os.path.join(DATA_HOME, 'voc', train_data_dir)
self.COCO = COCO(self.train_file_list)
self.img_dir = self.train_data_dir
elif self.mode == 'test':
val_file_list = 'annotations/instances_val2017.json'
val_data_dir = 'val2017'
self.val_file_list = os.path.join(DATA_HOME, 'voc', val_file_list)
self.val_data_dir = os.path.join(DATA_HOME, 'voc', val_data_dir)
self.COCO = COCO(self.val_file_list)
self.img_dir = self.val_data_dir
parse_dataset_catagory = DetectCatagory(self.COCO, self.img_dir)
self.label_names, self.label_ids, self.category_to_id_map = parse_dataset_catagory()
parse_images = ParseImages(self.COCO, self.mode, self.img_dir, self.category_to_id_map)
self.data = parse_images()
def __getitem__(self, idx: int):
if self.mode == "train":
img = self.data[idx]
out_img, gt_boxes, gt_labels, gt_scores = self.transform(img, 416)
return out_img, gt_boxes, gt_labels, gt_scores
elif self.mode == "test":
img = self.data[idx]
out_img, id, (h, w) = self.transform(img)
return out_img, id, (h, w)
def __len__(self):
return len(self.data)
...@@ -26,7 +26,7 @@ from PIL import Image ...@@ -26,7 +26,7 @@ from PIL import Image
from paddlehub.module.module import serving, RunModule from paddlehub.module.module import serving, RunModule
from paddlehub.utils.utils import base64_to_cv2 from paddlehub.utils.utils import base64_to_cv2
from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize from paddlehub.process.transforms import ConvertColorSpace, ColorPostprocess, Resize, BoxTool
class ImageServing(object): class ImageServing(object):
...@@ -103,11 +103,11 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -103,11 +103,11 @@ class ImageColorizeModule(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict: def training_step(self, batch: int, batch_idx: int) -> dict:
''' '''
One step for training, which should be called as forward computation. One step for training, which should be called as forward computation.
Args: Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels. batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch. batch_idx(int): The index of batch.
Returns: Returns:
results(dict) : The model outputs, such as loss and metrics. results(dict) : The model outputs, such as loss and metrics.
''' '''
...@@ -116,22 +116,22 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -116,22 +116,22 @@ class ImageColorizeModule(RunModule, ImageServing):
def validation_step(self, batch: int, batch_idx: int) -> dict: def validation_step(self, batch: int, batch_idx: int) -> dict:
''' '''
One step for validation, which should be called as forward computation. One step for validation, which should be called as forward computation.
Args: Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels. batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch. batch_idx(int): The index of batch.
Returns: Returns:
results(dict) : The model outputs, such as metrics. results(dict) : The model outputs, such as metrics.
''' '''
out_class, out_reg = self(batch[0], batch[1], batch[2]) out_class, out_reg = self(batch[0], batch[1], batch[2])
criterionCE = nn.loss.CrossEntropyLoss() criterionCE = nn.loss.CrossEntropyLoss()
loss_ce = criterionCE(out_class, batch[4][:, 0, :, :]) loss_ce = criterionCE(out_class, batch[4][:, 0, :, :])
loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True) loss_G_L1_reg = paddle.sum(paddle.abs(batch[3] - out_reg), axis=1, keepdim=True)
loss_G_L1_reg = paddle.mean(loss_G_L1_reg) loss_G_L1_reg = paddle.mean(loss_G_L1_reg)
loss = loss_ce + loss_G_L1_reg loss = loss_ce + loss_G_L1_reg
visual_ret = OrderedDict() visual_ret = OrderedDict()
psnrs = [] psnrs = []
lab2rgb = ConvertColorSpace(mode='LAB2RGB') lab2rgb = ConvertColorSpace(mode='LAB2RGB')
...@@ -141,7 +141,7 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -141,7 +141,7 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_ret['real'] = process(real) visual_ret['real'] = process(real)
fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i] fake = lab2rgb(np.concatenate((batch[0].numpy(), out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = process(fake) visual_ret['fake_reg'] = process(fake)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0) ** 2) mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnr_value = 20 * np.log10(255. / np.sqrt(mse))
psnrs.append(psnr_value) psnrs.append(psnr_value)
psnr = paddle.to_variable(np.array(psnrs)) psnr = paddle.to_variable(np.array(psnrs))
...@@ -150,12 +150,12 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -150,12 +150,12 @@ class ImageColorizeModule(RunModule, ImageServing):
def predict(self, images: str, visualization: bool = True, save_path: str = 'result'): def predict(self, images: str, visualization: bool = True, save_path: str = 'result'):
''' '''
Colorize images Colorize images
Args: Args:
images(str) : Images path to be colorized. images(str) : Images path to be colorized.
visualization(bool): Whether to save colorized images. visualization(bool): Whether to save colorized images.
save_path(str) : Path to save colorized images. save_path(str) : Path to save colorized images.
Returns: Returns:
results(list[dict]) : The prediction result of each input image results(list[dict]) : The prediction result of each input image
''' '''
...@@ -177,7 +177,7 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -177,7 +177,7 @@ class ImageColorizeModule(RunModule, ImageServing):
visual_ret['real'] = resize(process(real)) visual_ret['real'] = resize(process(real))
fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i] fake = lab2rgb(np.concatenate((im['A'], out_reg.numpy()), axis=1))[i]
visual_ret['fake_reg'] = resize(process(fake)) visual_ret['fake_reg'] = resize(process(fake))
if visualization: if visualization:
fake_name = "fake_" + str(time.time()) + ".png" fake_name = "fake_" + str(time.time()) + ".png"
if not os.path.exists(save_path): if not os.path.exists(save_path):
...@@ -185,8 +185,107 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -185,8 +185,107 @@ class ImageColorizeModule(RunModule, ImageServing):
fake_path = os.path.join(save_path, fake_name) fake_path = os.path.join(save_path, fake_name)
visual_gray = Image.fromarray(visual_ret['fake_reg']) visual_gray = Image.fromarray(visual_ret['fake_reg'])
visual_gray.save(fake_path) visual_gray.save(fake_path)
mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0) ** 2) mse = np.mean((visual_ret['real'] * 1.0 - visual_ret['fake_reg'] * 1.0)**2)
psnr_value = 20 * np.log10(255. / np.sqrt(mse)) psnr_value = 20 * np.log10(255. / np.sqrt(mse))
result.append(visual_ret) result.append(visual_ret)
return result return result
class Yolov3Module(RunModule, ImageServing):
def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict): The model outputs, such as loss.
'''
return self.validation_step(batch, batch_idx)
def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images, ground truth boxes, labels and scores.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
ious = []
boxtool = BoxTool()
img = batch[0].astype('float32')
B, C, W, H = img.shape
im_shape = np.array([(W, H)] * B).astype('int32')
im_shape = paddle.to_tensor(im_shape)
gt_box = batch[1].astype('float32')
gt_label = batch[2].astype('int32')
gt_score = batch[3].astype("float32")
loss, pred = self(img, gt_box, gt_label, gt_score, im_shape)
for i in range(len(pred)):
bboxes = pred[i].numpy()
labels = bboxes[:, 0].astype('int32')
scores = bboxes[:, 1].astype('float32')
boxes = bboxes[:, 2:].astype('float32')
iou = []
for j, (box, score, label) in enumerate(zip(boxes, scores, labels)):
x1, y1, x2, y2 = box
w = x2 - x1 + 1
h = y2 - y1 + 1
bbox = [x1, y1, w, h]
bbox = np.expand_dims(boxtool.coco_anno_box_to_center_relative(bbox, H, W), 0)
gt = gt_box[i].numpy()
iou.append(max(boxtool.box_iou_xywh(bbox, gt)))
ious.append(max(iou))
ious = paddle.to_tensor(np.array(ious))
return {'loss': loss, 'metrics': {'iou': ious}}
def predict(self, imgpath: str, filelist: str, visualization: bool = True, save_path: str = 'result'):
'''
Detect images
Args:
imgpath(str): Image path .
filelist(str): Path to get label name.
visualization(bool): Whether to save result image.
save_path(str) : Path to save detected images.
Returns:
boxes(np.ndarray): Predict box information.
scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels.
'''
boxtool = BoxTool()
img = {}
img['image'] = imgpath
img['id'] = 0
im, im_id, im_shape = self.transform(img, 416)
label_names = self.get_label_infos(filelist)
img_data = np.array([im]).astype('float32')
img_data = paddle.to_tensor(img_data)
im_shape = np.array([im_shape]).astype('int32')
im_shape = paddle.to_tensor(im_shape)
output, pred = self(img_data, None, None, None, im_shape)
for i in range(len(pred)):
bboxes = pred[i].numpy()
labels = bboxes[:, 0].astype('int32')
scores = bboxes[:, 1].astype('float32')
boxes = bboxes[:, 2:].astype('float32')
if visualization:
boxtool.draw_boxes_on_image(imgpath, boxes, scores, labels, label_names, 0.5)
return boxes, scores, labels
...@@ -12,16 +12,22 @@ ...@@ -12,16 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import random import random
import copy
from typing import Callable
from collections import OrderedDict from collections import OrderedDict
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image import matplotlib
from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt
from paddlehub.process.functional import * from paddlehub.process.functional import *
matplotlib.use('Agg')
class Compose: class Compose:
def __init__(self, transforms, to_rgb=True, stay_rgb=False): def __init__(self, transforms, to_rgb=True, stay_rgb=False):
...@@ -45,15 +51,13 @@ class Compose: ...@@ -45,15 +51,13 @@ class Compose:
for op in self.transforms: for op in self.transforms:
im = op(im) im = op(im)
if not self.stay_rgb: if not self.stay_rgb:
im = permute(im) im = permute(im)
return im return im
class RandomHorizontalFlip: class RandomHorizontalFlip:
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -239,8 +243,13 @@ class RandomPaddingCrop: ...@@ -239,8 +243,13 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder( im = cv2.copyMakeBorder(im,
im, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=self.im_padding_value) 0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
if crop_height > 0 and crop_width > 0: if crop_height > 0 and crop_width > 0:
h_off = np.random.randint(img_height - crop_height + 1) h_off = np.random.randint(img_height - crop_height + 1)
...@@ -295,13 +304,12 @@ class RandomRotation: ...@@ -295,13 +304,12 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy r[1, 2] += (nh / 2) - cy
dsize = (nw, nh) dsize = (nw, nh)
im = cv2.warpAffine( im = cv2.warpAffine(im,
im, r,
r, dsize=dsize,
dsize=dsize, flags=cv2.INTER_LINEAR,
flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,
borderMode=cv2.BORDER_CONSTANT, borderValue=self.im_padding_value)
borderValue=self.im_padding_value)
return im return im
...@@ -403,14 +411,14 @@ class RandomDistort: ...@@ -403,14 +411,14 @@ class RandomDistort:
return im return im
class ConvertColorSpace: class ConvertColorSpace:
""" """
Convert color space from RGB to LAB or from LAB to RGB. Convert color space from RGB to LAB or from LAB to RGB.
Args: Args:
mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'. mode(str): Color space convert mode, it can be 'RGB2LAB' or 'LAB2RGB'.
Return: Return:
img(np.ndarray): converted image. img(np.ndarray): converted image.
""" """
...@@ -429,7 +437,7 @@ class ConvertColorSpace: ...@@ -429,7 +437,7 @@ class ConvertColorSpace:
""" """
mask = (rgb > 0.04045) mask = (rgb > 0.04045)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
rgb = (((rgb + .055) / 1.055) ** 2.4) * mask + rgb / 12.92 * (1 - mask) rgb = (((rgb + .055) / 1.055)**2.4) * mask + rgb / 12.92 * (1 - mask)
rgb = np.nan_to_num(rgb) rgb = np.nan_to_num(rgb)
x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :] x = .412453 * rgb[:, 0, :, :] + .357580 * rgb[:, 1, :, :] + .180423 * rgb[:, 2, :, :]
y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :] y = .212671 * rgb[:, 0, :, :] + .715160 * rgb[:, 1, :, :] + .072169 * rgb[:, 2, :, :]
...@@ -490,7 +498,7 @@ class ConvertColorSpace: ...@@ -490,7 +498,7 @@ class ConvertColorSpace:
rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs rgb = np.maximum(rgb, 0) # sometimes reaches a small negative number, which causes NaNs
mask = (rgb > .0031308).astype(np.float32) mask = (rgb > .0031308).astype(np.float32)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
out = (1.055 * (rgb ** (1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask) out = (1.055 * (rgb**(1. / 2.4)) - 0.055) * mask + 12.92 * rgb * (1 - mask)
out = np.nan_to_num(out) out = np.nan_to_num(out)
return out return out
...@@ -511,7 +519,7 @@ class ConvertColorSpace: ...@@ -511,7 +519,7 @@ class ConvertColorSpace:
out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1) out = np.concatenate((x_int[:, None, :, :], y_int[:, None, :, :], z_int[:, None, :, :]), axis=1)
mask = (out > .2068966).astype(np.float32) mask = (out > .2068966).astype(np.float32)
np.seterr(invalid='ignore') np.seterr(invalid='ignore')
out = (out ** 3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask) out = (out**3.) * mask + (out - 16. / 116.) / 7.787 * (1 - mask)
out = np.nan_to_num(out) out = np.nan_to_num(out)
sc = np.array((0.95047, 1., 1.08883))[None, :, None, None] sc = np.array((0.95047, 1., 1.08883))[None, :, None, None]
out = out * sc out = out * sc
...@@ -546,27 +554,37 @@ class ConvertColorSpace: ...@@ -546,27 +554,37 @@ class ConvertColorSpace:
class ColorizeHint: class ColorizeHint:
"""Get hint and mask images for colorization. """Get hint and mask images for colorization.
This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process. This method is prepared for user guided colorization tasks. Take the original RGB images as imput, we will obtain the local hints and correspoding mask to guid colorization process.
Args: Args:
percent(float): Probability for ignoring hint in an iteration. percent(float): Probability for ignoring hint in an iteration.
num_points(int): Number of selected hints in an iteration. num_points(int): Number of selected hints in an iteration.
samp(str): Sample method, default is normal. samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area. use_avg(bool): Whether to use mean in selected hint area.
Return: Return:
hint(np.ndarray): hint images hint(np.ndarray): hint images
mask(np.ndarray): mask images mask(np.ndarray): mask images
""" """
def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True): def __init__(self, percent: float, num_points: int = None, samp: str = 'normal', use_avg: bool = True):
self.percent = percent self.percent = percent
self.num_points = num_points self.num_points = num_points
self.samp = samp self.samp = samp
self.use_avg = use_avg self.use_avg = use_avg
def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray): def __call__(self, data: np.ndarray, hint: np.ndarray, mask: np.ndarray):
sample_Ps = [1, 2, 3, 4, 5, 6, 7, 8, 9, ] sample_Ps = [
1,
2,
3,
4,
5,
6,
7,
8,
9,
]
self.data = data self.data = data
self.hint = hint self.hint = hint
self.mask = mask self.mask = mask
...@@ -593,9 +611,11 @@ class ColorizeHint: ...@@ -593,9 +611,11 @@ class ColorizeHint:
# add color point # add color point
if self.use_avg: if self.use_avg:
# embed() # embed()
hint[nn, :, h:h + P, w:w + P] = np.mean( hint[nn, :, h:h + P, w:w + P] = np.mean(np.mean(data[nn, :, h:h + P, w:w + P],
np.mean(data[nn, :, h:h + P, w:w + P], axis=2, keepdims=True), axis=1, keepdims=True).reshape( axis=2,
1, C, 1, 1) keepdims=True),
axis=1,
keepdims=True).reshape(1, C, 1, 1)
else: else:
hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P] hint[nn, :, h:h + P, w:w + P] = data[nn, :, h:h + P, w:w + P]
mask[nn, :, h:h + P, w:w + P] = 1 mask[nn, :, h:h + P, w:w + P] = 1
...@@ -609,10 +629,10 @@ class ColorizeHint: ...@@ -609,10 +629,10 @@ class ColorizeHint:
class SqueezeAxis: class SqueezeAxis:
""" """
Squeeze the specific axis when it equal to 1. Squeeze the specific axis when it equal to 1.
Args: Args:
axis(int): Which axis should be squeezed. axis(int): Which axis should be squeezed.
""" """
def __init__(self, axis: int): def __init__(self, axis: int):
self.axis = axis self.axis = axis
...@@ -628,7 +648,7 @@ class SqueezeAxis: ...@@ -628,7 +648,7 @@ class SqueezeAxis:
class ColorizePreprocess: class ColorizePreprocess:
"""Prepare dataset for image Colorization. """Prepare dataset for image Colorization.
Args: Args:
ab_thresh(float): Thresh value for setting mask value. ab_thresh(float): Thresh value for setting mask value.
p(float): Probability for ignoring hint in an iteration. p(float): Probability for ignoring hint in an iteration.
...@@ -636,12 +656,13 @@ class ColorizePreprocess: ...@@ -636,12 +656,13 @@ class ColorizePreprocess:
samp(str): Sample method, default is normal. samp(str): Sample method, default is normal.
use_avg(bool): Whether to use mean in selected hint area. use_avg(bool): Whether to use mean in selected hint area.
is_train(bool): Training process or not. is_train(bool): Training process or not.
Return: Return:
data(dict):The preprocessed data for colorization. data(dict):The preprocessed data for colorization.
""" """
def __init__(self, ab_thresh: float = 0., def __init__(self,
ab_thresh: float = 0.,
p: float = .125, p: float = .125,
num_points: int = None, num_points: int = None,
samp: str = 'normal', samp: str = 'normal',
...@@ -668,11 +689,14 @@ class ColorizePreprocess: ...@@ -668,11 +689,14 @@ class ColorizePreprocess:
""" """
data = {} data = {}
A = 2 * 110 / 10 + 1 A = 2 * 110 / 10 + 1
data['A'] = data_lab[:, [0, ], :, :] data['A'] = data_lab[:, [
0,
], :, :]
data['B'] = data_lab[:, 1:, :, :] data['B'] = data_lab[:, 1:, :, :]
if self.ab_thresh > 0: # mask out grayscale images if self.ab_thresh > 0: # mask out grayscale images
thresh = 1. * self.ab_thresh / 110 thresh = 1. * self.ab_thresh / 110
mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),axis=1) mask = np.sum(np.abs(np.max(np.max(data['B'], axis=3), axis=2) - np.min(np.min(data['B'], axis=3), axis=2)),
axis=1)
mask = (mask >= thresh) mask = (mask >= thresh)
data['A'] = data['A'][mask, :, :, :] data['A'] = data['A'][mask, :, :, :]
data['B'] = data['B'][mask, :, :, :] data['B'] = data['B'][mask, :, :, :]
...@@ -698,10 +722,10 @@ class ColorizePreprocess: ...@@ -698,10 +722,10 @@ class ColorizePreprocess:
class ColorPostprocess: class ColorPostprocess:
""" """
Transform images from [0, 1] to [0, 255] Transform images from [0, 1] to [0, 255]
Args: Args:
type(type): Type of Image value. type(type): Type of Image value.
Return: Return:
img(np.ndarray): Image in range of 0-255. img(np.ndarray): Image in range of 0-255.
""" """
...@@ -713,3 +737,505 @@ class ColorPostprocess: ...@@ -713,3 +737,505 @@ class ColorPostprocess:
img = np.clip(img, 0, 1) * 255 img = np.clip(img, 0, 1) * 255
img = img.astype(self.type) img = img.astype(self.type)
return img return img
class DetectCatagory:
"""Load label name, id and map from detection dataset.
Args:
COCO(Callable): Method for get detection attributes for images.
data_dir(str): Image dataset path.
Returns:
label_names(List(str)): The dataset label names.
label_ids(List(int)): The dataset label ids.
category_to_id_map(dict): Mapping relations of category and id for images.
"""
def __init__(self, COCO: Callable, data_dir: str):
self.COCO = COCO
self.img_dir = data_dir
def __call__(self):
self.categories = self.COCO.loadCats(self.COCO.getCatIds())
self.num_category = len(self.categories)
label_names = []
label_ids = []
for category in self.categories:
label_names.append(category['name'])
label_ids.append(int(category['id']))
category_to_id_map = {v: i for i, v in enumerate(label_ids)}
return label_names, label_ids, category_to_id_map
class ParseImages:
"""Prepare images for detection.
Args:
COCO(Callable): Method for get detection attributes for images.
is_train(bool): Select the mode for train or test.
data_dir(str): Image dataset path.
category_to_id_map(dict): Mapping relations of category and id for images.
Returns:
imgs(dict): The input for detection model, it is a dict.
"""
def __init__(self, COCO: Callable, is_train: bool, data_dir: str, category_to_id_map: dict):
self.COCO = COCO
self.is_train = is_train
self.img_dir = data_dir
self.category_to_id_map = category_to_id_map
self.parse_gt_annotations = GTAnotations(self.COCO, self.category_to_id_map)
def __call__(self):
image_ids = self.COCO.getImgIds()
image_ids.sort()
imgs = copy.deepcopy(self.COCO.loadImgs(image_ids))
for img in imgs:
img['image'] = os.path.join(self.img_dir, img['file_name'])
assert os.path.exists(img['image']), \
"image {} not found.".format(img['image'])
box_num = 50
img['gt_boxes'] = np.zeros((box_num, 4), dtype=np.float32)
img['gt_labels'] = np.zeros((box_num), dtype=np.int32)
if self.is_train:
img = self.parse_gt_annotations(img)
return imgs
class GTAnotations:
"""Set gt boxes and gt labels for train.
Args:
COCO(Callable): Method for get detection attributes for images.
category_to_id_map(dict): Mapping relations of category and id for images.
img(dict): Input for detection model.
Returns:
img(dict): Set specific value on the attributes of 'gt boxes' and 'gt labels' for input.
"""
def __init__(self, COCO: Callable, category_to_id_map: dict):
self.COCO = COCO
self.category_to_id_map = category_to_id_map
self.boxtool = BoxTool()
def __call__(self, img: dict):
img_height = img['height']
img_width = img['width']
anno = self.COCO.loadAnns(self.COCO.getAnnIds(imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < -1:
continue
if 'ignore' in target and target['ignore']:
continue
box = self.boxtool.coco_anno_box_to_center_relative(target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
img['gt_boxes'][gt_index] = box
img['gt_labels'][gt_index] = \
self.category_to_id_map[target['category_id']]
gt_index += 1
if gt_index >= 50:
break
return img
class DetectTestReader:
"""Preprocess for detection dataset on test mode.
Args:
mean(list): Mean values for normalization, default is [0.485, 0.456, 0.406].
std(list): Standard deviation for normalization, default is [0.229, 0.224, 0.225].
img(dict): Prepared input for detection model.
size(int): Image size for detection.
Returns:
out_img(np.ndarray): Normalized image, shape is [C, H, W].
id(int): Id number for corresponding out_img.
(h, w)(tuple): height and weight for corresponding out_img.
"""
def __init__(self, mean: list = [0.485, 0.456, 0.406], std: list = [0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
def __call__(self, img, size):
im_path = img['image']
im = cv2.imread(im_path).astype('float32')
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w, _ = im.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
out_img = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=cv2.INTER_CUBIC)
mean = np.array(self.mean).reshape((1, 1, -1))
std = np.array(self.std).reshape((1, 1, -1))
out_img = (out_img / 255.0 - mean) / std
out_img = out_img.transpose((2, 0, 1))
id = int(img['id'])
return out_img, id, (h, w)
class DetectTrainReader:
"""Preprocess for detection dataset on train mode.
Args:
mean(list): Mean values for normalization, default is [0.485, 0.456, 0.406].
std(list): Standard deviation for normalization, default is [0.229, 0.224, 0.225].
img(dict): Prepared input for detection model.
size(int): Image size for detection.
Returns:
out_img(np.ndarray): Normalized image, shape is [C, H, W].
gt_boxes(np.ndarray): Ground truth boxes information.
gt_labels(np.ndarray): Ground truth labels.
gt_scores(np.ndarray): Ground truth scores.
"""
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
self.mean = mean
self.std = std
self.boxtool = BoxTool()
def __call__(self, img, size):
im_path = img['image']
im = cv2.imread(im_path)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
gt_boxes = img['gt_boxes'].copy()
gt_labels = img['gt_labels'].copy()
gt_scores = np.ones_like(gt_labels)
im, gt_boxes, gt_labels, gt_scores = self.boxtool.image_augment(im, gt_boxes, gt_labels, gt_scores, size,
self.mean)
mean = np.array(self.mean).reshape((1, 1, -1))
std = np.array(self.std).reshape((1, 1, -1))
out_img = (im / 255.0 - mean) / std
out_img = out_img.astype('float32').transpose((2, 0, 1))
return out_img, gt_boxes, gt_labels, gt_scores
class BoxTool:
"""This class provides common methods for box processing in detection tasks."""
def __init__(self):
super(BoxTool, self).__init__()
def coco_anno_box_to_center_relative(self, box: list, img_height: int, img_width: int) -> np.ndarray:
"""
Convert COCO annotations box with format [x1, y1, w, h] to
center mode [center_x, center_y, w, h] and divide image width
and height to get relative value in range[0, 1]
"""
assert len(box) == 4, "box should be a len(4) list or tuple"
x, y, w, h = box
x1 = max(x, 0)
x2 = min(x + w - 1, img_width - 1)
y1 = max(y, 0)
y2 = min(y + h - 1, img_height - 1)
x = (x1 + x2) / 2 / img_width
y = (y1 + y2) / 2 / img_height
w = (x2 - x1) / img_width
h = (y2 - y1) / img_height
return np.array([x, y, w, h])
def clip_relative_box_in_image(self, x: int, y: int, w: int, h: int) -> int:
"""Clip relative box coordinates x, y, w, h to [0, 1]"""
x1 = max(x - w / 2, 0.)
x2 = min(x + w / 2, 1.)
y1 = min(y - h / 2, 0.)
y2 = max(y + h / 2, 1.)
x = (x1 + x2) / 2
y = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
return x, y, w, h
def box_xywh_to_xyxy(self, box: np.ndarray) -> np.ndarray:
"""Change box from xywh to xyxy"""
shape = box.shape
assert shape[-1] == 4, "Box shape[-1] should be 4."
box = box.reshape((-1, 4))
box[:, 0], box[:, 2] = box[:, 0] - box[:, 2] / 2, box[:, 0] + box[:, 2] / 2
box[:, 1], box[:, 3] = box[:, 1] - box[:, 3] / 2, box[:, 1] + box[:, 3] / 2
box = box.reshape(shape)
return box
def box_iou_xywh(self, box1: np.ndarray, box2: np.ndarray) -> float:
"""Calculate iou by xywh"""
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.minimum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1
inter_h = inter_y2 - inter_y1
inter_w[inter_w < 0] = 0
inter_h[inter_h < 0] = 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
return inter_area / (b1_area + b2_area - inter_area)
def box_iou_xyxy(self, box1: np.ndarray, box2: np.ndarray) -> float:
"""Calculate iou by xyxy"""
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_x1 = np.maximum(b1_x1, b2_x1)
inter_x2 = np.minimum(b1_x2, b2_x2)
inter_y1 = np.maximum(b1_y1, b2_y1)
inter_y2 = np.minimum(b1_y2, b2_y2)
inter_w = inter_x2 - inter_x1
inter_h = inter_y2 - inter_y1
inter_w[inter_w < 0] = 0
inter_h[inter_h < 0] = 0
inter_area = inter_w * inter_h
b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
return inter_area / (b1_area + b2_area - inter_area)
def box_crop(self, boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, crop: list, img_shape: list):
"""Crop the boxes ,labels, scores according to the given shape"""
x, y, w, h = map(float, crop)
im_w, im_h = map(float, img_shape)
boxes = boxes.copy()
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (boxes[:, 0] + boxes[:, 2] / 2) * im_w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (boxes[:, 1] + boxes[:, 3] / 2) * im_h
crop_box = np.array([x, y, x + w, y + h])
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1)
boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2])
boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:])
boxes[:, :2] -= crop_box[:2]
boxes[:, 2:] -= crop_box[:2]
mask = np.logical_and(mask, (boxes[:, :2] < boxes[:, 2:]).all(axis=1))
boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1)
labels = labels * mask.astype('float32')
scores = scores * mask.astype('float32')
boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (boxes[:, 2] - boxes[:, 0]) / w
boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (boxes[:, 3] - boxes[:, 1]) / h
return boxes, labels, scores, mask.sum()
def random_distort(self, img):
""" Distort the input image randomly."""
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = Image.fromarray(img)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
img = np.asarray(img)
return img
def random_crop(self, img, boxes, labels, scores, scales=[0.3, 1.0], max_ratio=2.0, constraints=None, max_trial=50):
"""Random crop the input image according to constraints."""
if len(boxes) == 0:
return img, boxes
if not constraints:
constraints = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0), (0.9, 1.0), (0.0, 1.0)]
img = Image.fromarray(img)
w, h = img.size
crops = [(0, 0, w, h)]
for min_iou, max_iou in constraints:
for _ in range(max_trial):
scale = random.uniform(scales[0], scales[1])
aspect_ratio = random.uniform(max(1 / max_ratio, scale * scale), \
min(max_ratio, 1 / scale / scale))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_x = random.randrange(w - crop_w)
crop_y = random.randrange(h - crop_h)
crop_box = np.array([[(crop_x + crop_w / 2.0) / w, (crop_y + crop_h / 2.0) / h, crop_w / float(w),
crop_h / float(h)]])
iou = self.box_iou_xywh(crop_box, boxes)
if min_iou <= iou.min() and max_iou >= iou.max():
crops.append((crop_x, crop_y, crop_w, crop_h))
break
while crops:
crop = crops.pop(np.random.randint(0, len(crops)))
crop_boxes, crop_labels, crop_scores, box_num = \
self.box_crop(boxes, labels, scores, crop, (w, h))
if box_num < 1:
continue
img = img.crop((crop[0], crop[1], crop[0] + crop[2], crop[1] + crop[3])).resize(img.size, Image.LANCZOS)
img = np.asarray(img)
return img, crop_boxes, crop_labels, crop_scores
img = np.asarray(img)
return img, boxes, labels, scores
def random_flip(self, img, gtboxes, thresh=0.5):
"""Flip the images randomly"""
if random.random() > thresh:
img = img[:, ::-1, :]
gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
return img, gtboxes
def random_interp(self, img, size, interp=None):
interp_method = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
]
if not interp or interp not in interp_method:
interp = interp_method[random.randint(0, len(interp_method) - 1)]
h, w, _ = img.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
return img
def random_expand(self, img, gtboxes, max_ratio=4., fill=None, keep_ratio=True, thresh=0.5):
"""Expand input image and ground truth box by random ratio."""
if random.random() > thresh:
return img, gtboxes
if max_ratio < 1.0:
return img, gtboxes
h, w, c = img.shape
ratio_x = random.uniform(1, max_ratio)
if keep_ratio:
ratio_y = ratio_x
else:
ratio_y = random.uniform(1, max_ratio)
oh = int(h * ratio_y)
ow = int(w * ratio_x)
off_x = random.randint(0, ow - w)
off_y = random.randint(0, oh - h)
out_img = np.zeros((oh, ow, c))
if fill and len(fill) == c:
for i in range(c):
out_img[:, :, i] = fill[i] * 255.0
out_img[off_y:off_y + h, off_x:off_x + w, :] = img
gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow)
gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh)
gtboxes[:, 2] = gtboxes[:, 2] / ratio_x
gtboxes[:, 3] = gtboxes[:, 3] / ratio_y
return out_img.astype('uint8'), gtboxes
def shuffle_gtbox(self, gtbox, gtlabel, gtscore):
"""Shuffle gt box."""
gt = np.concatenate([gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1)
idx = np.arange(gt.shape[0])
np.random.shuffle(idx)
gt = gt[idx, :]
return gt[:, :4], gt[:, 4], gt[:, 5]
def image_augment(self, img, gtboxes, gtlabels, gtscores, size, means=None):
"""Random processes for input image."""
img = self.random_distort(img)
img, gtboxes = self.random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = \
self.random_crop(img, gtboxes, gtlabels, gtscores)
img = self.random_interp(img, size)
img, gtboxes = self.random_flip(img, gtboxes)
gtboxes, gtlabels, gtscores = self.shuffle_gtbox(gtboxes, gtlabels, gtscores)
return img.astype('float32'), gtboxes.astype('float32'), \
gtlabels.astype('int32'), gtscores.astype('float32')
def draw_boxes_on_image(self,
image_path: str,
boxes: np.ndarray,
scores: np.ndarray,
labels: np.ndarray,
label_names: list,
score_thresh: float = 0.5):
"""Draw boxes on images"""
image = np.array(Image.open(image_path))
plt.figure()
_, ax = plt.subplots(1)
ax.imshow(image)
image_name = image_path.split('/')[-1]
print("Image {} detect: ".format(image_name))
colors = {}
for box, score, label in zip(boxes, scores, labels):
if score < score_thresh:
continue
if box[2] <= box[0] or box[3] <= box[1]:
continue
label = int(label)
if label not in colors:
colors[label] = plt.get_cmap('hsv')(label / len(label_names))
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, linewidth=2.0, edgecolor=colors[label])
ax.add_patch(rect)
ax.text(x1,
y1,
'{} {:.4f}'.format(label_names[label], score),
verticalalignment='bottom',
horizontalalignment='left',
bbox={
'facecolor': colors[label],
'alpha': 0.5,
'pad': 0
},
fontsize=8,
color='white')
print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(label)], str(list(map(int, list(box)))),
score))
image_name = image_name.replace('jpg', 'png')
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.cla()
plt.close('all')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册