未验证 提交 69cc99f9 编写于 作者: W wangxinxin08 提交者: GitHub

add reference of some code and remove some code (#4467)

上级 8ebccc9f
...@@ -436,7 +436,5 @@ python tools/anchor_cluster.py -c configs/ppyolo/ppyolo.yml -n 9 -s 608 -m v2 -i ...@@ -436,7 +436,5 @@ python tools/anchor_cluster.py -c configs/ppyolo/ppyolo.yml -n 9 -s 608 -m v2 -i
| -c/--config | 模型的配置文件 | 无默认值 | 必须指定 | | -c/--config | 模型的配置文件 | 无默认值 | 必须指定 |
| -n/--n | 聚类的簇数 | 9 | Anchor的数目 | | -n/--n | 聚类的簇数 | 9 | Anchor的数目 |
| -s/--size | 图片的输入尺寸 | None | 若指定,则使用指定的尺寸,如果不指定, 则尝试从配置文件中读取图片尺寸 | | -s/--size | 图片的输入尺寸 | None | 若指定,则使用指定的尺寸,如果不指定, 则尝试从配置文件中读取图片尺寸 |
| -m/--method | 使用的Anchor聚类方法 | v2 | 目前只支持yolov2/v5的聚类算法 | | -m/--method | 使用的Anchor聚类方法 | v2 | 目前只支持yolov2的聚类算法 |
| -i/--iters | kmeans聚类算法的迭代次数 | 1000 | kmeans算法收敛或者达到迭代次数后终止 | | -i/--iters | kmeans聚类算法的迭代次数 | 1000 | kmeans算法收敛或者达到迭代次数后终止 |
| -gi/--gen_iters | 遗传算法的迭代次数 | 1000 | 该参数只用于yolov5的Anchor聚类算法 |
| -t/--thresh| Anchor尺度的阈值 | 0.25 | 该参数只用于yolov5的Anchor聚类算法 |
...@@ -464,65 +464,6 @@ def gaussian2D(shape, sigma_x=1, sigma_y=1): ...@@ -464,65 +464,6 @@ def gaussian2D(shape, sigma_x=1, sigma_y=1):
return h return h
def transform_bbox(sample,
M,
w,
h,
area_thr=0.25,
wh_thr=2,
ar_thr=20,
perspective=False):
"""
transfrom bbox according to tranformation matrix M,
refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
"""
bbox = sample['gt_bbox']
label = sample['gt_class']
# rotate bbox
n = len(bbox)
xy = np.ones((n * 4, 3), dtype=np.float32)
xy[:, :2] = bbox[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)
# xy = xy @ M.T
xy = np.matmul(xy, M.T)
if perspective:
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8)
else:
xy = xy[:, :2].reshape(n, 8)
# get new bboxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
bbox = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
mask = filter_bbox(bbox, w, h, area_thr)
sample['gt_bbox'] = bbox[mask]
sample['gt_class'] = sample['gt_class'][mask]
if 'is_crowd' in sample:
sample['is_crowd'] = sample['is_crowd'][mask]
if 'difficult' in sample:
sample['difficult'] = sample['difficult'][mask]
return sample
def filter_bbox(bbox, w, h, area_thr=0.25, wh_thr=2, ar_thr=20):
"""
filter bbox, refer to https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
"""
# clip boxes
area1 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
bbox[:, [0, 2]] = bbox[:, [0, 2]].clip(0, w)
bbox[:, [1, 3]] = bbox[:, [1, 3]].clip(0, h)
# compute
area2 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
area_ratio = area2 / (area1 + 1e-16)
wh = bbox[:, 2:4] - bbox[:, 0:2]
ar_ratio = np.maximum(wh[:, 1] / (wh[:, 0] + 1e-16),
wh[:, 0] / (wh[:, 1] + 1e-16))
mask = (area_ratio > area_thr) & (
(wh > wh_thr).all(1)) & (ar_ratio < ar_thr)
return mask
def draw_umich_gaussian(heatmap, center, radius, k=1): def draw_umich_gaussian(heatmap, center, radius, k=1):
""" """
draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126 draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126
......
...@@ -48,7 +48,7 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process, ...@@ -48,7 +48,7 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling, generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling, satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling, generate_sample_bbox_square, bbox_area_sampling,
is_poly, transform_bbox, get_border) is_poly, get_border)
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
from ppdet.modeling.keypoint_utils import get_affine_transform, affine_transform from ppdet.modeling.keypoint_utils import get_affine_transform, affine_transform
...@@ -2476,6 +2476,9 @@ class RandomSelect(BaseOperator): ...@@ -2476,6 +2476,9 @@ class RandomSelect(BaseOperator):
""" """
Randomly choose a transformation between transforms1 and transforms2, Randomly choose a transformation between transforms1 and transforms2,
and the probability of choosing transforms1 is p. and the probability of choosing transforms1 is p.
The code is based on https://github.com/facebookresearch/detr/blob/main/datasets/transforms.py
""" """
def __init__(self, transforms1, transforms2, p=0.5): def __init__(self, transforms1, transforms2, p=0.5):
...@@ -2833,6 +2836,10 @@ class WarpAffine(BaseOperator): ...@@ -2833,6 +2836,10 @@ class WarpAffine(BaseOperator):
shift=0.1): shift=0.1):
"""WarpAffine """WarpAffine
Warp affine the image Warp affine the image
The code is based on https://github.com/xingyizhou/CenterNet/blob/master/src/lib/datasets/sample/ctdet.py
""" """
super(WarpAffine, self).__init__() super(WarpAffine, self).__init__()
self.keep_res = keep_res self.keep_res = keep_res
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); //
you may not use this file except in compliance with the License. // Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 //
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software //
distributed under the License is distributed on an "AS IS" BASIS, // Unless required by applicable law or agreed to in writing, software
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License. */ // See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "rbox_iou_op.h" #include "rbox_iou_op.h"
#include "paddle/extension.h" #include "paddle/extension.h"
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); //
you may not use this file except in compliance with the License. // Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 //
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software //
distributed under the License is distributed on an "AS IS" BASIS, // Unless required by applicable law or agreed to in writing, software
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License. */ // See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#include "rbox_iou_op.h" #include "rbox_iou_op.h"
#include "paddle/extension.h" #include "paddle/extension.h"
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); //
you may not use this file except in compliance with the License. // Licensed under the Apache License, Version 2.0 (the "License");
You may obtain a copy of the License at // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 //
// http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software //
distributed under the License is distributed on an "AS IS" BASIS, // Unless required by applicable law or agreed to in writing, software
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
limitations under the License. */ // See the License for the specific language governing permissions and
// limitations under the License.
//
// The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
#pragma once #pragma once
......
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#
# The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/models/anchor_heads_rotated/s2anet_head.py
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
...@@ -625,7 +628,8 @@ class S2ANetHead(nn.Layer): ...@@ -625,7 +628,8 @@ class S2ANetHead(nn.Layer):
fam_bbox_total = self.gwd_loss(fam_bbox_decode, fam_bbox_total = self.gwd_loss(fam_bbox_decode,
bbox_gt_bboxes_level) bbox_gt_bboxes_level)
fam_bbox_total = fam_bbox_total * feat_bbox_weights fam_bbox_total = fam_bbox_total * feat_bbox_weights
fam_bbox_total = paddle.sum(fam_bbox_total) / num_total_samples fam_bbox_total = paddle.sum(
fam_bbox_total) / num_total_samples
fam_bbox_losses.append(fam_bbox_total) fam_bbox_losses.append(fam_bbox_total)
st_idx += feat_anchor_num st_idx += feat_anchor_num
...@@ -739,7 +743,8 @@ class S2ANetHead(nn.Layer): ...@@ -739,7 +743,8 @@ class S2ANetHead(nn.Layer):
odm_bbox_total = self.gwd_loss(odm_bbox_decode, odm_bbox_total = self.gwd_loss(odm_bbox_decode,
bbox_gt_bboxes_level) bbox_gt_bboxes_level)
odm_bbox_total = odm_bbox_total * feat_bbox_weights odm_bbox_total = odm_bbox_total * feat_bbox_weights
odm_bbox_total = paddle.sum(odm_bbox_total) / num_total_samples odm_bbox_total = paddle.sum(
odm_bbox_total) / num_total_samples
odm_bbox_losses.append(odm_bbox_total) odm_bbox_losses.append(odm_bbox_total)
st_idx += feat_anchor_num st_idx += feat_anchor_num
......
...@@ -180,7 +180,7 @@ class CoordConv(nn.Layer): ...@@ -180,7 +180,7 @@ class CoordConv(nn.Layer):
name='', name='',
data_format='NCHW'): data_format='NCHW'):
""" """
CoordConv layer CoordConv layer, see https://arxiv.org/abs/1807.03247
Args: Args:
ch_in (int): input channel ch_in (int): input channel
......
...@@ -31,10 +31,8 @@ python tools/anchor_cluster.py -c ${config} -m ${method} -s ${size} ...@@ -31,10 +31,8 @@ python tools/anchor_cluster.py -c ${config} -m ${method} -s ${size}
| -c/--config | 模型的配置文件 | 无默认值 | 必须指定 | | -c/--config | 模型的配置文件 | 无默认值 | 必须指定 |
| -n/--n | 聚类的簇数 | 9 | Anchor的数目 | | -n/--n | 聚类的簇数 | 9 | Anchor的数目 |
| -s/--size | 图片的输入尺寸 | None | 若指定,则使用指定的尺寸,如果不指定, 则尝试从配置文件中读取图片尺寸 | | -s/--size | 图片的输入尺寸 | None | 若指定,则使用指定的尺寸,如果不指定, 则尝试从配置文件中读取图片尺寸 |
| -m/--method | 使用的Anchor聚类方法 | v2 | 目前只支持yolov2/v5的聚类算法 | | -m/--method | 使用的Anchor聚类方法 | v2 | 目前只支持yolov2的聚类算法 |
| -i/--iters | kmeans聚类算法的迭代次数 | 1000 | kmeans算法收敛或者达到迭代次数后终止 | | -i/--iters | kmeans聚类算法的迭代次数 | 1000 | kmeans算法收敛或者达到迭代次数后终止 |
| -gi/--gen_iters | 遗传算法的迭代次数 | 1000 | 该参数只用于yolov5的Anchor聚类算法 |
| -t/--thresh| Anchor尺度的阈值 | 0.25 | 该参数只用于yolov5的Anchor聚类算法 |
## 模型库 ## 模型库
下表中展示了当前支持的网络结构。 下表中展示了当前支持的网络结构。
......
...@@ -139,10 +139,8 @@ python tools/anchor_cluster.py -c configs/ppyolo/ppyolo.yml -n 9 -s 608 -m v2 -i ...@@ -139,10 +139,8 @@ python tools/anchor_cluster.py -c configs/ppyolo/ppyolo.yml -n 9 -s 608 -m v2 -i
| -c/--config | 模型的配置文件 | 无默认值 | 必须指定 | | -c/--config | 模型的配置文件 | 无默认值 | 必须指定 |
| -n/--n | 聚类的簇数 | 9 | Anchor的数目 | | -n/--n | 聚类的簇数 | 9 | Anchor的数目 |
| -s/--size | 图片的输入尺寸 | None | 若指定,则使用指定的尺寸,如果不指定, 则尝试从配置文件中读取图片尺寸 | | -s/--size | 图片的输入尺寸 | None | 若指定,则使用指定的尺寸,如果不指定, 则尝试从配置文件中读取图片尺寸 |
| -m/--method | 使用的Anchor聚类方法 | v2 | 目前只支持yolov2/v5的聚类算法 | | -m/--method | 使用的Anchor聚类方法 | v2 | 目前只支持yolov2的聚类算法 |
| -i/--iters | kmeans聚类算法的迭代次数 | 1000 | kmeans算法收敛或者达到迭代次数后终止 | | -i/--iters | kmeans聚类算法的迭代次数 | 1000 | kmeans算法收敛或者达到迭代次数后终止 |
| -gi/--gen_iters | 遗传算法的迭代次数 | 1000 | 该参数只用于yolov5的Anchor聚类算法 |
| -t/--thresh| Anchor尺度的阈值 | 0.25 | 该参数只用于yolov5的Anchor聚类算法 |
## 4.修改参数配置 ## 4.修改参数配置
......
...@@ -126,8 +126,7 @@ class YOLOv2AnchorCluster(BaseAnchorCluster): ...@@ -126,8 +126,7 @@ class YOLOv2AnchorCluster(BaseAnchorCluster):
""" """
YOLOv2 Anchor Cluster YOLOv2 Anchor Cluster
Reference: The code is based on https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py
https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py
Args: Args:
n (int): number of clusters n (int): number of clusters
...@@ -196,103 +195,6 @@ class YOLOv2AnchorCluster(BaseAnchorCluster): ...@@ -196,103 +195,6 @@ class YOLOv2AnchorCluster(BaseAnchorCluster):
return centers return centers
class YOLOv5AnchorCluster(BaseAnchorCluster):
def __init__(self,
n,
dataset,
size,
cache_path,
cache,
iters=300,
gen_iters=1000,
thresh=0.25,
verbose=True):
super(YOLOv5AnchorCluster, self).__init__(
n, cache_path, cache, verbose=verbose)
"""
YOLOv5 Anchor Cluster
Reference:
https://github.com/ultralytics/yolov5/blob/master/utils/general.py
Args:
n (int): number of clusters
dataset (DataSet): DataSet instance, VOC or COCO
size (list): [w, h]
cache_path (str): cache directory path
cache (bool): whether using cache
iters (int): iters of kmeans algorithm
gen_iters (int): iters of genetic algorithm
threshold (float): anchor scale threshold
verbose (bool): whether print results
"""
self.dataset = dataset
self.size = size
self.iters = iters
self.gen_iters = gen_iters
self.thresh = thresh
def print_result(self, centers):
whs = self.whs
centers = centers[np.argsort(centers.prod(1))]
x, best = self.metric(whs, centers)
bpr, aat = (
best > self.thresh).mean(), (x > self.thresh).mean() * self.n
logger.info(
'thresh=%.2f: %.4f best possible recall, %.2f anchors past thr' %
(self.thresh, bpr, aat))
logger.info(
'n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thresh=%.3f-mean: '
% (self.n, self.size, x.mean(), best.mean(),
x[x > self.thresh].mean()))
logger.info('%d anchor cluster result: [w, h]' % self.n)
for w, h in centers:
logger.info('[%d, %d]' % (round(w), round(h)))
def metric(self, whs, centers):
r = whs[:, None] / centers[None]
x = np.minimum(r, 1. / r).min(2)
return x, x.max(1)
def fitness(self, whs, centers):
_, best = self.metric(whs, centers)
return (best * (best > self.thresh)).mean()
def calc_anchors(self):
self.whs = self.whs * self.shapes / self.shapes.max(
1, keepdims=True) * np.array([self.size])
wh0 = self.whs
i = (wh0 < 3.0).any(1).sum()
if i:
logger.warning('Extremely small objects found. %d of %d'
'labels are < 3 pixels in width or height' %
(i, len(wh0)))
wh = wh0[(wh0 >= 2.0).any(1)]
logger.info('Running kmeans for %g anchors on %g points...' %
(self.n, len(wh)))
s = wh.std(0)
centers, dist = kmeans(wh / s, self.n, iter=self.iters)
centers *= s
f, sh, mp, s = self.fitness(wh, centers), centers.shape, 0.9, 0.1
pbar = tqdm(
range(self.gen_iters),
desc='Evolving anchors with Genetic Algorithm')
for _ in pbar:
v = np.ones(sh)
while (v == 1).all():
v = ((np.random.random(sh) < mp) * np.random.random() *
np.random.randn(*sh) * s + 1).clip(0.3, 3.0)
new_centers = (centers.copy() * v).clip(min=2.0)
new_f = self.fitness(wh, new_centers)
if new_f > f:
f, centers = new_f, new_centers.copy()
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
return centers
def main(): def main():
parser = ArgsParser() parser = ArgsParser()
parser.add_argument( parser.add_argument(
...@@ -303,18 +205,6 @@ def main(): ...@@ -303,18 +205,6 @@ def main():
default=1000, default=1000,
type=int, type=int,
help='num of iterations for kmeans') help='num of iterations for kmeans')
parser.add_argument(
'--gen_iters',
'-gi',
default=1000,
type=int,
help='num of iterations for genetic algorithm')
parser.add_argument(
'--thresh',
'-t',
default=0.25,
type=float,
help='anchor scale threshold')
parser.add_argument( parser.add_argument(
'--verbose', '-v', default=True, type=bool, help='whether print result') '--verbose', '-v', default=True, type=bool, help='whether print result')
parser.add_argument( parser.add_argument(
...@@ -328,7 +218,7 @@ def main(): ...@@ -328,7 +218,7 @@ def main():
'-m', '-m',
default='v2', default='v2',
type=str, type=str,
help='cluster method, [v2, v5] are supported now') help='cluster method, v2 is only supported now')
parser.add_argument( parser.add_argument(
'--cache_path', default='cache', type=str, help='cache path') '--cache_path', default='cache', type=str, help='cache path')
parser.add_argument( parser.add_argument(
...@@ -353,18 +243,14 @@ def main(): ...@@ -353,18 +243,14 @@ def main():
size = int(FLAGS.size) size = int(FLAGS.size)
size = [size, size] size = [size, size]
elif 'image_shape' in cfg['TrainReader']['inputs_def']: elif 'image_shape' in cfg['TestReader']['inputs_def']:
size = cfg['TrainReader']['inputs_def']['image_shape'][1:] size = cfg['TestReader']['inputs_def']['image_shape'][1:]
else: else:
raise ValueError('size is not specified') raise ValueError('size is not specified')
if FLAGS.method == 'v2': if FLAGS.method == 'v2':
cluster = YOLOv2AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path, cluster = YOLOv2AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path,
FLAGS.cache, FLAGS.iters, FLAGS.verbose) FLAGS.cache, FLAGS.iters, FLAGS.verbose)
elif FLAGS.method == 'v5':
cluster = YOLOv5AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path,
FLAGS.cache, FLAGS.iters, FLAGS.gen_iters,
FLAGS.thresh, FLAGS.verbose)
else: else:
raise ValueError('cluster method: %s is not supported' % FLAGS.method) raise ValueError('cluster method: %s is not supported' % FLAGS.method)
......
...@@ -111,8 +111,7 @@ class YOLOv2AnchorCluster(BaseAnchorCluster): ...@@ -111,8 +111,7 @@ class YOLOv2AnchorCluster(BaseAnchorCluster):
""" """
YOLOv2 Anchor Cluster YOLOv2 Anchor Cluster
Reference: The code is based on https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py
https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py
Args: Args:
n (int): number of clusters n (int): number of clusters
...@@ -182,103 +181,6 @@ class YOLOv2AnchorCluster(BaseAnchorCluster): ...@@ -182,103 +181,6 @@ class YOLOv2AnchorCluster(BaseAnchorCluster):
return centers return centers
class YOLOv5AnchorCluster(BaseAnchorCluster):
def __init__(self,
n,
dataset,
size,
cache_path,
cache,
iters=300,
gen_iters=1000,
thresh=0.25,
verbose=True):
super(YOLOv5AnchorCluster, self).__init__(
n, cache_path, cache, verbose=verbose)
"""
YOLOv5 Anchor Cluster
Reference:
https://github.com/ultralytics/yolov5/blob/master/utils/general.py
Args:
n (int): number of clusters
dataset (DataSet): DataSet instance, VOC or COCO
size (list): [w, h]
cache_path (str): cache directory path
cache (bool): whether using cache
iters (int): iters of kmeans algorithm
gen_iters (int): iters of genetic algorithm
threshold (float): anchor scale threshold
verbose (bool): whether print results
"""
self.dataset = dataset
self.size = size
self.iters = iters
self.gen_iters = gen_iters
self.thresh = thresh
def print_result(self, centers):
whs = self.whs
centers = centers[np.argsort(centers.prod(1))]
x, best = self.metric(whs, centers)
bpr, aat = (
best > self.thresh).mean(), (x > self.thresh).mean() * self.n
logger.info(
'thresh=%.2f: %.4f best possible recall, %.2f anchors past thr' %
(self.thresh, bpr, aat))
logger.info(
'n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thresh=%.3f-mean: '
% (self.n, self.size, x.mean(), best.mean(),
x[x > self.thresh].mean()))
logger.info('%d anchor cluster result: [w, h]' % self.n)
for w, h in centers:
logger.info('[%d, %d]' % (round(w), round(h)))
def metric(self, whs, centers):
r = whs[:, None] / centers[None]
x = np.minimum(r, 1. / r).min(2)
return x, x.max(1)
def fitness(self, whs, centers):
_, best = self.metric(whs, centers)
return (best * (best > self.thresh)).mean()
def calc_anchors(self):
self.whs = self.whs * self.shapes / self.shapes.max(
1, keepdims=True) * np.array([self.size])
wh0 = self.whs
i = (wh0 < 3.0).any(1).sum()
if i:
logger.warning('Extremely small objects found. %d of %d'
'labels are < 3 pixels in width or height' %
(i, len(wh0)))
wh = wh0[(wh0 >= 2.0).any(1)]
logger.info('Running kmeans for %g anchors on %g points...' %
(self.n, len(wh)))
s = wh.std(0)
centers, dist = kmeans(wh / s, self.n, iter=self.iters)
centers *= s
f, sh, mp, s = self.fitness(wh, centers), centers.shape, 0.9, 0.1
pbar = tqdm(
range(self.gen_iters),
desc='Evolving anchors with Genetic Algorithm')
for _ in pbar:
v = np.ones(sh)
while (v == 1).all():
v = ((np.random.random(sh) < mp) * np.random.random() *
np.random.randn(*sh) * s + 1).clip(0.3, 3.0)
new_centers = (centers.copy() * v).clip(min=2.0)
new_f = self.fitness(wh, new_centers)
if new_f > f:
f, centers = new_f, new_centers.copy()
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
return centers
def main(): def main():
parser = ArgsParser() parser = ArgsParser()
parser.add_argument( parser.add_argument(
...@@ -289,18 +191,6 @@ def main(): ...@@ -289,18 +191,6 @@ def main():
default=1000, default=1000,
type=int, type=int,
help='num of iterations for kmeans') help='num of iterations for kmeans')
parser.add_argument(
'--gen_iters',
'-gi',
default=1000,
type=int,
help='num of iterations for genetic algorithm')
parser.add_argument(
'--thresh',
'-t',
default=0.25,
type=float,
help='anchor scale threshold')
parser.add_argument( parser.add_argument(
'--verbose', '-v', default=True, type=bool, help='whether print result') '--verbose', '-v', default=True, type=bool, help='whether print result')
parser.add_argument( parser.add_argument(
...@@ -314,7 +204,7 @@ def main(): ...@@ -314,7 +204,7 @@ def main():
'-m', '-m',
default='v2', default='v2',
type=str, type=str,
help='cluster method, [v2, v5] are supported now') help='cluster method, v2 is only supported now')
parser.add_argument( parser.add_argument(
'--cache_path', default='cache', type=str, help='cache path') '--cache_path', default='cache', type=str, help='cache path')
parser.add_argument( parser.add_argument(
...@@ -338,19 +228,15 @@ def main(): ...@@ -338,19 +228,15 @@ def main():
else: else:
size = int(FLAGS.size) size = int(FLAGS.size)
size = [size, size] size = [size, size]
elif 'inputs_def' in cfg['TrainReader'] and 'image_shape' in cfg[ elif 'inputs_def' in cfg['TestReader'] and 'image_shape' in cfg[
'TrainReader']['inputs_def']: 'TestReader']['inputs_def']:
size = cfg['TrainReader']['inputs_def']['image_shape'][1:] size = cfg['TestReader']['inputs_def']['image_shape'][1:]
else: else:
raise ValueError('size is not specified') raise ValueError('size is not specified')
if FLAGS.method == 'v2': if FLAGS.method == 'v2':
cluster = YOLOv2AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path, cluster = YOLOv2AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path,
FLAGS.cache, FLAGS.iters, FLAGS.verbose) FLAGS.cache, FLAGS.iters, FLAGS.verbose)
elif FLAGS.method == 'v5':
cluster = YOLOv5AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path,
FLAGS.cache, FLAGS.iters, FLAGS.gen_iters,
FLAGS.thresh, FLAGS.verbose)
else: else:
raise ValueError('cluster method: %s is not supported' % FLAGS.method) raise ValueError('cluster method: %s is not supported' % FLAGS.method)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册