diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 64f9e03d7e061e0ed0ebef024896b338241e2062..3e07b1f155452d920c9c0acbc355935a12d1350f 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -54,7 +54,6 @@ detection_library(generate_proposal_labels_op SRCS detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc DEPS gpc) detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc) detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu) -detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc diff --git a/paddle/fluid/operators/detection/yolov3_loss_op.cc b/paddle/fluid/operators/detection/yolov3_loss_op.cc deleted file mode 100644 index 21aca33f65a1aaf45522887430b837db51b20d88..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detection/yolov3_loss_op.cc +++ /dev/null @@ -1,231 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/imperative/type_defs.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/backward.h" -#include "paddle/phi/infermeta/multiary.h" - -namespace paddle { -namespace operators { - -class Yolov3LossOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); - } -}; - -class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", - "The input tensor of YOLOv3 loss operator, " - "This is a 4-D tensor with shape of [N, C, H, W]." - "H and W should be same, and the second dimension(C) stores" - "box locations, confidence score and classification one-hot" - "keys of each anchor box"); - AddInput("GTBox", - "The input tensor of ground truth boxes, " - "This is a 3-D tensor with shape of [N, max_box_num, 5], " - "max_box_num is the max number of boxes in each image, " - "In the third dimension, stores x, y, w, h coordinates, " - "x, y is the center coordinate of boxes and w, h is the " - "width and height and x, y, w, h should be divided by " - "input image height to scale to [0, 1]."); - AddInput("GTLabel", - "The input tensor of ground truth label, " - "This is a 2-D tensor with shape of [N, max_box_num], " - "and each element should be an integer to indicate the " - "box class id."); - AddInput("GTScore", - "The score of GTLabel, This is a 2-D tensor in same shape " - "GTLabel, and score values should in range (0, 1). This " - "input is for GTLabel score can be not 1.0 in image mixup " - "augmentation.") - .AsDispensable(); - AddOutput("Loss", - "The output yolov3 loss tensor, " - "This is a 1-D tensor with shape of [N]"); - AddOutput("ObjectnessMask", - "This is an intermediate tensor with shape of [N, M, H, W], " - "M is the number of anchor masks. This parameter caches the " - "mask for calculate objectness loss in gradient kernel.") - .AsIntermediate(); - AddOutput("GTMatchMask", - "This is an intermediate tensor with shape of [N, B], " - "B is the max box number of GT boxes. This parameter caches " - "matched mask index of each GT boxes for gradient calculate.") - .AsIntermediate(); - - AddAttr("class_num", "The number of classes to predict."); - AddAttr>("anchors", - "The anchor width and height, " - "it will be parsed pair by pair.") - .SetDefault(std::vector{}); - AddAttr>("anchor_mask", - "The mask index of anchors used in " - "current YOLOv3 loss calculation.") - .SetDefault(std::vector{}); - AddAttr("downsample_ratio", - "The downsample ratio from network input to YOLOv3 loss " - "input, so 32, 16, 8 should be set for the first, second, " - "and thrid YOLOv3 loss operators.") - .SetDefault(32); - AddAttr("ignore_thresh", - "The ignore threshold to ignore confidence loss.") - .SetDefault(0.7); - AddAttr("use_label_smooth", - "Whether to use label smooth. Default True.") - .SetDefault(true); - AddAttr("scale_x_y", - "Scale the center point of decoded bounding " - "box. Default 1.0") - .SetDefault(1.); - AddComment(R"DOC( - This operator generates yolov3 loss based on given predict result and ground - truth boxes. - - The output of previous network is in shape [N, C, H, W], while H and W - should be the same, H and W specify the grid size, each grid point predict - given number bounding boxes, this given number, which following will be represented as S, - is specified by the number of anchor clusters in each scale. In the second dimension(the channel - dimension), C should be equal to S * (class_num + 5), class_num is the object - category number of source dataset(such as 80 in coco dataset), so in the - second(channel) dimension, apart from 4 box location coordinates x, y, w, h, - also includes confidence score of the box and class one-hot key of each anchor box. - - Assume the 4 location coordinates are :math:`t_x, t_y, t_w, t_h`, the box predictions - should be as follows: - - $$ - b_x = \\sigma(t_x) + c_x - $$ - $$ - b_y = \\sigma(t_y) + c_y - $$ - $$ - b_w = p_w e^{t_w} - $$ - $$ - b_h = p_h e^{t_h} - $$ - - In the equation above, :math:`c_x, c_y` is the left top corner of current grid - and :math:`p_w, p_h` is specified by anchors. - - As for confidence score, it is the logistic regression value of IoU between - anchor boxes and ground truth boxes, the score of the anchor box which has - the max IoU should be 1, and if the anchor box has IoU bigger than ignore - thresh, the confidence score loss of this anchor box will be ignored. - - Therefore, the yolov3 loss consists of three major parts: box location loss, - objectness loss and classification loss. The L1 loss is used for - box coordinates (w, h), sigmoid cross entropy loss is used for box - coordinates (x, y), objectness loss and classification loss. - - Each groud truth box finds a best matching anchor box in all anchors. - Prediction of this anchor box will incur all three parts of losses, and - prediction of anchor boxes with no GT box matched will only incur objectness - loss. - - In order to trade off box coordinate losses between big boxes and small - boxes, box coordinate losses will be mutiplied by scale weight, which is - calculated as follows. - - $$ - weight_{box} = 2.0 - t_w * t_h - $$ - - Final loss will be represented as follows. - - $$ - loss = (loss_{xy} + loss_{wh}) * weight_{box} - + loss_{conf} + loss_{class} - $$ - - While :attr:`use_label_smooth` is set to be :attr:`True`, the classification - target will be smoothed when calculating classification loss, target of - positive samples will be smoothed to :math:`1.0 - 1.0 / class\_num` and target of - negetive samples will be smoothed to :math:`1.0 / class\_num`. - - While :attr:`GTScore` is given, which means the mixup score of ground truth - boxes, all losses incured by a ground truth box will be multiplied by its - mixup score. - )DOC"); - } -}; - -class Yolov3LossOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - platform::CPUPlace()); - } -}; - -template -class Yolov3LossGradMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("yolov3_loss_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("GTBox", this->Input("GTBox")); - op->SetInput("GTLabel", this->Input("GTLabel")); - op->SetInput("GTScore", this->Input("GTScore")); - op->SetInput(framework::GradVarName("Loss"), this->OutputGrad("Loss")); - op->SetInput("ObjectnessMask", this->Output("ObjectnessMask")); - op->SetInput("GTMatchMask", this->Output("GTMatchMask")); - - op->SetAttrMap(this->Attrs()); - - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("GTBox"), this->EmptyInputGrad()); - op->SetOutput(framework::GradVarName("GTLabel"), this->EmptyInputGrad()); - op->SetOutput(framework::GradVarName("GTScore"), this->EmptyInputGrad()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -DECLARE_INFER_SHAPE_FUNCTOR(yolov3_loss, - Yolov3LossInferShapeFunctor, - PD_INFER_META(phi::YoloLossInferMeta)); -DECLARE_INFER_SHAPE_FUNCTOR(yolov3_loss_grad, - Yolov3LossGradInferShapeFunctor, - PD_INFER_META(phi::YoloLossGradInferMeta)); -REGISTER_OPERATOR(yolov3_loss, - ops::Yolov3LossOp, - ops::Yolov3LossOpMaker, - ops::Yolov3LossGradMaker, - ops::Yolov3LossGradMaker, - Yolov3LossInferShapeFunctor); -REGISTER_OPERATOR(yolov3_loss_grad, - ops::Yolov3LossOpGrad, - Yolov3LossGradInferShapeFunctor); diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.cc b/paddle/fluid/operators/generator/get_expected_kernel_func.cc index 6085dabaed6d880e6060d2c4235df2c8f5512d68..5c654d942c2456be2a8738e7c41f2cc23a739e3c 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.cc +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.cc @@ -156,6 +156,13 @@ phi::KernelKey GetMatrixNmsExpectedKernelType( platform::CPUPlace()); } +phi::KernelKey GetYoloLossExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr) { + return phi::KernelKey(op_ptr->IndicateVarDataType(ctx, "X"), + platform::CPUPlace()); +} + phi::KernelKey GetUniqueExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr) { diff --git a/paddle/fluid/operators/generator/get_expected_kernel_func.h b/paddle/fluid/operators/generator/get_expected_kernel_func.h index cbbb74e2312ed3d916bf276d62bc31e635743694..a6fe0fc4b225a3121d0c10f93bec50da8f71f185 100644 --- a/paddle/fluid/operators/generator/get_expected_kernel_func.h +++ b/paddle/fluid/operators/generator/get_expected_kernel_func.h @@ -48,5 +48,9 @@ phi::KernelKey GetUniqueExpectedKernelType( const framework::ExecutionContext& ctx, const framework::OperatorWithKernel* op_ptr); +phi::KernelKey GetYoloLossExpectedKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel* op_ptr); + } // namespace operators } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 7b3b0d225af109ac5232d744bbc822c3577c467b..bde673e60b60786b9c694158ad0556e743157322 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2065,6 +2065,16 @@ func : where_grad no_need_buffer : x, y +- backward_op : yolo_loss_grad + forward : yolo_loss (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors={}, int[] anchor_mask={}, int class_num =1 , float ignore_thresh=0.7, int downsample_ratio=32, bool use_label_smooth=true, float scale_x_y=1.0) -> Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask) + args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, Tensor objectness_mask, Tensor gt_match_mask, Tensor loss_grad, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth, float scale_x_y) + output : Tensor(x_grad), Tensor(gt_box_grad), Tensor(gt_label_grad), Tensor(gt_score_grad) + infer_meta : + func : YoloLossGradInferMeta + kernel : + func : yolo_loss_grad + optional : gt_score + - backward_op: unpool3d_grad forward: unpool3d (Tensor x, Tensor indices, int[] ksize, int[] strides={1,1,1}, int[] paddings={0,0,0}, int[] output_size={0,0,0}, str data_format="NCDHW") -> Tensor(out) args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] paddings, int[] output_size, str data_format) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index b844f76290c0eb2d20fc370ff5087a4b3c3f35d9..9d1b2ce5b49337a9d889f8f0ec7b86f759bd168e 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1032,13 +1032,3 @@ param : [out_grad] kernel : func : triu_grad - -- backward_op : yolo_loss_grad - forward : yolo_loss(Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) -> Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask) - args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, Tensor objectness_mask, Tensor gt_match_mask, Tensor loss_grad, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) - output : Tensor(x_grad), Tensor(gt_box_grad), Tensor(gt_label_grad), Tensor(gt_score_grad) - infer_meta : - func : YoloLossGradInferMeta - kernel : - func : yolo_loss_grad - optional : gt_score diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index e7b0a5ca4f330c4df70cfaafcf9d311f8829be86..6eff624761c89e9d30b0a144a85df08b9a6bc80f 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1198,17 +1198,6 @@ func : unique data_type : x -- op : yolo_loss - args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors, int[] anchor_mask, int class_num, float ignore_thresh, int downsample_ratio, bool use_label_smooth=true, float scale_x_y=1.0) - output : Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask) - infer_meta : - func : YoloLossInferMeta - kernel : - func : yolo_loss - data_type : x - optional : gt_score - backward : yolo_loss_grad - - op : zeros args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) output : Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 928797a3c394173c7a6154b4eb0793cb99457d69..1e81fcac8eb5e5c0a4518d1d5583567cf825836b 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2511,6 +2511,16 @@ outputs : {boxes : Boxes, scores : Scores} +- op : yolo_loss (yolov3_loss) + backward: yolo_loss_grad (yolov3_loss_grad) + inputs : + {x : X, gt_box : GTBox, gt_label : GTLabel ,gt_score : GTScore} + outputs : + {loss : Loss , objectness_mask : ObjectnessMask, gt_match_mask : GTMatchMask} + get_expected_kernel_type : + yolo_loss : GetYoloLossExpectedKernelType + yolo_loss_grad : GetYoloLossExpectedKernelType + - op: lu backward: lu_grad inputs: diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index eaaf9f61fa7acdb508f8851f464811e2a2b3cc91..e9624cbca06e6638d3ed5e827f5eba72992d32d8 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2265,3 +2265,15 @@ kernel : func : yolo_box data_type : x + +- op : yolo_loss + args : (Tensor x, Tensor gt_box, Tensor gt_label, Tensor gt_score, int[] anchors={}, int[] anchor_mask={}, int class_num =1 , float ignore_thresh=0.7, int downsample_ratio=32, bool use_label_smooth=true, float scale_x_y=1.0) + output : Tensor(loss), Tensor(objectness_mask), Tensor(gt_match_mask) + infer_meta : + func : YoloLossInferMeta + kernel : + func : yolo_loss + data_type : x + optional : gt_score + intermediate : objectness_mask, gt_match_mask + backward : yolo_loss_grad diff --git a/paddle/phi/ops/compat/yolov3_loss_sig.cc b/paddle/phi/ops/compat/yolov3_loss_sig.cc deleted file mode 100644 index f98709a9fdf330f009441a578346453699730226..0000000000000000000000000000000000000000 --- a/paddle/phi/ops/compat/yolov3_loss_sig.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature Yolov3LossOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("yolo_loss", - {"X", "GTBox", "GTLabel", "GTScore"}, - {"anchors", - "anchor_mask", - "class_num", - "ignore_thresh", - "downsample_ratio", - "use_label_smooth", - "scale_x_y"}, - {"Loss", "ObjectnessMask", "GTMatchMask"}); -} - -KernelSignature Yolov3LossGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "yolo_loss_grad", - {"X", - "GTBox", - "GTLabel", - "GTScore", - "ObjectnessMask", - "GTMatchMask", - "Loss@GRAD"}, - {"anchors", - "anchor_mask", - "class_num", - "ignore_thresh", - "downsample_ratio", - "use_label_smooth", - "scale_x_y"}, - {"X@GRAD", "GTBox@GRAD", "GTLabel@GRAD", "GTScore@GRAD"}); -} -} // namespace phi - -PD_REGISTER_BASE_KERNEL_NAME(yolov3_loss, yolo_loss); -PD_REGISTER_BASE_KERNEL_NAME(yolov3_loss_grad, yolo_loss_grad); - -PD_REGISTER_ARG_MAPPING_FN(yolov3_loss, phi::Yolov3LossOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(yolov3_loss_grad, - phi::Yolov3LossGradOpArgumentMapping); diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index e6ff2c02dcd96622ec0effb7071f763f4a073347..b02451f53b3323b39b8641548c55eb5b6e14eb7d 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -192,7 +192,7 @@ def yolo_loss( """ if in_dygraph_mode(): - loss, _, _ = _C_ops.yolo_loss( + loss = _C_ops.yolo_loss( x, gt_box, gt_label,