diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index f6fbe97565c43c306ea885c765c0a665492fa317..933a28f3f909ad48dc5a62b9b04beccbe60e7d6b 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) +detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) if(WITH_GPU) detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub) diff --git a/paddle/fluid/operators/detection/box_decoder_and_assign_op.cc b/paddle/fluid/operators/detection/box_decoder_and_assign_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4fb4a4c669ee388364a89878471854cd47d8c637 --- /dev/null +++ b/paddle/fluid/operators/detection/box_decoder_and_assign_op.cc @@ -0,0 +1,164 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; + +class BoxDecoderAndAssignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("PriorBox"), + "Input(PriorBox) of BoxDecoderAndAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("PriorBoxVar"), + "Input(PriorBoxVar) of BoxDecoderAndAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("TargetBox"), + "Input(TargetBox) of BoxDecoderAndAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("BoxScore"), + "Input(BoxScore) of BoxDecoderAndAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("OutputBox"), + "Output(OutputBox) of BoxDecoderAndAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("OutputAssignBox"), + "Output(OutputAssignBox) of BoxDecoderAndAssignOp should not be null."); + + auto prior_box_dims = ctx->GetInputDim("PriorBox"); + auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar"); + auto target_box_dims = ctx->GetInputDim("TargetBox"); + auto box_score_dims = ctx->GetInputDim("BoxScore"); + + PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2, + "The rank of Input of PriorBox must be 2"); + PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]"); + PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 1, + "The rank of Input of PriorBoxVar must be 1"); + PADDLE_ENFORCE_EQ(prior_box_var_dims[0], 4, + "The shape of PriorBoxVar is [4]"); + PADDLE_ENFORCE_EQ(target_box_dims.size(), 2, + "The rank of Input of TargetBox must be 2"); + PADDLE_ENFORCE_EQ(box_score_dims.size(), 2, + "The rank of Input of BoxScore must be 2"); + PADDLE_ENFORCE_EQ(prior_box_dims[0], target_box_dims[0], + "The first dim of prior_box and target_box is roi nums " + "and should be same!"); + PADDLE_ENFORCE_EQ(prior_box_dims[0], box_score_dims[0], + "The first dim of prior_box and box_score is roi nums " + "and should be same!"); + PADDLE_ENFORCE_EQ(target_box_dims[1], box_score_dims[1] * prior_box_dims[1], + "The shape of target_box is [N, classnum * 4], The shape " + "of box_score is [N, classnum], The shape of prior_box " + "is [N, 4]"); + + ctx->SetOutputDim("OutputBox", framework::make_ddim({target_box_dims[0], + target_box_dims[1]})); + ctx->ShareLoD("TargetBox", /*->*/ "OutputBox"); + ctx->SetOutputDim( + "OutputAssignBox", + framework::make_ddim({prior_box_dims[0], prior_box_dims[1]})); + ctx->ShareLoD("PriorBox", /*->*/ "OutputAssignBox"); + } +}; + +class BoxDecoderAndAssignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "PriorBox", + "(Tensor, default Tensor) " + "Box list PriorBox is a 2-D Tensor with shape [M, 4] holds M boxes, " + "each box is represented as [xmin, ymin, xmax, ymax], " + "[xmin, ymin] is the left top coordinate of the anchor box, " + "if the input is image feature map, they are close to the origin " + "of the coordinate system. [xmax, ymax] is the right bottom " + "coordinate of the anchor box."); + AddInput("PriorBoxVar", + "(Tensor, default Tensor, optional) " + "PriorBoxVar is a 2-D Tensor with shape [M, 4] holds M group " + "of variance. PriorBoxVar will set all elements to 1 by " + "default.") + .AsDispensable(); + AddInput( + "TargetBox", + "(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape " + "[N, classnum*4]. [N, classnum*4], each box is represented as " + "[xmin, ymin, xmax, ymax], [xmin, ymin] is the left top coordinate " + "of the box if the input is image feature map, they are close to " + "the origin of the coordinate system. [xmax, ymax] is the right " + "bottom coordinate of the box. This tensor can contain LoD " + "information to represent a batch of inputs. One instance of this " + "batch can contain different numbers of entities."); + AddInput( + "BoxScore", + "(LoDTensor or Tensor) This input can be a 2-D LoDTensor with shape " + "[N, classnum], each box is represented as [classnum] which is " + "the classification probabilities."); + AddAttr("box_clip", + "(float, default 4.135, np.log(1000. / 16.)) " + "clip box to prevent overflowing") + .SetDefault(4.135f); + AddOutput("OutputBox", + "(LoDTensor or Tensor) " + "the output tensor of op with shape [N, classnum * 4] " + "representing the result of N target boxes decoded with " + "M Prior boxes and variances for each class."); + AddOutput("OutputAssignBox", + "(LoDTensor or Tensor) " + "the output tensor of op with shape [N, 4] " + "representing the result of N target boxes decoded with " + "M Prior boxes and variances with the best non-background class " + "by BoxScore."); + AddComment(R"DOC( + +Bounding Box Coder. + +Decode the target bounding box with the priorbox information. + +The Decoding schema described below: + + ox = (pw * pxv * tx * + px) - tw / 2 + + oy = (ph * pyv * ty * + py) - th / 2 + + ow = exp(pwv * tw) * pw + tw / 2 + + oh = exp(phv * th) * ph + th / 2 + +where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width +and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the +priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`, +`phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the +encoded/decoded coordinates, width and height. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(box_decoder_and_assign, ops::BoxDecoderAndAssignOp, + ops::BoxDecoderAndAssignOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + box_decoder_and_assign, + ops::BoxDecoderAndAssignKernel, + ops::BoxDecoderAndAssignKernel); diff --git a/paddle/fluid/operators/detection/box_decoder_and_assign_op.cu b/paddle/fluid/operators/detection/box_decoder_and_assign_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef17c4c0006d7ff33a3fa3b0479dda4fe1d720de --- /dev/null +++ b/paddle/fluid/operators/detection/box_decoder_and_assign_op.cu @@ -0,0 +1,147 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +template +__global__ void DecodeBoxKernel(const T* prior_box_data, + const T* prior_box_var_data, + const T* target_box_data, const int roi_num, + const int class_num, const T box_clip, + T* output_box_data) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < roi_num * class_num) { + int i = idx / class_num; + int j = idx % class_num; + T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1; + T prior_box_height = + prior_box_data[i * 4 + 3] - prior_box_data[i * 4 + 1] + 1; + T prior_box_center_x = prior_box_data[i * 4] + prior_box_width / 2; + T prior_box_center_y = prior_box_data[i * 4 + 1] + prior_box_height / 2; + + int offset = i * class_num * 4 + j * 4; + T dw = prior_box_var_data[2] * target_box_data[offset + 2]; + T dh = prior_box_var_data[3] * target_box_data[offset + 3]; + if (dw > box_clip) { + dw = box_clip; + } + if (dh > box_clip) { + dh = box_clip; + } + T target_box_center_x = 0, target_box_center_y = 0; + T target_box_width = 0, target_box_height = 0; + target_box_center_x = + prior_box_var_data[0] * target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = + prior_box_var_data[1] * target_box_data[offset + 1] * prior_box_height + + prior_box_center_y; + target_box_width = expf(dw) * prior_box_width; + target_box_height = expf(dh) * prior_box_height; + + output_box_data[offset] = target_box_center_x - target_box_width / 2; + output_box_data[offset + 1] = target_box_center_y - target_box_height / 2; + output_box_data[offset + 2] = + target_box_center_x + target_box_width / 2 - 1; + output_box_data[offset + 3] = + target_box_center_y + target_box_height / 2 - 1; + } +} + +template +__global__ void AssignBoxKernel(const T* prior_box_data, + const T* box_score_data, T* output_box_data, + const int roi_num, const int class_num, + T* output_assign_box_data) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < roi_num) { + int i = idx; + T max_score = -1; + int max_j = -1; + for (int j = 0; j < class_num; ++j) { + T score = box_score_data[i * class_num + j]; + if (score > max_score && j > 0) { + max_score = score; + max_j = j; + } + } + if (max_j > 0) { + for (int pno = 0; pno < 4; pno++) { + output_assign_box_data[i * 4 + pno] = + output_box_data[i * class_num * 4 + max_j * 4 + pno]; + } + } else { + for (int pno = 0; pno < 4; pno++) { + output_assign_box_data[i * 4 + pno] = prior_box_data[i * 4 + pno]; + } + } + } +} + +template +class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + auto* prior_box = context.Input("PriorBox"); + auto* prior_box_var = context.Input("PriorBoxVar"); + auto* target_box = context.Input("TargetBox"); + auto* box_score = context.Input("BoxScore"); + auto* output_box = context.Output("OutputBox"); + auto* output_assign_box = + context.Output("OutputAssignBox"); + + auto roi_num = target_box->dims()[0]; + auto class_num = box_score->dims()[1]; + auto* target_box_data = target_box->data(); + auto* prior_box_data = prior_box->data(); + auto* prior_box_var_data = prior_box_var->data(); + auto* box_score_data = box_score->data(); + output_box->mutable_data({roi_num, class_num * 4}, context.GetPlace()); + output_assign_box->mutable_data({roi_num, 4}, context.GetPlace()); + T* output_box_data = output_box->data(); + T* output_assign_box_data = output_assign_box->data(); + + int block = 512; + int grid = (roi_num * class_num + block - 1) / block; + auto& device_ctx = context.cuda_device_context(); + + const T box_clip = context.Attr("box_clip"); + + DecodeBoxKernel<<>>( + prior_box_data, prior_box_var_data, target_box_data, roi_num, class_num, + box_clip, output_box_data); + + context.device_context().Wait(); + int assign_grid = (roi_num + block - 1) / block; + AssignBoxKernel<<>>( + prior_box_data, box_score_data, output_box_data, roi_num, class_num, + output_assign_box_data); + context.device_context().Wait(); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + box_decoder_and_assign, + ops::BoxDecoderAndAssignCUDAKernel, + ops::BoxDecoderAndAssignCUDAKernel); diff --git a/paddle/fluid/operators/detection/box_decoder_and_assign_op.h b/paddle/fluid/operators/detection/box_decoder_and_assign_op.h new file mode 100644 index 0000000000000000000000000000000000000000..ff343e5d44b05b6639b5a22d23530ca818f7a7a9 --- /dev/null +++ b/paddle/fluid/operators/detection/box_decoder_and_assign_op.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class BoxDecoderAndAssignKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* prior_box = context.Input("PriorBox"); + auto* prior_box_var = context.Input("PriorBoxVar"); + auto* target_box = context.Input("TargetBox"); + auto* box_score = context.Input("BoxScore"); + auto* output_box = context.Output("OutputBox"); + auto* output_assign_box = + context.Output("OutputAssignBox"); + int roi_num = target_box->dims()[0]; + int class_num = box_score->dims()[1]; + auto* target_box_data = target_box->data(); + auto* prior_box_data = prior_box->data(); + auto* prior_box_var_data = prior_box_var->data(); + auto* box_score_data = box_score->data(); + output_box->mutable_data({roi_num, class_num * 4}, context.GetPlace()); + output_assign_box->mutable_data({roi_num, 4}, context.GetPlace()); + T* output_box_data = output_box->data(); + T* output_assign_box_data = output_assign_box->data(); + const T bbox_clip = context.Attr("box_clip"); + + for (int i = 0; i < roi_num; ++i) { + T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1; + T prior_box_height = + prior_box_data[i * 4 + 3] - prior_box_data[i * 4 + 1] + 1; + T prior_box_center_x = prior_box_data[i * 4] + prior_box_width / 2; + T prior_box_center_y = prior_box_data[i * 4 + 1] + prior_box_height / 2; + for (int j = 0; j < class_num; ++j) { + int64_t offset = i * class_num * 4 + j * 4; + T dw = std::min(prior_box_var_data[2] * target_box_data[offset + 2], + bbox_clip); + T dh = std::min(prior_box_var_data[3] * target_box_data[offset + 3], + bbox_clip); + T target_box_center_x = 0, target_box_center_y = 0; + T target_box_width = 0, target_box_height = 0; + target_box_center_x = + prior_box_var_data[0] * target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = prior_box_var_data[1] * + target_box_data[offset + 1] * + prior_box_height + + prior_box_center_y; + target_box_width = std::exp(dw) * prior_box_width; + target_box_height = std::exp(dh) * prior_box_height; + + output_box_data[offset] = target_box_center_x - target_box_width / 2; + output_box_data[offset + 1] = + target_box_center_y - target_box_height / 2; + output_box_data[offset + 2] = + target_box_center_x + target_box_width / 2 - 1; + output_box_data[offset + 3] = + target_box_center_y + target_box_height / 2 - 1; + } + + T max_score = -1; + int max_j = -1; + for (int j = 0; j < class_num; ++j) { + T score = box_score_data[i * class_num + j]; + if (score > max_score && j > 0) { + max_score = score; + max_j = j; + } + } + + if (max_j > 0) { + for (int pno = 0; pno < 4; pno++) { + output_assign_box_data[i * 4 + pno] = + output_box_data[i * class_num * 4 + max_j * 4 + pno]; + } + } else { + for (int pno = 0; pno < 4; pno++) { + output_assign_box_data[i * 4 + pno] = prior_box_data[i * 4 + pno]; + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 61a7d4f31d5245e635e2e1fe33e418ce20e94180..4ee92cd5c69077f766c803df8bb4f2f216c6825d 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -51,6 +51,7 @@ __all__ = [ 'yolov3_loss', 'box_clip', 'multiclass_nms', + 'box_decoder_and_assign', ] @@ -2221,3 +2222,53 @@ def multiclass_nms(bboxes, output.stop_gradient = True return output + + +@templatedoc() +def box_decoder_and_assign(prior_box, prior_box_var, target_box, box_score, + box_clip): + """ + ${comment} + Args: + prior_box(${prior_box_type}): ${prior_box_comment} + prior_box_var(${prior_box_var_type}): ${prior_box_var_comment} + target_box(${target_box_type}): ${target_box_comment} + box_score(${box_score_type}): ${box_score_comment} + Returns: + output_box(${output_box_type}): ${output_box_comment} + output_assign_box(${output_assign_box_type}): ${output_assign_box_comment} + Examples: + .. code-block:: python + + pb = fluid.layers.data(name='prior_box', shape=[20, 4], + dtype='float32') + pbv = fluid.layers.data(name='prior_box_var', shape=[1, 4], + dtype='float32') + loc = fluid.layers.data(name='target_box', shape=[20, 4*81], + dtype='float32') + scores = fluid.layers.data(name='scores', shape=[20, 81], + dtype='float32') + output_box, output_assign_box = fluid.layers.box_decoder_and_assign(pb, pbv, loc, scores, 4.135) + + """ + helper = LayerHelper("box_decoder_and_assign", **locals()) + + output_box = helper.create_variable_for_type_inference( + dtype=prior_box.dtype) + output_assign_box = helper.create_variable_for_type_inference( + dtype=prior_box.dtype) + + helper.append_op( + type="box_decoder_and_assign", + inputs={ + "PriorBox": prior_box, + "PriorBoxVar": prior_box_var, + "TargetBox": target_box, + "BoxScore": box_score + }, + attrs={"box_clip": box_clip}, + outputs={ + "OutputBox": output_box, + "OutputAssignBox": output_assign_box + }) + return output_box, output_assign_box diff --git a/python/paddle/fluid/tests/unittests/test_box_decoder_and_assign_op.py b/python/paddle/fluid/tests/unittests/test_box_decoder_and_assign_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b136c90f2d6c049ab1d000ac8d077769c181a7f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_box_decoder_and_assign_op.py @@ -0,0 +1,96 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest + + +def box_decoder_and_assign(deltas, weights, boxes, box_score, box_clip): + boxes = boxes.astype(deltas.dtype, copy=False) + widths = boxes[:, 2] - boxes[:, 0] + 1.0 + heights = boxes[:, 3] - boxes[:, 1] + 1.0 + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + wx, wy, ww, wh = weights + dx = deltas[:, 0::4] * wx + dy = deltas[:, 1::4] * wy + dw = deltas[:, 2::4] * ww + dh = deltas[:, 3::4] * wh + # Prevent sending too large values into np.exp() + dw = np.minimum(dw, box_clip) + dh = np.minimum(dh, box_clip) + pred_ctr_x = dx * widths[:, np.newaxis] + ctr_x[:, np.newaxis] + pred_ctr_y = dy * heights[:, np.newaxis] + ctr_y[:, np.newaxis] + pred_w = np.exp(dw) * widths[:, np.newaxis] + pred_h = np.exp(dh) * heights[:, np.newaxis] + pred_boxes = np.zeros(deltas.shape, dtype=deltas.dtype) + # x1 + pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w + # y1 + pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h + # x2 (note: "- 1" is correct; don't be fooled by the asymmetry) + pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1 + # y2 (note: "- 1" is correct; don't be fooled by the asymmetry) + pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1 + + output_assign_box = [] + for ino in range(len(pred_boxes)): + rank = np.argsort(-box_score[ino]) + maxidx = rank[0] + if maxidx == 0: + maxidx = rank[1] + beg_pos = maxidx * 4 + end_pos = maxidx * 4 + 4 + output_assign_box.append(pred_boxes[ino, beg_pos:end_pos]) + output_assign_box = np.array(output_assign_box) + + return pred_boxes, output_assign_box + + +class TestBoxDecoderAndAssignOpWithLoD(OpTest): + def test_check_output(self): + self.check_output() + + def setUp(self): + self.op_type = "box_decoder_and_assign" + lod = [[4, 8, 8]] + num_classes = 10 + prior_box = np.random.random((20, 4)).astype('float32') + prior_box_var = np.array([0.1, 0.1, 0.2, 0.2], dtype=np.float32) + target_box = np.random.random((20, 4 * num_classes)).astype('float32') + box_score = np.random.random((20, num_classes)).astype('float32') + box_clip = 4.135 + output_box, output_assign_box = box_decoder_and_assign( + target_box, prior_box_var, prior_box, box_score, box_clip) + + self.inputs = { + 'PriorBox': (prior_box, lod), + 'PriorBoxVar': prior_box_var, + 'TargetBox': (target_box, lod), + 'BoxScore': (box_score, lod), + } + self.attrs = {'box_clip': box_clip} + self.outputs = { + 'OutputBox': output_box, + 'OutputAssignBox': output_assign_box + } + + +if __name__ == '__main__': + unittest.main()