未验证 提交 e254e7c6 编写于 作者: T TTerror 提交者: GitHub

optimize prior_box for kunlun, *test=kunlun (#39477)

上级 f138371c
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/detection/prior_box_op.h" #include "paddle/fluid/operators/detection/prior_box_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -81,21 +82,17 @@ class PriorBoxOpXPUKernel : public framework::OpKernel<T> { ...@@ -81,21 +82,17 @@ class PriorBoxOpXPUKernel : public framework::OpKernel<T> {
dev_ctx.x_context(), boxes_data, aspect_ratios_param, min_sizes_param, dev_ctx.x_context(), boxes_data, aspect_ratios_param, min_sizes_param,
max_sizes_param, feature_height, feature_width, img_height, img_width, max_sizes_param, feature_height, feature_width, img_height, img_width,
offset, step_height, step_width, clip, min_max_aspect_ratios_order); offset, step_height, step_width, clip, min_max_aspect_ratios_order);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gen_prior_box");
platform::errors::External(
"XPU gen_prior_box kernel return wrong value[%d %s]",
ret, XPUAPIErrorMsg[ret]));
int box_num = feature_height * feature_width * num_priors; int box_num = feature_height * feature_width * num_priors;
int vlen = variances.size(); int vlen = variances.size();
std::vector<K> var_cpu(vlen * box_num);
for (int i = 0; i < box_num; ++i) { for (int i = 0; i < box_num; ++i) {
ret = xpu_memcpy(vars_data + i * vlen, variances.data(), vlen * sizeof(K), std::copy(variances.begin(), variances.end(), var_cpu.begin() + i * vlen);
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, platform::errors::External(
"XPU xpu_memcpy return wrong "
"value[%d %s] in prior_box.",
ret, XPUAPIErrorMsg[ret]));
} }
ret = xpu_memcpy(vars_data, var_cpu.data(), var_cpu.size() * sizeof(K),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
PADDLE_ENFORCE_XPU_SUCCESS(ret);
} }
}; };
......
...@@ -14,188 +14,196 @@ ...@@ -14,188 +14,196 @@
from __future__ import print_function from __future__ import print_function
import unittest import math
import numpy as np import numpy as np
import sys import sys
import unittest
sys.path.append("..") sys.path.append("..")
import math
import paddle import paddle
from op_test import OpTest
from op_test_xpu import XPUOpTest from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static() paddle.enable_static()
class TestPriorBoxOp(XPUOpTest): class XPUTestPriorBoxOp(XPUOpTestWrapper):
def set_data(self): def __init__(self):
self.init_test_params() self.op_name = 'prior_box'
self.init_test_input() self.use_dynamic_create_class = False
self.init_test_output()
self.inputs = {'Input': self.input, 'Image': self.image} class TestPriorBoxOp(XPUOpTest):
def setUp(self):
self.attrs = { self.op_type = "prior_box"
'min_sizes': self.min_sizes, self.use_xpu = True
'aspect_ratios': self.aspect_ratios, self.dtype = self.in_type
'variances': self.variances, self.set_data()
'flip': self.flip,
'clip': self.clip, def set_data(self):
'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, self.init_test_params()
'step_w': self.step_w, self.init_test_input()
'step_h': self.step_h, self.init_test_output()
'offset': self.offset self.inputs = {'Input': self.input, 'Image': self.image}
}
if len(self.max_sizes) > 0: self.attrs = {
self.attrs['max_sizes'] = self.max_sizes 'min_sizes': self.min_sizes,
'aspect_ratios': self.aspect_ratios,
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} 'variances': self.variances,
'flip': self.flip,
def test_check_output(self): 'clip': self.clip,
place = paddle.XPUPlace(0) 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order,
self.check_output_with_place(place) 'step_w': self.step_w,
'step_h': self.step_h,
def test_check_grad(self): 'offset': self.offset
pass }
if len(self.max_sizes) > 0:
def setUp(self): self.attrs['max_sizes'] = self.max_sizes
self.op_type = "prior_box"
self.use_xpu = True self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
self.set_data()
def test_check_output(self):
def set_max_sizes(self): place = paddle.XPUPlace(0)
max_sizes = [5, 10] self.check_output_with_place(place)
self.max_sizes = np.array(max_sizes).astype('float32').tolist()
def set_max_sizes(self):
def set_min_max_aspect_ratios_order(self): max_sizes = [5, 10]
self.min_max_aspect_ratios_order = False self.max_sizes = np.array(max_sizes).astype('float32').tolist()
def init_test_params(self): def set_min_max_aspect_ratios_order(self):
self.layer_w = 32 self.min_max_aspect_ratios_order = False
self.layer_h = 32
def init_test_params(self):
self.image_w = 40 self.layer_w = 32
self.image_h = 40 self.layer_h = 32
self.step_w = float(self.image_w) / float(self.layer_w) self.image_w = 40
self.step_h = float(self.image_h) / float(self.layer_h) self.image_h = 40
self.input_channels = 2 self.step_w = float(self.image_w) / float(self.layer_w)
self.image_channels = 3 self.step_h = float(self.image_h) / float(self.layer_h)
self.batch_size = 10
self.input_channels = 2
self.min_sizes = [2, 4] self.image_channels = 3
self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() self.batch_size = 10
self.set_max_sizes()
self.aspect_ratios = [2.0, 3.0] self.min_sizes = [2, 4]
self.flip = True self.min_sizes = np.array(self.min_sizes).astype('float32').tolist()
self.set_min_max_aspect_ratios_order() self.set_max_sizes()
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] self.aspect_ratios = [2.0, 3.0]
self.aspect_ratios = np.array( self.flip = True
self.aspect_ratios, dtype=np.float).flatten() self.set_min_max_aspect_ratios_order()
self.variances = [0.1, 0.1, 0.2, 0.2] self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
self.variances = np.array(self.variances, dtype=np.float).flatten() self.aspect_ratios = np.array(
self.aspect_ratios, dtype=np.float).flatten()
self.clip = True self.variances = [0.1, 0.1, 0.2, 0.2]
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) self.variances = np.array(self.variances, dtype=np.float).flatten()
if len(self.max_sizes) > 0:
self.num_priors += len(self.max_sizes) self.clip = True
self.offset = 0.5 self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
if len(self.max_sizes) > 0:
def init_test_input(self): self.num_priors += len(self.max_sizes)
self.image = np.random.random( self.offset = 0.5
(self.batch_size, self.image_channels, self.image_w,
self.image_h)).astype('float32') def init_test_input(self):
self.image = np.random.random(
self.input = np.random.random( (self.batch_size, self.image_channels, self.image_w,
(self.batch_size, self.input_channels, self.layer_w, self.image_h)).astype(self.dtype)
self.layer_h)).astype('float32')
self.input = np.random.random(
def init_test_output(self): (self.batch_size, self.input_channels, self.layer_w,
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) self.layer_h)).astype(self.dtype)
out_boxes = np.zeros(out_dim).astype('float32')
out_var = np.zeros(out_dim).astype('float32') def init_test_output(self):
out_dim = (self.layer_h, self.layer_w, self.num_priors, 4)
idx = 0 out_boxes = np.zeros(out_dim).astype(self.dtype)
for h in range(self.layer_h): out_var = np.zeros(out_dim).astype(self.dtype)
for w in range(self.layer_w):
c_x = (w + self.offset) * self.step_w idx = 0
c_y = (h + self.offset) * self.step_h for h in range(self.layer_h):
idx = 0 for w in range(self.layer_w):
for s in range(len(self.min_sizes)): c_x = (w + self.offset) * self.step_w
min_size = self.min_sizes[s] c_y = (h + self.offset) * self.step_h
if not self.min_max_aspect_ratios_order: idx = 0
# rest of priors for s in range(len(self.min_sizes)):
for r in range(len(self.real_aspect_ratios)): min_size = self.min_sizes[s]
ar = self.real_aspect_ratios[r] if not self.min_max_aspect_ratios_order:
c_w = min_size * math.sqrt(ar) / 2 # rest of priors
c_h = (min_size / math.sqrt(ar)) / 2 for r in range(len(self.real_aspect_ratios)):
out_boxes[h, w, idx, :] = [ ar = self.real_aspect_ratios[r]
(c_x - c_w) / self.image_w, (c_y - c_h) / c_w = min_size * math.sqrt(ar) / 2
self.image_h, (c_x + c_w) / self.image_w, c_h = (min_size / math.sqrt(ar)) / 2
(c_y + c_h) / self.image_h out_boxes[h, w, idx, :] = [
] (c_x - c_w) / self.image_w, (c_y - c_h) /
idx += 1 self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
if len(self.max_sizes) > 0: ]
max_size = self.max_sizes[s] idx += 1
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2 if len(self.max_sizes) > 0:
out_boxes[h, w, idx, :] = [ max_size = self.max_sizes[s]
(c_x - c_w) / self.image_w, (c_y - c_h) / # second prior: aspect_ratio = 1,
self.image_h, (c_x + c_w) / self.image_w, c_w = c_h = math.sqrt(min_size * max_size) / 2
(c_y + c_h) / self.image_h out_boxes[h, w, idx, :] = [
] (c_x - c_w) / self.image_w, (c_y - c_h) /
idx += 1 self.image_h, (c_x + c_w) / self.image_w,
else: (c_y + c_h) / self.image_h
c_w = c_h = min_size / 2. ]
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, idx += 1
(c_y - c_h) / self.image_h, else:
(c_x + c_w) / self.image_w, c_w = c_h = min_size / 2.
(c_y + c_h) / self.image_h]
idx += 1
if len(self.max_sizes) > 0:
max_size = self.max_sizes[s]
# second prior: aspect_ratio = 1,
c_w = c_h = math.sqrt(min_size * max_size) / 2
out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
# rest of priors
for r in range(len(self.real_aspect_ratios)):
ar = self.real_aspect_ratios[r]
if abs(ar - 1.) < 1e-6:
continue
c_w = min_size * math.sqrt(ar) / 2
c_h = (min_size / math.sqrt(ar)) / 2
out_boxes[h, w, idx, :] = [ out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) / (c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w, self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h (c_y + c_h) / self.image_h
] ]
idx += 1 idx += 1
if len(self.max_sizes) > 0:
# clip the prior's coordidate such that it is within[0, 1] max_size = self.max_sizes[s]
if self.clip: # second prior: aspect_ratio = 1,
out_boxes = np.clip(out_boxes, 0.0, 1.0) c_w = c_h = math.sqrt(min_size * max_size) / 2
# set the variance. out_boxes[h, w, idx, :] = [
out_var = np.tile(self.variances, (self.layer_h, self.layer_w, (c_x - c_w) / self.image_w, (c_y - c_h) /
self.num_priors, 1)) self.image_h, (c_x + c_w) / self.image_w,
self.out_boxes = out_boxes.astype('float32') (c_y + c_h) / self.image_h
self.out_var = out_var.astype('float32') ]
idx += 1
class TestPriorBoxOpWithoutMaxSize(TestPriorBoxOp): # rest of priors
def set_max_sizes(self): for r in range(len(self.real_aspect_ratios)):
self.max_sizes = [] ar = self.real_aspect_ratios[r]
if abs(ar - 1.) < 1e-6:
continue
class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp): c_w = min_size * math.sqrt(ar) / 2
def set_min_max_aspect_ratios_order(self): c_h = (min_size / math.sqrt(ar)) / 2
self.min_max_aspect_ratios_order = True out_boxes[h, w, idx, :] = [
(c_x - c_w) / self.image_w, (c_y - c_h) /
self.image_h, (c_x + c_w) / self.image_w,
(c_y + c_h) / self.image_h
]
idx += 1
# clip the prior's coordidate such that it is within[0, 1]
if self.clip:
out_boxes = np.clip(out_boxes, 0.0, 1.0)
# set the variance.
out_var = np.tile(self.variances, (self.layer_h, self.layer_w,
self.num_priors, 1))
self.out_boxes = out_boxes.astype(self.dtype)
self.out_var = out_var.astype(self.dtype)
class TestPriorBoxOpWithoutMaxSize(TestPriorBoxOp):
def set_max_sizes(self):
self.max_sizes = []
class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp):
def set_min_max_aspect_ratios_order(self):
self.min_max_aspect_ratios_order = True
support_types = get_xpu_op_support_types('prior_box')
for stype in support_types:
create_test_class(globals(), XPUTestPriorBoxOp, stype)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册