From e254e7c656f9fe3fc5136e51e1972e1753b7a1e2 Mon Sep 17 00:00:00 2001 From: TTerror Date: Wed, 16 Feb 2022 17:21:34 +0800 Subject: [PATCH] optimize prior_box for kunlun, *test=kunlun (#39477) --- .../operators/detection/prior_box_op_xpu.cc | 17 +- .../unittests/xpu/test_prior_box_op_xpu.py | 342 +++++++++--------- 2 files changed, 182 insertions(+), 177 deletions(-) diff --git a/paddle/fluid/operators/detection/prior_box_op_xpu.cc b/paddle/fluid/operators/detection/prior_box_op_xpu.cc index bab39468954..c39f702a486 100644 --- a/paddle/fluid/operators/detection/prior_box_op_xpu.cc +++ b/paddle/fluid/operators/detection/prior_box_op_xpu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/detection/prior_box_op.h" +#include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { namespace operators { @@ -81,21 +82,17 @@ class PriorBoxOpXPUKernel : public framework::OpKernel { dev_ctx.x_context(), boxes_data, aspect_ratios_param, min_sizes_param, max_sizes_param, feature_height, feature_width, img_height, img_width, offset, step_height, step_width, clip, min_max_aspect_ratios_order); - PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, - platform::errors::External( - "XPU gen_prior_box kernel return wrong value[%d %s]", - ret, XPUAPIErrorMsg[ret])); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gen_prior_box"); int box_num = feature_height * feature_width * num_priors; int vlen = variances.size(); + std::vector var_cpu(vlen * box_num); for (int i = 0; i < box_num; ++i) { - ret = xpu_memcpy(vars_data + i * vlen, variances.data(), vlen * sizeof(K), - XPUMemcpyKind::XPU_HOST_TO_DEVICE); - PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, platform::errors::External( - "XPU xpu_memcpy return wrong " - "value[%d %s] in prior_box.", - ret, XPUAPIErrorMsg[ret])); + std::copy(variances.begin(), variances.end(), var_cpu.begin() + i * vlen); } + ret = xpu_memcpy(vars_data, var_cpu.data(), var_cpu.size() * sizeof(K), + XPUMemcpyKind::XPU_HOST_TO_DEVICE); + PADDLE_ENFORCE_XPU_SUCCESS(ret); } }; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_prior_box_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_prior_box_op_xpu.py index 44137f47187..0830237d5a8 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_prior_box_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_prior_box_op_xpu.py @@ -14,188 +14,196 @@ from __future__ import print_function -import unittest +import math import numpy as np import sys +import unittest sys.path.append("..") -import math + import paddle -from op_test import OpTest + from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper paddle.enable_static() -class TestPriorBoxOp(XPUOpTest): - def set_data(self): - self.init_test_params() - self.init_test_input() - self.init_test_output() - self.inputs = {'Input': self.input, 'Image': self.image} - - self.attrs = { - 'min_sizes': self.min_sizes, - 'aspect_ratios': self.aspect_ratios, - 'variances': self.variances, - 'flip': self.flip, - 'clip': self.clip, - 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, - 'step_w': self.step_w, - 'step_h': self.step_h, - 'offset': self.offset - } - if len(self.max_sizes) > 0: - self.attrs['max_sizes'] = self.max_sizes - - self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} - - def test_check_output(self): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) - - def test_check_grad(self): - pass - - def setUp(self): - self.op_type = "prior_box" - self.use_xpu = True - self.set_data() - - def set_max_sizes(self): - max_sizes = [5, 10] - self.max_sizes = np.array(max_sizes).astype('float32').tolist() - - def set_min_max_aspect_ratios_order(self): - self.min_max_aspect_ratios_order = False - - def init_test_params(self): - self.layer_w = 32 - self.layer_h = 32 - - self.image_w = 40 - self.image_h = 40 - - self.step_w = float(self.image_w) / float(self.layer_w) - self.step_h = float(self.image_h) / float(self.layer_h) - - self.input_channels = 2 - self.image_channels = 3 - self.batch_size = 10 - - self.min_sizes = [2, 4] - self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() - self.set_max_sizes() - self.aspect_ratios = [2.0, 3.0] - self.flip = True - self.set_min_max_aspect_ratios_order() - self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] - self.aspect_ratios = np.array( - self.aspect_ratios, dtype=np.float).flatten() - self.variances = [0.1, 0.1, 0.2, 0.2] - self.variances = np.array(self.variances, dtype=np.float).flatten() - - self.clip = True - self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) - if len(self.max_sizes) > 0: - self.num_priors += len(self.max_sizes) - self.offset = 0.5 - - def init_test_input(self): - self.image = np.random.random( - (self.batch_size, self.image_channels, self.image_w, - self.image_h)).astype('float32') - - self.input = np.random.random( - (self.batch_size, self.input_channels, self.layer_w, - self.layer_h)).astype('float32') - - def init_test_output(self): - out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) - out_boxes = np.zeros(out_dim).astype('float32') - out_var = np.zeros(out_dim).astype('float32') - - idx = 0 - for h in range(self.layer_h): - for w in range(self.layer_w): - c_x = (w + self.offset) * self.step_w - c_y = (h + self.offset) * self.step_h - idx = 0 - for s in range(len(self.min_sizes)): - min_size = self.min_sizes[s] - if not self.min_max_aspect_ratios_order: - # rest of priors - for r in range(len(self.real_aspect_ratios)): - ar = self.real_aspect_ratios[r] - c_w = min_size * math.sqrt(ar) / 2 - c_h = (min_size / math.sqrt(ar)) / 2 - out_boxes[h, w, idx, :] = [ - (c_x - c_w) / self.image_w, (c_y - c_h) / - self.image_h, (c_x + c_w) / self.image_w, - (c_y + c_h) / self.image_h - ] - idx += 1 - - if len(self.max_sizes) > 0: - max_size = self.max_sizes[s] - # second prior: aspect_ratio = 1, - c_w = c_h = math.sqrt(min_size * max_size) / 2 - out_boxes[h, w, idx, :] = [ - (c_x - c_w) / self.image_w, (c_y - c_h) / - self.image_h, (c_x + c_w) / self.image_w, - (c_y + c_h) / self.image_h - ] - idx += 1 - else: - c_w = c_h = min_size / 2. - out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w, - (c_y - c_h) / self.image_h, - (c_x + c_w) / self.image_w, - (c_y + c_h) / self.image_h] - idx += 1 - if len(self.max_sizes) > 0: - max_size = self.max_sizes[s] - # second prior: aspect_ratio = 1, - c_w = c_h = math.sqrt(min_size * max_size) / 2 - out_boxes[h, w, idx, :] = [ - (c_x - c_w) / self.image_w, (c_y - c_h) / - self.image_h, (c_x + c_w) / self.image_w, - (c_y + c_h) / self.image_h - ] - idx += 1 - - # rest of priors - for r in range(len(self.real_aspect_ratios)): - ar = self.real_aspect_ratios[r] - if abs(ar - 1.) < 1e-6: - continue - c_w = min_size * math.sqrt(ar) / 2 - c_h = (min_size / math.sqrt(ar)) / 2 +class XPUTestPriorBoxOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'prior_box' + self.use_dynamic_create_class = False + + class TestPriorBoxOp(XPUOpTest): + def setUp(self): + self.op_type = "prior_box" + self.use_xpu = True + self.dtype = self.in_type + self.set_data() + + def set_data(self): + self.init_test_params() + self.init_test_input() + self.init_test_output() + self.inputs = {'Input': self.input, 'Image': self.image} + + self.attrs = { + 'min_sizes': self.min_sizes, + 'aspect_ratios': self.aspect_ratios, + 'variances': self.variances, + 'flip': self.flip, + 'clip': self.clip, + 'min_max_aspect_ratios_order': self.min_max_aspect_ratios_order, + 'step_w': self.step_w, + 'step_h': self.step_h, + 'offset': self.offset + } + if len(self.max_sizes) > 0: + self.attrs['max_sizes'] = self.max_sizes + + self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def set_max_sizes(self): + max_sizes = [5, 10] + self.max_sizes = np.array(max_sizes).astype('float32').tolist() + + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = False + + def init_test_params(self): + self.layer_w = 32 + self.layer_h = 32 + + self.image_w = 40 + self.image_h = 40 + + self.step_w = float(self.image_w) / float(self.layer_w) + self.step_h = float(self.image_h) / float(self.layer_h) + + self.input_channels = 2 + self.image_channels = 3 + self.batch_size = 10 + + self.min_sizes = [2, 4] + self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() + self.set_max_sizes() + self.aspect_ratios = [2.0, 3.0] + self.flip = True + self.set_min_max_aspect_ratios_order() + self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] + self.aspect_ratios = np.array( + self.aspect_ratios, dtype=np.float).flatten() + self.variances = [0.1, 0.1, 0.2, 0.2] + self.variances = np.array(self.variances, dtype=np.float).flatten() + + self.clip = True + self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes) + if len(self.max_sizes) > 0: + self.num_priors += len(self.max_sizes) + self.offset = 0.5 + + def init_test_input(self): + self.image = np.random.random( + (self.batch_size, self.image_channels, self.image_w, + self.image_h)).astype(self.dtype) + + self.input = np.random.random( + (self.batch_size, self.input_channels, self.layer_w, + self.layer_h)).astype(self.dtype) + + def init_test_output(self): + out_dim = (self.layer_h, self.layer_w, self.num_priors, 4) + out_boxes = np.zeros(out_dim).astype(self.dtype) + out_var = np.zeros(out_dim).astype(self.dtype) + + idx = 0 + for h in range(self.layer_h): + for w in range(self.layer_w): + c_x = (w + self.offset) * self.step_w + c_y = (h + self.offset) * self.step_h + idx = 0 + for s in range(len(self.min_sizes)): + min_size = self.min_sizes[s] + if not self.min_max_aspect_ratios_order: + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + else: + c_w = c_h = min_size / 2. out_boxes[h, w, idx, :] = [ (c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h, (c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h ] idx += 1 - - # clip the prior's coordidate such that it is within[0, 1] - if self.clip: - out_boxes = np.clip(out_boxes, 0.0, 1.0) - # set the variance. - out_var = np.tile(self.variances, (self.layer_h, self.layer_w, - self.num_priors, 1)) - self.out_boxes = out_boxes.astype('float32') - self.out_var = out_var.astype('float32') - - -class TestPriorBoxOpWithoutMaxSize(TestPriorBoxOp): - def set_max_sizes(self): - self.max_sizes = [] - - -class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp): - def set_min_max_aspect_ratios_order(self): - self.min_max_aspect_ratios_order = True - + if len(self.max_sizes) > 0: + max_size = self.max_sizes[s] + # second prior: aspect_ratio = 1, + c_w = c_h = math.sqrt(min_size * max_size) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + # rest of priors + for r in range(len(self.real_aspect_ratios)): + ar = self.real_aspect_ratios[r] + if abs(ar - 1.) < 1e-6: + continue + c_w = min_size * math.sqrt(ar) / 2 + c_h = (min_size / math.sqrt(ar)) / 2 + out_boxes[h, w, idx, :] = [ + (c_x - c_w) / self.image_w, (c_y - c_h) / + self.image_h, (c_x + c_w) / self.image_w, + (c_y + c_h) / self.image_h + ] + idx += 1 + + # clip the prior's coordidate such that it is within[0, 1] + if self.clip: + out_boxes = np.clip(out_boxes, 0.0, 1.0) + # set the variance. + out_var = np.tile(self.variances, (self.layer_h, self.layer_w, + self.num_priors, 1)) + self.out_boxes = out_boxes.astype(self.dtype) + self.out_var = out_var.astype(self.dtype) + + class TestPriorBoxOpWithoutMaxSize(TestPriorBoxOp): + def set_max_sizes(self): + self.max_sizes = [] + + class TestPriorBoxOpWithSpecifiedOutOrder(TestPriorBoxOp): + def set_min_max_aspect_ratios_order(self): + self.min_max_aspect_ratios_order = True + + +support_types = get_xpu_op_support_types('prior_box') +for stype in support_types: + create_test_class(globals(), XPUTestPriorBoxOp, stype) if __name__ == '__main__': unittest.main() -- GitLab