提交 be2d9dc2 编写于 作者: B baiyf 提交者: qingqing01

Add prior_box output order control (#12032)

* Add flag to set prior_box output order.
上级 8e4b225f
...@@ -149,6 +149,13 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -149,6 +149,13 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"(float) " "(float) "
"Prior boxes center offset.") "Prior boxes center offset.")
.SetDefault(0.5); .SetDefault(0.5);
AddAttr<bool>(
"min_max_aspect_ratios_order",
"(bool) If set True, the output prior box is in order of"
"[min, max, aspect_ratios], which is consistent with Caffe."
"Please note, this order affects the weights order of convolution layer"
"followed by and does not affect the final detection results.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Prior box operator Prior box operator
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
......
...@@ -28,8 +28,8 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height, ...@@ -28,8 +28,8 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
const int im_width, const int as_num, const int im_width, const int as_num,
const T offset, const T step_width, const T offset, const T step_width,
const T step_height, const T* min_sizes, const T step_height, const T* min_sizes,
const T* max_sizes, const int min_num, const T* max_sizes, const int min_num, bool is_clip,
bool is_clip) { bool min_max_aspect_ratios_order) {
int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num; int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num;
int box_num = height * width * num_priors; int box_num = height * width * num_priors;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
...@@ -44,14 +44,28 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height, ...@@ -44,14 +44,28 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
T min_size = min_sizes[m]; T min_size = min_sizes[m];
if (max_sizes) { if (max_sizes) {
int s = p % (as_num + 1); int s = p % (as_num + 1);
if (s < as_num) { if (!min_max_aspect_ratios_order) {
T ar = aspect_ratios[s]; if (s < as_num) {
bw = min_size * sqrt(ar) / 2.; T ar = aspect_ratios[s];
bh = min_size / sqrt(ar) / 2.; bw = min_size * sqrt(ar) / 2.;
bh = min_size / sqrt(ar) / 2.;
} else {
T max_size = max_sizes[m];
bw = sqrt(min_size * max_size) / 2.;
bh = bw;
}
} else { } else {
T max_size = max_sizes[m]; if (s == 0) {
bw = sqrt(min_size * max_size) / 2.; bw = bh = min_size / 2.;
bh = bw; } else if (s == 1) {
T max_size = max_sizes[m];
bw = sqrt(min_size * max_size) / 2.;
bh = bw;
} else {
T ar = aspect_ratios[s - 1];
bw = min_size * sqrt(ar) / 2.;
bh = min_size / sqrt(ar) / 2.;
}
} }
} else { } else {
int s = p % as_num; int s = p % as_num;
...@@ -94,6 +108,8 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -94,6 +108,8 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
auto variances = ctx.Attr<std::vector<float>>("variances"); auto variances = ctx.Attr<std::vector<float>>("variances");
auto flip = ctx.Attr<bool>("flip"); auto flip = ctx.Attr<bool>("flip");
auto clip = ctx.Attr<bool>("clip"); auto clip = ctx.Attr<bool>("clip");
auto min_max_aspect_ratios_order =
ctx.Attr<bool>("min_max_aspect_ratios_order");
std::vector<float> aspect_ratios; std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
...@@ -149,7 +165,7 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -149,7 +165,7 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
GenPriorBox<T><<<grid, block, 0, stream>>>( GenPriorBox<T><<<grid, block, 0, stream>>>(
boxes->data<T>(), r.data<T>(), height, width, im_height, im_width, boxes->data<T>(), r.data<T>(), height, width, im_height, im_width,
aspect_ratios.size(), offset, step_width, step_height, min.data<T>(), aspect_ratios.size(), offset, step_width, step_height, min.data<T>(),
max_data, min_num, clip); max_data, min_num, clip, min_max_aspect_ratios_order);
framework::Tensor v; framework::Tensor v;
framework::TensorFromVector(variances, ctx.device_context(), &v); framework::TensorFromVector(variances, ctx.device_context(), &v);
......
...@@ -68,6 +68,8 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -68,6 +68,8 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
auto variances = ctx.Attr<std::vector<float>>("variances"); auto variances = ctx.Attr<std::vector<float>>("variances");
auto flip = ctx.Attr<bool>("flip"); auto flip = ctx.Attr<bool>("flip");
auto clip = ctx.Attr<bool>("clip"); auto clip = ctx.Attr<bool>("clip");
auto min_max_aspect_ratios_order =
ctx.Attr<bool>("min_max_aspect_ratios_order");
std::vector<float> aspect_ratios; std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
...@@ -108,26 +110,59 @@ class PriorBoxOpKernel : public framework::OpKernel<T> { ...@@ -108,26 +110,59 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
int idx = 0; int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) { for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s]; auto min_size = min_sizes[s];
// priors with different aspect ratios if (min_max_aspect_ratios_order) {
for (size_t r = 0; r < aspect_ratios.size(); ++r) { box_width = box_height = min_size / 2.;
float ar = aspect_ratios[r];
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
// square prior with size sqrt(minSize * maxSize)
box_width = box_height = sqrt(min_size * max_size) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width; e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height; e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width; e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height; e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++; idx++;
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
// square prior with size sqrt(minSize * maxSize)
box_width = box_height = sqrt(min_size * max_size) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
// priors with different aspect ratios
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
float ar = aspect_ratios[r];
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
} else {
// priors with different aspect ratios
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
float ar = aspect_ratios[r];
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
// square prior with size sqrt(minSize * maxSize)
box_width = box_height = sqrt(min_size * max_size) / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
}
} }
} }
} }
......
...@@ -789,7 +789,8 @@ def prior_box(input, ...@@ -789,7 +789,8 @@ def prior_box(input,
clip=False, clip=False,
steps=[0.0, 0.0], steps=[0.0, 0.0],
offset=0.5, offset=0.5,
name=None): name=None,
min_max_aspect_ratios_order=False):
""" """
**Prior Box Operator** **Prior Box Operator**
...@@ -818,6 +819,11 @@ def prior_box(input, ...@@ -818,6 +819,11 @@ def prior_box(input,
Default: [0., 0.] Default: [0., 0.]
offset(float): Prior boxes center offset. Default: 0.5 offset(float): Prior boxes center offset. Default: 0.5
name(str): Name of the prior box op. Default: None. name(str): Name of the prior box op. Default: None.
min_max_aspect_ratios_order(bool): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the final
detection results. Default: False.
Returns: Returns:
tuple: A tuple with two Variable (boxes, variances) tuple: A tuple with two Variable (boxes, variances)
...@@ -871,7 +877,8 @@ def prior_box(input, ...@@ -871,7 +877,8 @@ def prior_box(input,
'clip': clip, 'clip': clip,
'step_w': steps[0], 'step_w': steps[0],
'step_h': steps[1], 'step_h': steps[1],
'offset': offset 'offset': offset,
'min_max_aspect_ratios_order': min_max_aspect_ratios_order
} }
if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0: if max_sizes is not None and len(max_sizes) > 0 and max_sizes[0] > 0:
if not _is_list_or_tuple_(max_sizes): if not _is_list_or_tuple_(max_sizes):
...@@ -911,7 +918,8 @@ def multi_box_head(inputs, ...@@ -911,7 +918,8 @@ def multi_box_head(inputs,
kernel_size=1, kernel_size=1,
pad=0, pad=0,
stride=1, stride=1,
name=None): name=None,
min_max_aspect_ratios_order=False):
""" """
Generate prior boxes for SSD(Single Shot MultiBox Detector) Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the algorithm. The details of this algorithm, please refer the
...@@ -954,6 +962,11 @@ def multi_box_head(inputs, ...@@ -954,6 +962,11 @@ def multi_box_head(inputs,
pad(int|list|tuple): The padding of conv2d. Default:0. pad(int|list|tuple): The padding of conv2d. Default:0.
stride(int|list|tuple): The stride of conv2d. Default:1, stride(int|list|tuple): The stride of conv2d. Default:1,
name(str): Name of the prior box layer. Default: None. name(str): Name of the prior box layer. Default: None.
min_max_aspect_ratios_order(bool): If set True, the output prior box is
in order of [min, max, aspect_ratios], which is consistent with
Caffe. Please note, this order affects the weights order of
convolution layer followed by and does not affect the fininal
detection results. Default: False.
Returns: Returns:
tuple: A tuple with four Variables. (mbox_loc, mbox_conf, boxes, variances) tuple: A tuple with four Variables. (mbox_loc, mbox_conf, boxes, variances)
...@@ -1068,7 +1081,8 @@ def multi_box_head(inputs, ...@@ -1068,7 +1081,8 @@ def multi_box_head(inputs,
step = [step_w[i] if step_w else 0.0, step_h[i] if step_w else 0.0] step = [step_w[i] if step_w else 0.0, step_h[i] if step_w else 0.0]
box, var = prior_box(input, image, min_size, max_size, aspect_ratio, box, var = prior_box(input, image, min_size, max_size, aspect_ratio,
variance, flip, clip, step, offset) variance, flip, clip, step, offset, None,
min_max_aspect_ratios_order)
box_results.append(box) box_results.append(box)
var_results.append(var) var_results.append(var)
......
...@@ -32,6 +32,7 @@ class TestPriorBoxOp(OpTest): ...@@ -32,6 +32,7 @@ class TestPriorBoxOp(OpTest):
'variances': self.variances, 'variances': self.variances,
'flip': self.flip, 'flip': self.flip,
'clip': self.clip, 'clip': self.clip,
'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order,
'step_w': self.step_w, 'step_w': self.step_w,
'step_h': self.step_h, 'step_h': self.step_h,
'offset': self.offset 'offset': self.offset
...@@ -52,6 +53,9 @@ class TestPriorBoxOp(OpTest): ...@@ -52,6 +53,9 @@ class TestPriorBoxOp(OpTest):
max_sizes = [5, 10] max_sizes = [5, 10]
self.max_sizes = np.array(max_sizes).astype('float32').tolist() self.max_sizes = np.array(max_sizes).astype('float32').tolist()
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = False
def init_test_params(self): def init_test_params(self):
self.layer_w = 32 self.layer_w = 32
self.layer_h = 32 self.layer_h = 32
...@@ -71,6 +75,7 @@ class TestPriorBoxOp(OpTest): ...@@ -71,6 +75,7 @@ class TestPriorBoxOp(OpTest):
self.set_max_sizes() self.set_max_sizes()
self.aspect_ratios = [2.0, 3.0] self.aspect_ratios = [2.0, 3.0]
self.flip = True self.flip = True
self.set_min_max_aspect_ratios_order()
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.aspect_ratios = np.array( self.aspect_ratios = np.array(
self.aspect_ratios, dtype=np.float).flatten() self.aspect_ratios, dtype=np.float).flatten()
...@@ -78,7 +83,6 @@ class TestPriorBoxOp(OpTest): ...@@ -78,7 +83,6 @@ class TestPriorBoxOp(OpTest):
self.variances = np.array(self.variances, dtype=np.float).flatten() self.variances = np.array(self.variances, dtype=np.float).flatten()
self.clip = True self.clip = True
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 0: if len(self.max_sizes) > 0:
self.num_priors += len(self.max_sizes) self.num_priors += len(self.max_sizes)
...@@ -106,26 +110,60 @@ class TestPriorBoxOp(OpTest): ...@@ -106,26 +110,60 @@ class TestPriorBoxOp(OpTest):
idx = 0 idx = 0
for s in range(len(self.min_sizes)): for s in range(len(self.min_sizes)):
min_size = self.min_sizes[s] min_size = self.min_sizes[s]
# rest of priors if not self.min_max_aspect_ratios_order:
for r in range(len(self.real_aspect_ratios)): # rest of priors
ar = self.real_aspect_ratios[r] for r in range(len(self.real_aspect_ratios)):
c_w = min_size * math.sqrt(ar) / 2 ar = self.real_aspect_ratios[r]
c_h = (min_size / math.sqrt(ar)) / 2 c_w = min_size * math.sqrt(ar) / 2
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, c_h = (min_size / math.sqrt(ar)) / 2
(c_y - c_h) / self.image_h, out_boxes[h, w, idx, :] = [
(c_x + c_w) / self.image_w, (c_x - c_w) / self.image_w, (c_y - c_h) /
(c_y + c_h) / self.image_h] self.image_h, (c_x + c_w) / self.image_w,
idx += 1 (c_y + c_h) / self.image_h
]
if len(self.max_sizes) > 0: idx += 1
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1, if len(self.max_sizes) > 0:
c_w = c_h = math.sqrt(min_size * max_size) / 2 max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
else:
c_w = c_h = min_size / 2.
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
(c_y - c_h) / self.image_h, (c_y - c_h) / self.image_h,
(c_x + c_w) / self.image_w, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h] (c_y + c_h) / self.image_h]
idx += 1 idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if abs(ar - 1.) < 1e-6:
continue
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
# clip the prior's coordidate such that it is within[0, 1] # clip the prior's coordidate such that it is within[0, 1]
if self.clip: if self.clip:
...@@ -137,10 +175,15 @@ class TestPriorBoxOp(OpTest): ...@@ -137,10 +175,15 @@ class TestPriorBoxOp(OpTest):
self.out_var = out_var.astype('float32') self.out_var = out_var.astype('float32')
class TestPriorBoxOpWithMaxSize(TestPriorBoxOp): class TestPriorBoxOpWithoutMaxSize(TestPriorBoxOp):
def set_max_sizes(self): def set_max_sizes(self):
self.max_sizes = [] self.max_sizes = []
class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp):
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册