提交 452373de 编写于 作者: D dengkaipeng

resize box in input image scale. test=develop

上级 3896d955
......@@ -23,12 +23,15 @@ class YoloBoxOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImgSize"),
"Input(ImgSize) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Boxes"),
"Output(Boxes) of YoloBoxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Scores"),
"Output(Scores) of YoloBoxOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto dim_imgsize = ctx->GetInputDim("ImgSize");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
int anchor_num = anchors.size() / 2;
auto class_num = ctx->Attrs().Get<int>("class_num");
......@@ -39,6 +42,12 @@ class YoloBoxOp : public framework::OperatorWithKernel {
dim_x[1], anchor_num * (5 + class_num),
"Input(X) dim[1] should be equal to (anchor_mask_number * (5 "
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_imgsize.size(), 2,
"Input(ImgSize) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
dim_imgsize[0], dim_x[0],
"Input(ImgSize) dim[0] and Input(X) dim[0] should be same.");
PADDLE_ENFORCE_EQ(dim_imgsize[1], 2, "Input(ImgSize) dim[1] should be 2.");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater then 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
......@@ -72,6 +81,11 @@ class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"box locations, confidence score and classification one-hot"
"keys of each anchor box. Generally, X should be the output"
"of YOLOv3 network.");
AddInput("ImgSize",
"The image size tensor of YoloBox operator, "
"This is a 2-D tensor with shape of [N, 2]. This tensor holds"
"height and width of each input image using for resize output"
"box in input image scale.");
AddOutput("Boxes",
"The output tensor of detection boxes of YoloBox operator, "
"This is a 3-D tensor with shape of [N, M, 4], N is the"
......
......@@ -32,12 +32,15 @@ static inline T sigmoid(T x) {
template <typename T>
static inline Box<T> GetYoloBox(const T* x, std::vector<int> anchors, int i,
int j, int an_idx, int grid_size,
int input_size, int index, int stride) {
int input_size, int index, int stride,
int img_height, int img_width) {
Box<T> b;
b.x = (i + sigmoid<T>(x[index])) * input_size / grid_size;
b.y = (j + sigmoid<T>(x[index + stride])) * input_size / grid_size;
b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx];
b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1];
b.x = (i + sigmoid<T>(x[index])) * img_width / grid_size;
b.y = (j + sigmoid<T>(x[index + stride])) * img_height / grid_size;
b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx] * img_width /
input_size;
b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1] * img_height /
input_size;
return b;
}
......@@ -69,6 +72,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* imgsize = ctx.Input<Tensor>("ImgSize");
auto* boxes = ctx.Output<Tensor>("Boxes");
auto* scores = ctx.Output<Tensor>("Scores");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
......@@ -87,6 +91,7 @@ class YoloBoxKernel : public framework::OpKernel<T> {
const int an_stride = (class_num + 5) * stride;
const T* input_data = input->data<T>();
const int* imgsize_data = imgsize->data<int>();
T* boxes_data = boxes->mutable_data<T>({n, box_num, 4}, ctx.GetPlace());
memset(boxes_data, 0, boxes->numel() * sizeof(T));
T* scores_data =
......@@ -94,6 +99,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
memset(scores_data, 0, scores->numel() * sizeof(T));
for (int i = 0; i < n; i++) {
int img_height = imgsize_data[2 * i];
int img_width = imgsize_data[2 * i + 1];
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
......@@ -106,8 +114,9 @@ class YoloBoxKernel : public framework::OpKernel<T> {
int box_idx =
GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0);
Box<T> pred = GetYoloBox(input_data, anchors, l, k, j, h,
input_size, box_idx, stride);
Box<T> pred =
GetYoloBox(input_data, anchors, l, k, j, h, input_size, box_idx,
stride, img_height, img_width);
box_idx = (i * box_num + j * stride + k * w + l) * 4;
CalcDetectionBox<T>(boxes_data, pred, box_idx);
......
......@@ -611,12 +611,19 @@ def yolov3_loss(x,
@templatedoc(op_type="yolo_box")
def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
def yolo_box(x,
img_size,
anchors,
class_num,
conf_thresh,
downsample_ratio,
name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
img_size (Variable): ${img_size_comment}
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
conf_thresh (float): ${conf_thresh_comment}
......@@ -643,16 +650,17 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
helper = LayerHelper('yolo_box', **locals())
if not isinstance(x, Variable):
raise TypeError("Input x of yolov3_loss must be Variable")
raise TypeError("Input x of yolo_box must be Variable")
if not isinstance(img_size, Variable):
raise TypeError("Input img_size of yolo_box must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
raise TypeError("Attr anchors of yolo_box must be list or tuple")
if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple):
raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple")
raise TypeError("Attr anchor_mask of yolo_box must be list or tuple")
if not isinstance(class_num, int):
raise TypeError("Attr class_num of yolov3_loss must be an integer")
raise TypeError("Attr class_num of yolo_box must be an integer")
if not isinstance(conf_thresh, float):
raise TypeError(
"Attr ignore_thresh of yolov3_loss must be a float number")
raise TypeError("Attr ignore_thresh of yolo_box must be a float number")
boxes = helper.create_variable_for_type_inference(dtype=x.dtype)
scores = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -666,7 +674,10 @@ def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None):
helper.append_op(
type='yolo_box',
inputs={"X": x, },
inputs={
"X": x,
"ImgSize": img_size,
},
outputs={
'Boxes': boxes,
'Scores': scores,
......
......@@ -484,7 +484,9 @@ class TestYoloDetection(unittest.TestCase):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
boxes, scores = layers.yolo_box(x, [10, 13, 30, 13], 10, 0.01, 32)
img_size = layers.data(name='x', shape=[2], dtype='int32')
boxes, scores = layers.yolo_box(x, img_size, [10, 13, 30, 13], 10,
0.01, 32)
self.assertIsNotNone(boxes)
self.assertIsNotNone(scores)
......
......@@ -25,7 +25,7 @@ def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
def YoloBox(x, attrs):
def YoloBox(x, img_size, attrs):
n, c, h, w = x.shape
anchors = attrs['anchors']
an_num = int(len(anchors) // 2)
......@@ -56,15 +56,14 @@ def YoloBox(x, attrs):
pred_box = pred_box * (pred_conf > 0.).astype('float32')
pred_box = pred_box.reshape((n, -1, 4))
pred_box[:, :, :
2], pred_box[:, :, 2:
4] = pred_box[:, :, :
2] - pred_box[:, :, 2:
4] / 2., pred_box[:, :, :
2] + pred_box[:, :,
2:
4] / 2.0
pred_box = pred_box * input_size
pred_box[:, :, :2], pred_box[:, :, 2:4] = \
pred_box[:, :, :2] - pred_box[:, :, 2:4] / 2., \
pred_box[:, :, :2] + pred_box[:, :, 2:4] / 2.0
# pred_box = pred_box * input_size
pred_box[:, :, 0] = pred_box[:, :, 0] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 1] = pred_box[:, :, 1] * img_size[:, 0][:, np.newaxis]
pred_box[:, :, 2] = pred_box[:, :, 2] * img_size[:, 1][:, np.newaxis]
pred_box[:, :, 3] = pred_box[:, :, 3] * img_size[:, 0][:, np.newaxis]
return pred_box, pred_score.reshape((n, -1, class_num))
......@@ -74,6 +73,7 @@ class TestYoloBoxOp(OpTest):
self.initTestCase()
self.op_type = 'yolo_box'
x = np.random.random(self.x_shape).astype('float32')
img_size = np.random.randint(10, 20, self.imgsize_shape).astype('int32')
self.attrs = {
"anchors": self.anchors,
......@@ -82,8 +82,11 @@ class TestYoloBoxOp(OpTest):
"downsample": self.downsample,
}
self.inputs = {'X': x, }
boxes, scores = YoloBox(x, self.attrs)
self.inputs = {
'X': x,
'ImgSize': img_size,
}
boxes, scores = YoloBox(x, img_size, self.attrs)
self.outputs = {
"Boxes": boxes,
"Scores": scores,
......@@ -95,10 +98,12 @@ class TestYoloBoxOp(OpTest):
def initTestCase(self):
self.anchors = [10, 13, 16, 30, 33, 23]
an_num = int(len(self.anchors) // 2)
self.batch_size = 3
self.class_num = 2
self.conf_thresh = 0.5
self.downsample = 32
self.x_shape = (3, an_num * (5 + self.class_num), 5, 5)
self.x_shape = (self.batch_size, an_num * (5 + self.class_num), 5, 5)
self.imgsize_shape = (self.batch_size, 2)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册