From 3896d955c75aa537a30e1a8dde9ad02f6540bb73 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 18 Feb 2019 21:14:40 +0800 Subject: [PATCH] add yolo_box_op CPU kernel --- .../fluid/operators/detection/CMakeLists.txt | 1 + .../fluid/operators/detection/yolo_box_op.cc | 144 ++++++++++++++++++ .../fluid/operators/detection/yolo_box_op.cu | 71 +++++++++ .../fluid/operators/detection/yolo_box_op.h | 127 +++++++++++++++ python/paddle/fluid/layers/detection.py | 66 ++++++++ python/paddle/fluid/tests/test_detection.py | 9 +- .../fluid/tests/unittests/test_yolo_box_op.py | 105 +++++++++++++ .../tests/unittests/test_yolov3_loss_op.py | 12 +- 8 files changed, 526 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/detection/yolo_box_op.cc create mode 100644 paddle/fluid/operators/detection/yolo_box_op.cu create mode 100644 paddle/fluid/operators/detection/yolo_box_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_yolo_box_op.py diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index c87837e69..94a2016aa 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc) detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) +detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) if(WITH_GPU) diff --git a/paddle/fluid/operators/detection/yolo_box_op.cc b/paddle/fluid/operators/detection/yolo_box_op.cc new file mode 100644 index 000000000..4c2c5d1e6 --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/operators/detection/yolo_box_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class YoloBoxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of YoloBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Boxes"), + "Output(Boxes) of YoloBoxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Scores"), + "Output(Scores) of YoloBoxOp should not be null."); + + auto dim_x = ctx->GetInputDim("X"); + auto anchors = ctx->Attrs().Get>("anchors"); + int anchor_num = anchors.size() / 2; + auto class_num = ctx->Attrs().Get("class_num"); + auto conf_thresh = ctx->Attrs().Get("conf_thresh"); + + PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor."); + PADDLE_ENFORCE_EQ( + dim_x[1], anchor_num * (5 + class_num), + "Input(X) dim[1] should be equal to (anchor_mask_number * (5 " + "+ class_num))."); + PADDLE_ENFORCE_GT(anchors.size(), 0, + "Attr(anchors) length should be greater then 0."); + PADDLE_ENFORCE_EQ(anchors.size() % 2, 0, + "Attr(anchors) length should be even integer."); + PADDLE_ENFORCE_GT(class_num, 0, + "Attr(class_num) should be an integer greater then 0."); + + int box_num = dim_x[2] * dim_x[3] * anchor_num; + std::vector dim_boxes({dim_x[0], box_num, 4}); + ctx->SetOutputDim("Boxes", framework::make_ddim(dim_boxes)); + + std::vector dim_scores({dim_x[0], box_num, class_num}); + ctx->SetOutputDim("Scores", framework::make_ddim(dim_scores)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class YoloBoxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of YoloBox operator, " + "This is a 4-D tensor with shape of [N, C, H, W]." + "H and W should be same, and the second dimention(C) stores" + "box locations, confidence score and classification one-hot" + "keys of each anchor box. Generally, X should be the output" + "of YOLOv3 network."); + AddOutput("Boxes", + "The output tensor of detection boxes of YoloBox operator, " + "This is a 3-D tensor with shape of [N, M, 4], N is the" + "batch num, M is output box number, and the 3rd dimention" + "stores [xmin, ymin, xmax, ymax] coordinates of boxes."); + AddOutput("Scores", + "The output tensor ofdetection boxes scores of YoloBox" + "operator, This is a 3-D tensor with shape of [N, M, C]," + "N is the batch num, M is output box number, C is the" + "class number."); + + AddAttr("class_num", "The number of classes to predict."); + AddAttr>("anchors", + "The anchor width and height, " + "it will be parsed pair by pair.") + .SetDefault(std::vector{}); + AddAttr("downsample_ratio", + "The downsample ratio from network input to YoloBox operator " + "input, so 32, 16, 8 should be set for the first, second, " + "and thrid YoloBox operators.") + .SetDefault(32); + AddAttr("conf_thresh", + "The confidence scores threshold of detection boxes." + "boxes with confidence scores under threshold should" + "be ignored.") + .SetDefault(0.01); + AddComment(R"DOC( + This operator generate YOLO detection boxes fron output of YOLOv3 network. + + The output of previous network is in shape [N, C, H, W], while H and W + should be the same, specify the grid size, each grid point predict given + number boxes, this given number is specified by anchors, it should be + half anchors length, which following will be represented as S. In the + second dimention(the channel dimention), C should be S * (class_num + 5), + class_num is the box categoriy number of source dataset(such as coco), + so in the second dimention, stores 4 box location coordinates x, y, w, h + and confidence score of the box and class one-hot key of each anchor box. + + While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions + correspnd to: + + $$ + b_x = \sigma(t_x) + c_x + b_y = \sigma(t_y) + c_y + b_w = p_w e^{t_w} + b_h = p_h e^{t_h} + $$ + + While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$ + is specified by anchors. + + The logistic scores of the 5rd channel of each anchor prediction boxes + represent the confidence score of each prediction scores, and the logistic + scores of the last class_num channels of each anchor prediction boxes + represent the classifcation scores. Boxes with confidence scores less then + conf_thresh should be ignored, and boxes final scores if the products result + of confidence scores and classification scores. + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(yolo_box, ops::YoloBoxOp, ops::YoloBoxOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(yolo_box, ops::YoloBoxKernel, + ops::YoloBoxKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op.cu b/paddle/fluid/operators/detection/yolo_box_op.cu new file mode 100644 index 000000000..9cc94794f --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op.cu @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/yolo_box_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +static __global__ void GenDensityPriorBox( + const int height, const int width, const int im_height, const int im_width, + const T offset, const T step_width, const T step_height, + const int num_priors, const T* ratios_shift, bool is_clip, const T var_xmin, + const T var_ymin, const T var_xmax, const T var_ymax, T* out, T* var) { + int gidx = blockIdx.x * blockDim.x + threadIdx.x; + int gidy = blockIdx.y * blockDim.y + threadIdx.y; + int step_x = blockDim.x * gridDim.x; + int step_y = blockDim.y * gridDim.y; +} + +template +class YoloBoxOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* boxes = ctx.Output("Boxes"); + auto* scores = ctx.Output("Scores"); + + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float conf_thresh = ctx.Attr("conf_thresh"); + int downsample_ratio = ctx.Attr("downsample_ratio"); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int box_num = boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int input_size = downsample_ratio * h; + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + const T* input_data = input->data(); + T* boxes_data = boxes->mutable_data({n}, ctx.GetPlace()); + memset(loss_data, 0, boxes->numel() * sizeof(T)); + T* scores_data = scores->mutable_data({n}, ctx.GetPlace()); + memset(scores_data, 0, scores->numel() * sizeof(T)); + } +}; // namespace operators + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(density_prior_box, + ops::DensityPriorBoxOpCUDAKernel, + ops::DensityPriorBoxOpCUDAKernel); diff --git a/paddle/fluid/operators/detection/yolo_box_op.h b/paddle/fluid/operators/detection/yolo_box_op.h new file mode 100644 index 000000000..7a9ebf46d --- /dev/null +++ b/paddle/fluid/operators/detection/yolo_box_op.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct Box { + T x, y, w, h; +}; + +template +static inline T sigmoid(T x) { + return 1.0 / (1.0 + std::exp(-x)); +} + +template +static inline Box GetYoloBox(const T* x, std::vector anchors, int i, + int j, int an_idx, int grid_size, + int input_size, int index, int stride) { + Box b; + b.x = (i + sigmoid(x[index])) * input_size / grid_size; + b.y = (j + sigmoid(x[index + stride])) * input_size / grid_size; + b.w = std::exp(x[index + 2 * stride]) * anchors[2 * an_idx]; + b.h = std::exp(x[index + 3 * stride]) * anchors[2 * an_idx + 1]; + return b; +} + +static inline int GetEntryIndex(int batch, int an_idx, int hw_idx, int an_num, + int an_stride, int stride, int entry) { + return (batch * an_num + an_idx) * an_stride + entry * stride + hw_idx; +} + +template +static inline void CalcDetectionBox(T* boxes, Box pred, const int box_idx) { + boxes[box_idx] = pred.x - pred.w / 2; + boxes[box_idx + 1] = pred.y - pred.h / 2; + boxes[box_idx + 2] = pred.x + pred.w / 2; + boxes[box_idx + 3] = pred.y + pred.h / 2; +} + +template +static inline void CalcLabelScore(T* scores, const T* input, + const int label_idx, const int score_idx, + const int class_num, const T conf, + const int stride) { + for (int i = 0; i < class_num; i++) { + scores[score_idx + i] = conf * sigmoid(input[label_idx + i * stride]); + } +} + +template +class YoloBoxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* boxes = ctx.Output("Boxes"); + auto* scores = ctx.Output("Scores"); + auto anchors = ctx.Attr>("anchors"); + int class_num = ctx.Attr("class_num"); + float conf_thresh = ctx.Attr("conf_thresh"); + int downsample_ratio = ctx.Attr("downsample_ratio"); + + const int n = input->dims()[0]; + const int h = input->dims()[2]; + const int w = input->dims()[3]; + const int box_num = boxes->dims()[1]; + const int an_num = anchors.size() / 2; + int input_size = downsample_ratio * h; + + const int stride = h * w; + const int an_stride = (class_num + 5) * stride; + + const T* input_data = input->data(); + T* boxes_data = boxes->mutable_data({n, box_num, 4}, ctx.GetPlace()); + memset(boxes_data, 0, boxes->numel() * sizeof(T)); + T* scores_data = + scores->mutable_data({n, box_num, class_num}, ctx.GetPlace()); + memset(scores_data, 0, scores->numel() * sizeof(T)); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < an_num; j++) { + for (int k = 0; k < h; k++) { + for (int l = 0; l < w; l++) { + int obj_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 4); + T conf = sigmoid(input_data[obj_idx]); + if (conf < conf_thresh) { + continue; + } + + int box_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 0); + Box pred = GetYoloBox(input_data, anchors, l, k, j, h, + input_size, box_idx, stride); + box_idx = (i * box_num + j * stride + k * w + l) * 4; + CalcDetectionBox(boxes_data, pred, box_idx); + + int label_idx = + GetEntryIndex(i, j, k * w + l, an_num, an_stride, stride, 5); + int score_idx = (i * box_num + j * stride + k * w + l) * class_num; + CalcLabelScore(scores_data, input_data, label_idx, score_idx, + class_num, conf, stride); + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index cbedd70f8..29020f824 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -49,6 +49,7 @@ __all__ = [ 'box_coder', 'polygon_box_transform', 'yolov3_loss', + 'yolo_box', 'box_clip', 'multiclass_nms', 'distribute_fpn_proposals', @@ -609,6 +610,71 @@ def yolov3_loss(x, return loss +@templatedoc(op_type="yolo_box") +def yolo_box(x, anchors, class_num, conf_thresh, downsample_ratio, name=None): + """ + ${comment} + + Args: + x (Variable): ${x_comment} + anchors (list|tuple): ${anchors_comment} + class_num (int): ${class_num_comment} + conf_thresh (float): ${conf_thresh_comment} + downsample_ratio (int): ${downsample_ratio_comment} + name (string): the name of yolov3 loss + + Returns: + Variable: A 1-D tensor with shape [1], the value of yolov3 loss + + Raises: + TypeError: Input x of yolov_box must be Variable + TypeError: Attr anchors of yolo box must be list or tuple + TypeError: Attr class_num of yolo box must be an integer + TypeError: Attr conf_thresh of yolo box must be a float number + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[255, 13, 13], dtype='float32') + anchors = [10, 13, 16, 30, 33, 23] + loss = fluid.layers.yolov3_loss(x=x, class_num=80, anchors=anchors, + conf_thresh=0.01, downsample_ratio=32) + """ + helper = LayerHelper('yolo_box', **locals()) + + if not isinstance(x, Variable): + raise TypeError("Input x of yolov3_loss must be Variable") + if not isinstance(anchors, list) and not isinstance(anchors, tuple): + raise TypeError("Attr anchors of yolov3_loss must be list or tuple") + if not isinstance(anchor_mask, list) and not isinstance(anchor_mask, tuple): + raise TypeError("Attr anchor_mask of yolov3_loss must be list or tuple") + if not isinstance(class_num, int): + raise TypeError("Attr class_num of yolov3_loss must be an integer") + if not isinstance(conf_thresh, float): + raise TypeError( + "Attr ignore_thresh of yolov3_loss must be a float number") + + boxes = helper.create_variable_for_type_inference(dtype=x.dtype) + scores = helper.create_variable_for_type_inference(dtype=x.dtype) + + attrs = { + "anchors": anchors, + "class_num": class_num, + "conf_thresh": ignore_thresh, + "downsample_ratio": downsample_ratio, + } + + helper.append_op( + type='yolo_box', + inputs={"X": x, }, + outputs={ + 'Boxes': boxes, + 'Scores': scores, + }, + attrs=attrs) + return boxes, scores + + @templatedoc() def detection_map(detect_res, label, diff --git a/python/paddle/fluid/tests/test_detection.py b/python/paddle/fluid/tests/test_detection.py index 6218db734..9592bbe2e 100644 --- a/python/paddle/fluid/tests/test_detection.py +++ b/python/paddle/fluid/tests/test_detection.py @@ -478,9 +478,16 @@ class TestYoloDetection(unittest.TestCase): gtlabel = layers.data(name='gtlabel', shape=[10], dtype='int32') loss = layers.yolov3_loss(x, gtbox, gtlabel, [10, 13, 30, 13], [0, 1], 10, 0.7, 32) - self.assertIsNotNone(loss) + def test_yolo_box(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[30, 7, 7], dtype='float32') + boxes, scores = layers.yolo_box(x, [10, 13, 30, 13], 10, 0.01, 32) + self.assertIsNotNone(boxes) + self.assertIsNotNone(scores) + class TestBoxClip(unittest.TestCase): def test_box_clip(self): diff --git a/python/paddle/fluid/tests/unittests/test_yolo_box_op.py b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py new file mode 100644 index 000000000..bed0be9a5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_yolo_box_op.py @@ -0,0 +1,105 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import unittest +import numpy as np +from op_test import OpTest + +from paddle.fluid import core + + +def sigmoid(x): + return 1.0 / (1.0 + np.exp(-1.0 * x)) + + +def YoloBox(x, attrs): + n, c, h, w = x.shape + anchors = attrs['anchors'] + an_num = int(len(anchors) // 2) + class_num = attrs['class_num'] + conf_thresh = attrs['conf_thresh'] + downsample = attrs['downsample'] + input_size = downsample * h + + x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) + + pred_box = x[:, :, :, :, :4].copy() + grid_x = np.tile(np.arange(w).reshape((1, w)), (h, 1)) + grid_y = np.tile(np.arange(h).reshape((h, 1)), (1, w)) + pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w + pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h + + anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] + anchors_s = np.array( + [(an_w / input_size, an_h / input_size) for an_w, an_h in anchors]) + anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) + anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) + pred_box[:, :, :, :, 2] = np.exp(pred_box[:, :, :, :, 2]) * anchor_w + pred_box[:, :, :, :, 3] = np.exp(pred_box[:, :, :, :, 3]) * anchor_h + + pred_conf = sigmoid(x[:, :, :, :, 4:5]) + pred_conf[pred_conf < conf_thresh] = 0. + pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf + pred_box = pred_box * (pred_conf > 0.).astype('float32') + + pred_box = pred_box.reshape((n, -1, 4)) + pred_box[:, :, : + 2], pred_box[:, :, 2: + 4] = pred_box[:, :, : + 2] - pred_box[:, :, 2: + 4] / 2., pred_box[:, :, : + 2] + pred_box[:, :, + 2: + 4] / 2.0 + pred_box = pred_box * input_size + + return pred_box, pred_score.reshape((n, -1, class_num)) + + +class TestYoloBoxOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'yolo_box' + x = np.random.random(self.x_shape).astype('float32') + + self.attrs = { + "anchors": self.anchors, + "class_num": self.class_num, + "conf_thresh": self.conf_thresh, + "downsample": self.downsample, + } + + self.inputs = {'X': x, } + boxes, scores = YoloBox(x, self.attrs) + self.outputs = { + "Boxes": boxes, + "Scores": scores, + } + + def test_check_output(self): + self.check_output() + + def initTestCase(self): + self.anchors = [10, 13, 16, 30, 33, 23] + an_num = int(len(self.anchors) // 2) + self.class_num = 2 + self.conf_thresh = 0.5 + self.downsample = 32 + self.x_shape = (3, an_num * (5 + self.class_num), 5, 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py index 020c11392..569fe63d0 100644 --- a/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_yolov3_loss_op.py @@ -75,8 +75,8 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): mask_num = len(anchor_mask) class_num = attrs["class_num"] ignore_thresh = attrs['ignore_thresh'] - downsample = attrs['downsample'] - input_size = downsample * h + downsample_ratio = attrs['downsample_ratio'] + input_size = downsample_ratio * h x = x.reshape((n, mask_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2)) loss = np.zeros((n)).astype('float32') @@ -86,10 +86,6 @@ def YOLOv3Loss(x, gtbox, gtlabel, attrs): pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0])) / w pred_box[:, :, :, :, 1] = (grid_y + sigmoid(pred_box[:, :, :, :, 1])) / h - x[:, :, :, :, 5:] = np.where(x[:, :, :, :, 5:] < -0.5, x[:, :, :, :, 5:], - np.ones_like(x[:, :, :, :, 5:]) * 1.0 / - class_num) - mask_anchors = [] for m in anchor_mask: mask_anchors.append((anchors[2 * m], anchors[2 * m + 1])) @@ -176,7 +172,7 @@ class TestYolov3LossOp(OpTest): "anchor_mask": self.anchor_mask, "class_num": self.class_num, "ignore_thresh": self.ignore_thresh, - "downsample": self.downsample, + "downsample_ratio": self.downsample_ratio, } self.inputs = { @@ -208,7 +204,7 @@ class TestYolov3LossOp(OpTest): self.anchor_mask = [1, 2] self.class_num = 5 self.ignore_thresh = 0.5 - self.downsample = 32 + self.downsample_ratio = 32 self.x_shape = (3, len(self.anchor_mask) * (5 + self.class_num), 5, 5) self.gtbox_shape = (3, 5, 4) -- GitLab