diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc index f3414c33b5ab3cc8dffee640fd85b9625b3f237b..b1f09fb0029affe671d63874cf3d3db86476c367 100644 --- a/paddle/fluid/operators/compare_op.cc +++ b/paddle/fluid/operators/compare_op.cc @@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y"); REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_OP(equal, "Out = X == Y"); REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_OP(not_equal, "Out = X != Y"); +REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.cu b/paddle/fluid/operators/compare_op.cu index 3507af2ae3add8cf02f5b9f3b3d89b40d73bcb0d..00263a2ade4502e732d53b871665185f8d0fa9f1 100644 --- a/paddle/fluid/operators/compare_op.cu +++ b/paddle/fluid/operators/compare_op.cu @@ -17,3 +17,4 @@ limitations under the License. */ REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor); REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor); REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor); +REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor); diff --git a/paddle/fluid/operators/compare_op.h b/paddle/fluid/operators/compare_op.h index 4b2ee5a9d68f5f1fd3d2d374669763855659f1db..c651335268fee08c08bcac6247f5a2ff92784330 100644 --- a/paddle/fluid/operators/compare_op.h +++ b/paddle/fluid/operators/compare_op.h @@ -48,6 +48,14 @@ struct EqualFunctor { } }; +template +struct NotEqualFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + return !EqualFunctor()(a, b); + } +}; + template class CompareOpKernel : public framework::OpKernel { diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..48308a11b4b313ec19b578110b9e369f4bfc52bf --- /dev/null +++ b/paddle/fluid/operators/detection_map_op.cc @@ -0,0 +1,184 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection_map_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class DetectionMAPOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("DetectRes"), + "Input(DetectRes) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input(Label) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumPosCount"), + "Output(AccumPosCount) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumTruePos"), + "Output(AccumTruePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("AccumFalsePos"), + "Output(AccumFalsePos) of DetectionMAPOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MAP"), + "Output(MAP) of DetectionMAPOp should not be null."); + + auto det_dims = ctx->GetInputDim("DetectRes"); + PADDLE_ENFORCE_EQ(det_dims.size(), 2UL, + "The rank of Input(DetectRes) must be 2, " + "the shape is [N, 6]."); + PADDLE_ENFORCE_EQ(det_dims[1], 6UL, + "The shape is of Input(DetectRes) [N, 6]."); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(label_dims.size(), 2UL, + "The rank of Input(Label) must be 2, " + "the shape is [N, 6]."); + PADDLE_ENFORCE_EQ(label_dims[1], 6UL, + "The shape is of Input(Label) [N, 6]."); + + if (ctx->HasInput("PosCount")) { + PADDLE_ENFORCE(ctx->HasInput("TruePos"), + "Input(TruePos) of DetectionMAPOp should not be null when " + "Input(TruePos) is not null."); + PADDLE_ENFORCE( + ctx->HasInput("FalsePos"), + "Input(FalsePos) of DetectionMAPOp should not be null when " + "Input(FalsePos) is not null."); + } + + ctx->SetOutputDim("MAP", framework::make_ddim({1})); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input("DetectRes")->type()), + ctx.device_context()); + } +}; + +class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("DetectRes", + "(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the " + "detections. Each row has 6 values: " + "[label, confidence, xmin, ymin, xmax, ymax], M is the total " + "number of detect results in this mini-batch. For each instance, " + "the offsets in first dimension are called LoD, the number of " + "offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is " + "no detected data."); + AddInput("Label", + "(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the" + "Labeled ground-truth data. Each row has 6 values: " + "[label, is_difficult, xmin, ymin, xmax, ymax], N is the total " + "number of ground-truth data in this mini-batch. For each " + "instance, the offsets in first dimension are called LoD, " + "the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, " + "means there is no ground-truth data."); + AddInput("PosCount", + "(Tensor) A tensor with shape [Ncls, 1], store the " + "input positive example count of each class, Ncls is the count of " + "input classification. " + "This input is used to pass the AccumPosCount generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. " + "When the input(PosCount) is empty, the cumulative " + "calculation is not carried out, and only the results of the " + "current mini-batch are calculated.") + .AsDispensable(); + AddInput("TruePos", + "(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the " + "input true positive example of each class." + "This input is used to pass the AccumTruePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") + .AsDispensable(); + AddInput("FalsePos", + "(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the " + "input false positive example of each class." + "This input is used to pass the AccumFalsePos generated by the " + "previous mini-batch when the multi mini-batches cumulative " + "calculation carried out. ") + .AsDispensable(); + AddOutput("AccumPosCount", + "(Tensor) A tensor with shape [Ncls, 1], store the " + "positive example count of each class. It combines the input " + "input(PosCount) and the positive example count computed from " + "input(Detection) and input(Label)."); + AddOutput("AccumTruePos", + "(LoDTensor) A LoDTensor with shape [Ntp', 2], store the " + "true positive example of each class. It combines the " + "input(TruePos) and the true positive examples computed from " + "input(Detection) and input(Label)."); + AddOutput("AccumFalsePos", + "(LoDTensor) A LoDTensor with shape [Nfp', 2], store the " + "false positive example of each class. It combines the " + "input(FalsePos) and the false positive examples computed from " + "input(Detection) and input(Label)."); + AddOutput("MAP", + "(Tensor) A tensor with shape [1], store the mAP evaluate " + "result of the detection."); + + AddAttr( + "overlap_threshold", + "(float) " + "The lower bound jaccard overlap threshold of detection output and " + "ground-truth data.") + .SetDefault(.3f); + AddAttr("evaluate_difficult", + "(bool, default true) " + "Switch to control whether the difficult data is evaluated.") + .SetDefault(true); + AddAttr("ap_type", + "(string, default 'integral') " + "The AP algorithm type, 'integral' or '11point'.") + .SetDefault("integral") + .InEnum({"integral", "11point"}) + .AddCustomChecker([](const std::string& ap_type) { + PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone, + "The ap_type should be 'integral' or '11point."); + }); + AddComment(R"DOC( +Detection mAP evaluate operator. +The general steps are as follows. First, calculate the true positive and + false positive according to the input of detection and labels, then + calculate the mAP evaluate value. + Supporting '11 point' and 'integral' mAP algorithm. Please get more information + from the following articles: + https://sanchom.wordpress.com/tag/average-precision/ + https://arxiv.org/abs/1512.02325 + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp, + ops::DetectionMAPOpMaker); +REGISTER_OP_CPU_KERNEL( + detection_map, ops::DetectionMAPOpKernel, + ops::DetectionMAPOpKernel); diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h new file mode 100644 index 0000000000000000000000000000000000000000..0f5f588e9c448a6d84d388848aa5701f2b4882dd --- /dev/null +++ b/paddle/fluid/operators/detection_map_op.h @@ -0,0 +1,451 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +enum APType { kNone = 0, kIntegral, k11point }; + +APType GetAPType(std::string str) { + if (str == "integral") { + return APType::kIntegral; + } else if (str == "11point") { + return APType::k11point; + } else { + return APType::kNone; + } +} + +template +inline bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +inline void GetAccumulation(std::vector> in_pairs, + std::vector* accu_vec) { + std::stable_sort(in_pairs.begin(), in_pairs.end(), SortScorePairDescend); + accu_vec->clear(); + size_t sum = 0; + for (size_t i = 0; i < in_pairs.size(); ++i) { + auto count = in_pairs[i].second; + sum += count; + accu_vec->push_back(sum); + } +} + +template +class DetectionMAPOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_detect = ctx.Input("DetectRes"); + auto* in_label = ctx.Input("Label"); + auto* out_map = ctx.Output("MAP"); + + auto* in_pos_count = ctx.Input("PosCount"); + auto* in_true_pos = ctx.Input("TruePos"); + auto* in_false_pos = ctx.Input("FalsePos"); + + auto* out_pos_count = ctx.Output("AccumPosCount"); + auto* out_true_pos = ctx.Output("AccumTruePos"); + auto* out_false_pos = ctx.Output("AccumFalsePos"); + + float overlap_threshold = ctx.Attr("overlap_threshold"); + float evaluate_difficult = ctx.Attr("evaluate_difficult"); + auto ap_type = GetAPType(ctx.Attr("ap_type")); + + auto label_lod = in_label->lod(); + auto detect_lod = in_detect->lod(); + PADDLE_ENFORCE_EQ(label_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(), + "The batch_size of input(Label) and input(Detection) " + "must be the same."); + + std::vector>> gt_boxes; + std::vector>>> detect_boxes; + + GetBoxes(*in_label, *in_detect, gt_boxes, detect_boxes); + + std::map label_pos_count; + std::map>> true_pos; + std::map>> false_pos; + + if (in_pos_count != nullptr) { + GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count, + true_pos, false_pos); + } + + CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult, + overlap_threshold, label_pos_count, true_pos, + false_pos); + + T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos); + + GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count, + *out_true_pos, *out_false_pos); + + T* map_data = out_map->mutable_data(ctx.GetPlace()); + map_data[0] = map; + } + + protected: + struct Box { + Box(T xmin, T ymin, T xmax, T ymax) + : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax), is_difficult(false) {} + + T xmin, ymin, xmax, ymax; + bool is_difficult; + }; + + inline T JaccardOverlap(const Box& box1, const Box& box2) const { + if (box2.xmin > box1.xmax || box2.xmax < box1.xmin || + box2.ymin > box1.ymax || box2.ymax < box1.ymin) { + return 0.0; + } else { + T inter_xmin = std::max(box1.xmin, box2.xmin); + T inter_ymin = std::max(box1.ymin, box2.ymin); + T inter_xmax = std::min(box1.xmax, box2.xmax); + T inter_ymax = std::min(box1.ymax, box2.ymax); + + T inter_width = inter_xmax - inter_xmin; + T inter_height = inter_ymax - inter_ymin; + T inter_area = inter_width * inter_height; + + T bbox_area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin); + T bbox_area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin); + + return inter_area / (bbox_area1 + bbox_area2 - inter_area); + } + } + + void GetBoxes(const framework::LoDTensor& input_label, + const framework::LoDTensor& input_detect, + std::vector>>& gt_boxes, + std::vector>>>& + detect_boxes) const { + auto labels = framework::EigenTensor::From(input_label); + auto detect = framework::EigenTensor::From(input_detect); + + auto label_lod = input_label.lod(); + auto detect_lod = input_detect.lod(); + + int batch_size = label_lod[0].size() - 1; + auto label_index = label_lod[0]; + + for (int n = 0; n < batch_size; ++n) { + std::map> boxes; + for (int i = label_index[n]; i < label_index[n + 1]; ++i) { + Box box(labels(i, 2), labels(i, 3), labels(i, 4), labels(i, 5)); + int label = labels(i, 0); + auto is_difficult = labels(i, 1); + if (std::abs(is_difficult - 0.0) < 1e-6) + box.is_difficult = false; + else + box.is_difficult = true; + boxes[label].push_back(box); + } + gt_boxes.push_back(boxes); + } + + auto detect_index = detect_lod[0]; + for (int n = 0; n < batch_size; ++n) { + std::map>> boxes; + for (int i = detect_index[n]; i < detect_index[n + 1]; ++i) { + Box box(detect(i, 2), detect(i, 3), detect(i, 4), detect(i, 5)); + int label = detect(i, 0); + auto score = detect(i, 1); + boxes[label].push_back(std::make_pair(score, box)); + } + detect_boxes.push_back(boxes); + } + } + + void GetOutputPos( + const framework::ExecutionContext& ctx, + const std::map& label_pos_count, + const std::map>>& true_pos, + const std::map>>& false_pos, + framework::Tensor& output_pos_count, + framework::LoDTensor& output_true_pos, + framework::LoDTensor& output_false_pos) const { + int max_class_id = 0; + int true_pos_count = 0; + int false_pos_count = 0; + for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { + int label = it->first; + if (label > max_class_id) max_class_id = label; + int label_num_pos = it->second; + if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) + continue; + auto label_true_pos = true_pos.find(label)->second; + auto label_false_pos = false_pos.find(label)->second; + true_pos_count += label_true_pos.size(); + false_pos_count += label_false_pos.size(); + } + + int* pos_count_data = output_pos_count.mutable_data( + framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace()); + T* true_pos_data = output_true_pos.mutable_data( + framework::make_ddim({true_pos_count, 2}), ctx.GetPlace()); + T* false_pos_data = output_false_pos.mutable_data( + framework::make_ddim({false_pos_count, 2}), ctx.GetPlace()); + true_pos_count = 0; + false_pos_count = 0; + std::vector true_pos_starts = {0}; + std::vector false_pos_starts = {0}; + for (int i = 0; i <= max_class_id; ++i) { + auto it_count = label_pos_count.find(i); + pos_count_data[i] = 0; + if (it_count != label_pos_count.end()) { + pos_count_data[i] = it_count->second; + } + auto it_true_pos = true_pos.find(i); + if (it_true_pos != true_pos.end()) { + const std::vector>& true_pos_vec = + it_true_pos->second; + for (const std::pair& tp : true_pos_vec) { + true_pos_data[true_pos_count * 2] = tp.first; + true_pos_data[true_pos_count * 2 + 1] = static_cast(tp.second); + true_pos_count++; + } + } + true_pos_starts.push_back(true_pos_count); + + auto it_false_pos = false_pos.find(i); + if (it_false_pos != false_pos.end()) { + const std::vector>& false_pos_vec = + it_false_pos->second; + for (const std::pair& fp : false_pos_vec) { + false_pos_data[false_pos_count * 2] = fp.first; + false_pos_data[false_pos_count * 2 + 1] = static_cast(fp.second); + false_pos_count++; + } + } + false_pos_starts.push_back(false_pos_count); + } + + framework::LoD true_pos_lod; + true_pos_lod.emplace_back(true_pos_starts); + framework::LoD false_pos_lod; + false_pos_lod.emplace_back(false_pos_starts); + + output_true_pos.set_lod(true_pos_lod); + output_false_pos.set_lod(false_pos_lod); + return; + } + + void GetInputPos( + const framework::Tensor& input_pos_count, + const framework::LoDTensor& input_true_pos, + const framework::LoDTensor& input_false_pos, + std::map& label_pos_count, + std::map>>& true_pos, + std::map>>& false_pos) const { + constexpr T kEPS = static_cast(1e-6); + int class_number = input_pos_count.dims()[0]; + const int* pos_count_data = input_pos_count.data(); + for (int i = 0; i < class_number; ++i) { + label_pos_count[i] = pos_count_data[i]; + } + + auto SetData = [](const framework::LoDTensor& pos_tensor, + std::map>>& pos) { + const T* pos_data = pos_tensor.data(); + auto pos_data_lod = pos_tensor.lod(); + for (int i = 0; i < pos_data_lod.size(); ++i) { + for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) { + T score = pos_data[j * 2]; + int flag = 1; + if (pos_data[j * 2 + 1] < kEPS) flag = 0; + pos[i].push_back(std::make_pair(score, flag)); + } + } + }; + + SetData(input_true_pos, true_pos); + SetData(input_false_pos, false_pos); + return; + } + + void CalcTrueAndFalsePositive( + const std::vector>>& gt_boxes, + const std::vector>>>& + detect_boxes, + bool evaluate_difficult, float overlap_threshold, + std::map& label_pos_count, + std::map>>& true_pos, + std::map>>& false_pos) const { + int batch_size = gt_boxes.size(); + for (int n = 0; n < batch_size; ++n) { + auto image_gt_boxes = gt_boxes[n]; + for (auto it = image_gt_boxes.begin(); it != image_gt_boxes.end(); ++it) { + size_t count = 0; + auto labeled_bboxes = it->second; + if (evaluate_difficult) { + count = labeled_bboxes.size(); + } else { + for (size_t i = 0; i < labeled_bboxes.size(); ++i) + if (!(labeled_bboxes[i].is_difficult)) ++count; + } + if (count == 0) { + continue; + } + int label = it->first; + if (label_pos_count.find(label) == label_pos_count.end()) { + label_pos_count[label] = count; + } else { + label_pos_count[label] += count; + } + } + } + + for (size_t n = 0; n < detect_boxes.size(); ++n) { + auto image_gt_boxes = gt_boxes[n]; + auto detections = detect_boxes[n]; + + if (image_gt_boxes.size() == 0) { + for (auto it = detections.begin(); it != detections.end(); ++it) { + auto pred_boxes = it->second; + int label = it->first; + for (size_t i = 0; i < pred_boxes.size(); ++i) { + auto score = pred_boxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + continue; + } + + for (auto it = detections.begin(); it != detections.end(); ++it) { + int label = it->first; + auto pred_boxes = it->second; + if (image_gt_boxes.find(label) == image_gt_boxes.end()) { + for (size_t i = 0; i < pred_boxes.size(); ++i) { + auto score = pred_boxes[i].first; + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + continue; + } + + auto matched_bboxes = image_gt_boxes.find(label)->second; + std::vector visited(matched_bboxes.size(), false); + // Sort detections in descend order based on scores + std::sort(pred_boxes.begin(), pred_boxes.end(), + SortScorePairDescend); + for (size_t i = 0; i < pred_boxes.size(); ++i) { + T max_overlap = -1.0; + size_t max_idx = 0; + auto score = pred_boxes[i].first; + for (size_t j = 0; j < matched_bboxes.size(); ++j) { + T overlap = JaccardOverlap(pred_boxes[i].second, matched_bboxes[j]); + if (overlap > max_overlap) { + max_overlap = overlap; + max_idx = j; + } + } + if (max_overlap > overlap_threshold) { + bool match_evaluate_difficult = + evaluate_difficult || + (!evaluate_difficult && !matched_bboxes[max_idx].is_difficult); + if (match_evaluate_difficult) { + if (!visited[max_idx]) { + true_pos[label].push_back(std::make_pair(score, 1)); + false_pos[label].push_back(std::make_pair(score, 0)); + visited[max_idx] = true; + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } else { + true_pos[label].push_back(std::make_pair(score, 0)); + false_pos[label].push_back(std::make_pair(score, 1)); + } + } + } + } + } + + T CalcMAP( + APType ap_type, const std::map& label_pos_count, + const std::map>>& true_pos, + const std::map>>& false_pos) const { + T mAP = 0.0; + int count = 0; + for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) { + int label = it->first; + int label_num_pos = it->second; + if (label_num_pos == 0 || true_pos.find(label) == true_pos.end()) + continue; + auto label_true_pos = true_pos.find(label)->second; + auto label_false_pos = false_pos.find(label)->second; + // Compute average precision. + std::vector tp_sum; + GetAccumulation(label_true_pos, &tp_sum); + std::vector fp_sum; + GetAccumulation(label_false_pos, &fp_sum); + std::vector precision, recall; + size_t num = tp_sum.size(); + // Compute Precision. + for (size_t i = 0; i < num; ++i) { + precision.push_back(static_cast(tp_sum[i]) / + static_cast(tp_sum[i] + fp_sum[i])); + recall.push_back(static_cast(tp_sum[i]) / label_num_pos); + } + // VOC2007 style + if (ap_type == APType::k11point) { + std::vector max_precisions(11, 0.0); + int start_idx = num - 1; + for (int j = 10; j >= 0; --j) + for (int i = start_idx; i >= 0; --i) { + if (recall[i] < j / 10.) { + start_idx = i; + if (j > 0) max_precisions[j - 1] = max_precisions[j]; + break; + } else { + if (max_precisions[j] < precision[i]) + max_precisions[j] = precision[i]; + } + } + for (int j = 10; j >= 0; --j) mAP += max_precisions[j] / 11; + ++count; + } else if (ap_type == APType::kIntegral) { + // Nature integral + float average_precisions = 0.; + float prev_recall = 0.; + for (size_t i = 0; i < num; ++i) { + if (fabs(recall[i] - prev_recall) > 1e-6) + average_precisions += precision[i] * fabs(recall[i] - prev_recall); + prev_recall = recall[i]; + } + mAP += average_precisions; + ++count; + } else { + LOG(FATAL) << "Unkown ap version: " << ap_type; + } + } + if (count != 0) mAP /= count; + return mAP * 100; + } +}; // namespace operators + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/prior_box_op.cc b/paddle/fluid/operators/prior_box_op.cc index ed48603e17f38f89705186fb9fb992f69d26d2ff..1385a6cdce838b7f376cd784a8eaa63f591c7ef2 100644 --- a/paddle/fluid/operators/prior_box_op.cc +++ b/paddle/fluid/operators/prior_box_op.cc @@ -38,8 +38,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_LT(input_dims[3], image_dims[3], "The width of input must smaller than image."); - auto min_sizes = ctx->Attrs().Get>("min_sizes"); - auto max_sizes = ctx->Attrs().Get>("max_sizes"); + auto min_sizes = ctx->Attrs().Get>("min_sizes"); + auto max_sizes = ctx->Attrs().Get>("max_sizes"); auto variances = ctx->Attrs().Get>("variances"); auto aspect_ratios = ctx->Attrs().Get>("aspect_ratios"); bool flip = ctx->Attrs().Get("flip"); @@ -47,15 +47,15 @@ class PriorBoxOp : public framework::OperatorWithKernel { std::vector aspect_ratios_vec; ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec); - int num_priors = aspect_ratios_vec.size() * min_sizes.size(); + size_t num_priors = aspect_ratios_vec.size() * min_sizes.size(); if (max_sizes.size() > 0) { PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(), "The number of min_size and max_size must be equal."); - for (size_t i = 0; i < min_sizes.size(); ++i) { + num_priors += max_sizes.size(); + for (size_t i = 0; i < max_sizes.size(); ++i) { PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i], "max_size[%d] must be greater than min_size[%d].", i, i); - num_priors += 1; } } @@ -90,20 +90,20 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { "H is the height of input, W is the width of input, num_priors " "is the box count of each position."); - AddAttr>("min_sizes", - "(vector) List of min sizes " - "of generated prior boxes.") - .AddCustomChecker([](const std::vector& min_sizes) { + AddAttr>("min_sizes", + "(vector) List of min sizes " + "of generated prior boxes.") + .AddCustomChecker([](const std::vector& min_sizes) { PADDLE_ENFORCE_GT(min_sizes.size(), 0, "Size of min_sizes must be at least 1."); for (size_t i = 0; i < min_sizes.size(); ++i) { - PADDLE_ENFORCE_GT(min_sizes[i], 0, + PADDLE_ENFORCE_GT(min_sizes[i], 0.0, "min_sizes[%d] must be positive.", i); } }); - AddAttr>( + AddAttr>( "max_sizes", - "(vector) List of max sizes of generated prior boxes."); + "(vector) List of max sizes of generated prior boxes."); AddAttr>( "aspect_ratios", "(vector) List of aspect ratios of generated prior boxes."); @@ -125,16 +125,16 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(true); AddAttr("step_w", - "Prior boxes step across width, 0 for auto calculation.") + "Prior boxes step across width, 0.0 for auto calculation.") .SetDefault(0.0) .AddCustomChecker([](const float& step_w) { - PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0."); + PADDLE_ENFORCE_GE(step_w, 0.0, "step_w should be larger than 0."); }); AddAttr("step_h", - "Prior boxes step across height, 0 for auto calculation.") + "Prior boxes step across height, 0.0 for auto calculation.") .SetDefault(0.0) .AddCustomChecker([](const float& step_h) { - PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0."); + PADDLE_ENFORCE_GE(step_h, 0.0, "step_h should be larger than 0."); }); AddAttr("offset", diff --git a/paddle/fluid/operators/prior_box_op.h b/paddle/fluid/operators/prior_box_op.h index fd07041233495660605e9cf9acb33d57eb57bc30..e2c9514ed0814f21fec6c4184b7e971c4528d489 100644 --- a/paddle/fluid/operators/prior_box_op.h +++ b/paddle/fluid/operators/prior_box_op.h @@ -60,8 +60,8 @@ class PriorBoxOpKernel : public framework::OpKernel { auto* boxes = ctx.Output("Boxes"); auto* vars = ctx.Output("Variances"); - auto min_sizes = ctx.Attr>("min_sizes"); - auto max_sizes = ctx.Attr>("max_sizes"); + auto min_sizes = ctx.Attr>("min_sizes"); + auto max_sizes = ctx.Attr>("max_sizes"); auto input_aspect_ratio = ctx.Attr>("aspect_ratios"); auto variances = ctx.Attr>("variances"); auto flip = ctx.Attr("flip"); @@ -108,7 +108,7 @@ class PriorBoxOpKernel : public framework::OpKernel { T box_width, box_height; int idx = 0; for (size_t s = 0; s < min_sizes.size(); ++s) { - int min_size = min_sizes[s]; + auto min_size = min_sizes[s]; // first prior: aspect_ratio = 1, size = min_size box_width = box_height = min_size; // xmin @@ -124,7 +124,7 @@ class PriorBoxOpKernel : public framework::OpKernel { idx++; if (max_sizes.size() > 0) { - int max_size = max_sizes[s]; + auto max_size = max_sizes[s]; // second prior: aspect_ratio = 1, // size = sqrt(min_size * max_size) box_width = box_height = sqrt(min_size * max_size); diff --git a/paddle/fluid/operators/smooth_l1_loss_op.cc b/paddle/fluid/operators/smooth_l1_loss_op.cc index be4c7a56a84e84c39a578b958fe7c9ad551f54f6..e6eede23ee367200f9a2b531d1cbd402ceea6b54 100644 --- a/paddle/fluid/operators/smooth_l1_loss_op.cc +++ b/paddle/fluid/operators/smooth_l1_loss_op.cc @@ -44,7 +44,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { } }; -template class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { public: SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker) @@ -73,10 +72,10 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor, default Tensor) A tensor with rank be 2. " "The output smooth l1 loss with shape [batch_size, 1]."); - AddAttr("sigma", - "Hyper parameter of smooth l1 loss op." - "A float scalar with default value 3.0.") - .SetDefault(3.0); + AddAttr("sigma", + "Hyper parameter of smooth l1 loss op." + "A float scalar with default value 3.0.") + .SetDefault(1.0); AddComment(R"DOC( Smooth L1 Loss Operator. @@ -133,9 +132,8 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, - ops::SmoothL1LossOpMaker, smooth_l1_loss_grad, - ops::SmoothL1LossGradOp); +REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, ops::SmoothL1LossOpMaker, + smooth_l1_loss_grad, ops::SmoothL1LossGradOp); REGISTER_OP_CPU_KERNEL( smooth_l1_loss, ops::SmoothL1LossKernel); diff --git a/python/paddle/v2/fluid/layers/__init__.py b/python/paddle/v2/fluid/layers/__init__.py index 89b9f30668ee3ed84a9b728932c4ba0227e454b3..cfbbf710b6ac63b9a0fe7d51b0d1940532e948fc 100644 --- a/python/paddle/v2/fluid/layers/__init__.py +++ b/python/paddle/v2/fluid/layers/__init__.py @@ -28,8 +28,11 @@ import device from device import * import math_op_patch from math_op_patch import * +import detection +from detection import * __all__ = [] +__all__ += math_op_patch.__all__ __all__ += detection.__all__ __all__ += nn.__all__ __all__ += io.__all__ @@ -37,4 +40,4 @@ __all__ += tensor.__all__ __all__ += control_flow.__all__ __all__ += ops.__all__ __all__ += device.__all__ -__all__ += math_op_patch.__all__ +__all__ += detection.__all__ diff --git a/python/paddle/v2/fluid/layers/detection.py b/python/paddle/v2/fluid/layers/detection.py index bbe2765e138e0e05b9fdb2967822cefd337bf100..b045e1c56c8e97161137ae52b96483807e028063 100644 --- a/python/paddle/v2/fluid/layers/detection.py +++ b/python/paddle/v2/fluid/layers/detection.py @@ -18,15 +18,15 @@ All layers just related to the detection neural network. from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..framework import Variable -from layer_function_generator import autodoc +from ..nets import img_conv_with_bn from tensor import concat from ops import reshape -from ..nets import img_conv_with_bn from nn import transpose import math __all__ = [ 'detection_output', + 'prior_box', 'multi_box_head', ] @@ -44,7 +44,7 @@ def detection_output(scores, """ **Detection Output Layer** - This layer applies the NMS to the output of network and computes the + This layer applies the NMS to the output of network and computes the predict bounding box location. The output's shape of this layer could be zero if there is no valid bounding box. @@ -127,6 +127,211 @@ def detection_output(scores, return nmsed_outs +def prior_box(inputs, + image, + min_ratio, + max_ratio, + aspect_ratios, + base_size, + steps=None, + step_w=None, + step_h=None, + offset=0.5, + variance=[0.1, 0.1, 0.1, 0.1], + flip=False, + clip=False, + min_sizes=None, + max_sizes=None, + name=None): + """ + **Prior_boxes** + + Generate prior boxes for SSD(Single Shot MultiBox Detector) + algorithm. The details of this algorithm, please refer the + section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector) + `_ . + + Args: + inputs(list): The list of input Variables, the format + of all Variables is NCHW. + image(Variable): The input image data of PriorBoxOp, + the layout is NCHW. + min_ratio(int): the min ratio of generated prior boxes. + max_ratio(int): the max ratio of generated prior boxes. + aspect_ratios(list): the aspect ratios of generated prior + boxes. The length of input and aspect_ratios must be equal. + base_size(int): the base_size is used to get min_size + and max_size according to min_ratio and max_ratio. + step_w(list, optional, default=None): Prior boxes step + across width. If step_w[i] == 0.0, the prior boxes step + across width of the inputs[i] will be automatically calculated. + step_h(list, optional, default=None): Prior boxes step + across height, If step_h[i] == 0.0, the prior boxes + step across height of the inputs[i] will be automatically calculated. + offset(float, optional, default=0.5): Prior boxes center offset. + variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances + to be encoded in prior boxes. + flip(bool, optional, default=False): Whether to flip + aspect ratios. + clip(bool, optional, default=False): Whether to clip + out-of-boundary boxes. + min_sizes(list, optional, default=None): If `len(inputs) <=2`, + min_sizes must be set up, and the length of min_sizes + should equal to the length of inputs. + max_sizes(list, optional, default=None): If `len(inputs) <=2`, + max_sizes must be set up, and the length of min_sizes + should equal to the length of inputs. + name(str, optional, None): Name of the prior box layer. + + Returns: + boxes(Variable): the output prior boxes of PriorBoxOp. + The layout is [num_priors, 4]. num_priors is the total + box count of each position of inputs. + Variances(Variable): the expanded variances of PriorBoxOp. + The layout is [num_priors, 4]. num_priors is the total + box count of each position of inputs + + Examples: + .. code-block:: python + + prior_box( + inputs = [conv1, conv2, conv3, conv4, conv5, conv6], + image = data, + min_ratio = 20, # 0.20 + max_ratio = 90, # 0.90 + offset = 0.5, + base_size = 300, + variance = [0.1,0.1,0.1,0.1], + aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + flip=True, + clip=True) + """ + + def _prior_box_(input, + image, + min_sizes, + max_sizes, + aspect_ratios, + variance, + flip=False, + clip=False, + step_w=0.0, + step_h=0.0, + offset=0.5, + name=None): + helper = LayerHelper("prior_box", **locals()) + dtype = helper.input_dtype() + + box = helper.create_tmp_variable(dtype) + var = helper.create_tmp_variable(dtype) + helper.append_op( + type="prior_box", + inputs={"Input": input, + "Image": image}, + outputs={"Boxes": box, + "Variances": var}, + attrs={ + 'min_sizes': min_sizes, + 'max_sizes': max_sizes, + 'aspect_ratios': aspect_ratios, + 'variances': variance, + 'flip': flip, + 'clip': clip, + 'step_w': step_w, + 'step_h': step_h, + 'offset': offset + }) + return box, var + + def _reshape_with_axis_(input, axis=1): + if not (axis > 0 and axis < len(input.shape)): + raise ValueError("The axis should be smaller than " + "the arity of input and bigger than 0.") + new_shape = [ + -1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)]) + ] + out = reshape(x=input, shape=new_shape) + return out + + assert isinstance(inputs, list), 'inputs should be a list.' + num_layer = len(inputs) + + if num_layer <= 2: + assert min_sizes is not None and max_sizes is not None + assert len(min_sizes) == num_layer and len(max_sizes) == num_layer + else: + min_sizes = [] + max_sizes = [] + step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) + for ratio in xrange(min_ratio, max_ratio + 1, step): + min_sizes.append(base_size * ratio / 100.) + max_sizes.append(base_size * (ratio + step) / 100.) + min_sizes = [base_size * .10] + min_sizes + max_sizes = [base_size * .20] + max_sizes + + if aspect_ratios: + if not (isinstance(aspect_ratios, list) and + len(aspect_ratios) == num_layer): + raise ValueError( + 'aspect_ratios should be list and the length of inputs ' + 'and aspect_ratios should be the same.') + if step_h: + if not (isinstance(step_h, list) and len(step_h) == num_layer): + raise ValueError( + 'step_h should be list and the length of inputs and ' + 'step_h should be the same.') + if step_w: + if not (isinstance(step_w, list) and len(step_w) == num_layer): + raise ValueError( + 'step_w should be list and the length of inputs and ' + 'step_w should be the same.') + if steps: + if not (isinstance(steps, list) and len(steps) == num_layer): + raise ValueError( + 'steps should be list and the length of inputs and ' + 'step_w should be the same.') + step_w = steps + step_h = steps + + box_results = [] + var_results = [] + for i, input in enumerate(inputs): + min_size = min_sizes[i] + max_size = max_sizes[i] + aspect_ratio = [] + if not isinstance(min_size, list): + min_size = [min_size] + if not isinstance(max_size, list): + max_size = [max_size] + if aspect_ratios: + aspect_ratio = aspect_ratios[i] + if not isinstance(aspect_ratio, list): + aspect_ratio = [aspect_ratio] + + box, var = _prior_box_(input, image, min_size, max_size, aspect_ratio, + variance, flip, clip, step_w[i] + if step_w else 0.0, step_h[i] + if step_w else 0.0, offset) + + box_results.append(box) + var_results.append(var) + + if len(box_results) == 1: + box = box_results[0] + var = var_results[0] + else: + reshaped_boxes = [] + reshaped_vars = [] + for i in range(len(box_results)): + reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3)) + reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3)) + + box = concat(reshaped_boxes) + var = concat(reshaped_vars) + + return box, var + + def multi_box_head(inputs, num_classes, min_sizes=None, @@ -171,34 +376,53 @@ def multi_box_head(inputs, Returns: - mbox_loc(Variable): the output prior boxes of PriorBoxOp. The layout is + mbox_loc(list): the output prior boxes of PriorBoxOp. The layout is [num_priors, 4]. num_priors is the total box count of each position of inputs. - mbox_conf(Variable): the expanded variances of PriorBoxOp. The layout + mbox_conf(list): the expanded variances of PriorBoxOp. The layout is [num_priors, 4]. num_priors is the total box count of each position of inputs Examples: .. code-block:: python - + mbox_locs, mbox_confs = detection.multi_box_head( + inputs=[conv1, conv2, conv3, conv4, conv5, conv5], + num_classes=21, + min_ratio=20, + max_ratio=90, + aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + base_size=300, + flip=True) """ - assert isinstance(inputs, list), 'inputs should be a list.' + if not (isinstance(inputs, list)): + raise ValueError('inputs should be a list.') if min_sizes is not None: - assert len(inputs) == len(min_sizes) + if not (len(inputs) == len(min_sizes)): + raise ValueError('the length of min_sizes ' + 'and inputs should be the same.') if max_sizes is not None: - assert len(inputs) == len(max_sizes) + if not (len(inputs) == len(max_sizes)): + raise ValueError('the length of max_sizes ' + 'and inputs should be the same.') + + if aspect_ratios is not None: + if not (len(inputs) == len(aspect_ratios)): + raise ValueError('the length of aspect_ratios ' + 'and inputs should be the same.') if min_sizes is None: - # if min_sizes is None, min_sizes and max_sizes - # will be set according to max_ratio and min_ratio - assert max_ratio is not None and min_ratio is not None + # If min_sizes is None, min_sizes and max_sizes + # will be set according to max_ratio and min_ratio. + num_layer = len(inputs) + assert max_ratio is not None and min_ratio is not None,\ + 'max_ratio and min_ratio must be not None.' + assert num_layer >= 3, 'The length of the input data is at least three.' min_sizes = [] max_sizes = [] - num_layer = len(inputs) step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2))) for ratio in xrange(min_ratio, max_ratio + 1, step): min_sizes.append(base_size * ratio / 100.) @@ -206,9 +430,6 @@ def multi_box_head(inputs, min_sizes = [base_size * .10] + min_sizes max_sizes = [base_size * .20] + max_sizes - if aspect_ratios is not None: - assert len(inputs) == len(aspect_ratios) - mbox_locs = [] mbox_confs = [] for i, input in enumerate(inputs): @@ -221,9 +442,9 @@ def multi_box_head(inputs, max_size = max_sizes[i] if type(max_size) is not list: max_size = [max_size] - if max_size: - assert len(max_size) == len( - min_size), "max_size and min_size should have same length." + if not (len(max_size) == len(min_size)): + raise ValueError( + 'max_size and min_size should have same length.') aspect_ratio = [] if aspect_ratios is not None: @@ -231,17 +452,19 @@ def multi_box_head(inputs, if type(aspect_ratio) is not list: aspect_ratio = [aspect_ratio] + # get the number of prior box on each location num_priors_per_location = 0 if max_sizes is not None: - num_priors_per_location = len(min_size) + len(aspect_ratio) * len( - min_size) + len(max_size) + num_priors_per_location = len(min_size) + \ + len(aspect_ratio) * len(min_size) +\ + len(max_size) else: - num_priors_per_location = len(min_size) + len(aspect_ratio) * len( - min_size) + num_priors_per_location = len(min_size) +\ + len(aspect_ratio) * len(min_size) if flip: num_priors_per_location += len(aspect_ratio) * len(min_size) - # mbox_loc + # get mbox_loc num_loc_output = num_priors_per_location * 4 if share_location: num_loc_output *= num_classes @@ -256,7 +479,7 @@ def multi_box_head(inputs, mbox_loc = transpose(mbox_loc, perm=[0, 2, 3, 1]) mbox_locs.append(mbox_loc) - # get the number of prior box + # get conf_loc num_conf_output = num_priors_per_location * num_classes conf_loc = img_conv_with_bn( input=input, diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index 00e4e6907804c7a460e60d960b4aa94ca23b4886..d829bba1b101cc802ea29f32e0b7ecdb1ac448f5 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -152,7 +152,12 @@ def monkey_patch_variable(): ("__div__", "elementwise_div", False), ("__rdiv__", "elementwise_div", True), ("__pow__", "elementwise_pow", False), - ("__rpow__", "elementwise_pow", True)): + ("__rpow__", "elementwise_pow", True), + # for logical compare + ("__eq__", "equal", False), + ("__ne__", "not_equal", False), + ("__lt__", "less_than", False), + ("__le__", "less_equal", False)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 5ebd329fc0285a39111a23b3c58c80944cfe23f6..051b5368180d3f7951b100c26fb7367372d9a343 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -66,6 +66,8 @@ __all__ = [ 'row_conv', 'multiplex', 'layer_norm', + 'softmax_with_cross_entropy', + 'smooth_l1', ] @@ -3091,3 +3093,122 @@ def multiplex(inputs, index): 'Ids': index}, outputs={'Out': [out]}) return out + + +def softmax_with_cross_entropy(logits, label, soft_label=False): + """ + **Softmax With Cross Entropy Operator.** + + Cross entropy loss with softmax is used as the output layer extensively. This + operator computes the softmax normalized values for each row of the input + tensor, after which cross-entropy loss is computed. This provides a more + numerically stable gradient. + + Because this operator performs a softmax on logits internally, it expects + unscaled logits. This operator should not be used with the output of + softmax operator since that would produce incorrect results. + + When the attribute soft_label is set false, this operators expects mutually + exclusive hard labels, each sample in a batch is in exactly one class with a + probability of 1.0. Each sample in the batch will have a single label. + + The equation is as follows: + + 1) Hard label (one-hot label, so every sample has exactly one class) + + .. math:: + + loss_j = -\\text{logit}_{label_j} + + \\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logit}_i)\\right), j = 1,..., K + + 2) Soft label (each sample can have a distribution over all classes) + + .. math:: + + loss_j = -\\sum_{i=0}^{K}\\text{label}_i + \\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K} + \\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K + + Args: + logits (Variable): The unscaled log probabilities, which is a 2-D tensor + with shape [N x K]. N is the batch_size, and K is the class number. + label (Variable): The ground truth which is a 2-D tensor. If soft_label + is set to false, Label is a Tensor with shape [N x 1]. If + soft_label is set to true, Label is a Tensor with + soft_label (bool): A flag to indicate whether to interpretate the given + labels as soft labels. By default, `soft_label` is set to False. + Returns: + Variable: The cross entropy loss is a 2-D tensor with shape [N x 1]. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[128], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + fc = fluid.layers.fc(input=data, size=100) + out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label) + """ + helper = LayerHelper('softmax_with_cross_entropy', **locals()) + softmax = helper.create_tmp_variable(dtype=logits.dtype) + loss = helper.create_tmp_variable(dtype=logits.dtype) + helper.append_op( + type='softmax_with_cross_entropy', + inputs={'Logits': logits, + 'Label': label}, + outputs={'Softmax': softmax, + 'Loss': loss}, + attrs={'soft_label': soft_label}) + return loss + + +def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): + """ + **Smooth L1 Loss Operator. ** + + This operator computes the smooth l1 loss for X and Y. + The operator takes the first dimension of X and Y as batch size. + For each instance, it computes the smooth l1 loss element by element first + and then sums all the losses. So the shape of Out is [batch_size, 1]. + + Args: + x (Variable): A tensor with rank at least 2. The input value of smooth + l1 loss op with shape [batch_size, dim1, ..., dimN]. + y (Variable): A tensor with rank at least 2. The target value of smooth + l1 loss op with same shape as x. + inside_weight (Variable|None): A tensor with rank at least 2. This + input is optional and should have same shape with x. If provided, + the result of (x - y) will be multiplied by this tensor element by + element. + outside_weight (Variable|None): A tensor with rank at least 2. This + input is optional and should have same shape with x. If provided, + the out smooth l1 loss will be multiplied by this tensor element + by element. + sigma (float|None): Hyper parameter of smooth l1 loss op. A float scalar + with default value 1.0. + Returns: + Variable: A tensor with rank be 2. The output smooth l1 loss with + shape [batch_size, 1]. + + Examples: + .. code-block:: python + + data = fluid.layers.data(name='data', shape=[128], dtype='float32') + label = fluid.layers.data(name='label', shape=[100], dtype='int64') + fc = fluid.layers.fc(input=data, size=100) + out = fluid.layers.smooth_l1(logits=fc, label=label) + """ + helper = LayerHelper('smooth_l1_loss', **locals()) + diff = helper.create_tmp_variable(dtype=x.dtype) + loss = helper.create_tmp_variable(dtype=x.dtype) + helper.append_op( + type='smooth_l1_loss', + inputs={ + 'X': x, + 'Y': y, + 'InsideWeight': inside_weight, + 'OutsideWeight': outside_weight + }, + outputs={'Diff': diff, + 'Out': loss}, + attrs={'sigma': sigma}) + return loss diff --git a/python/paddle/v2/fluid/learning_rate_decay.py b/python/paddle/v2/fluid/learning_rate_decay.py index 2a2a29fd9cbedc138dc82ca75ccd78208fd33195..0826d3da79a96590f00159a2d2e6f069792909c4 100644 --- a/python/paddle/v2/fluid/learning_rate_decay.py +++ b/python/paddle/v2/fluid/learning_rate_decay.py @@ -179,7 +179,7 @@ def polynomial_decay(learning_rate, shape=[1], dtype='float32', value=1.0) with layers.Switch() as switch: - with switch.case(layers.equal(x=global_step, y=zero_var)): + with switch.case(global_step == zero_var): layers.assign(input=one_var, output=div_res) decay_steps = decay_steps * div_res else: @@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values): shape=[1], dtype='float32', value=float(boundaries[i])) value_var = layers.fill_constant( shape=[1], dtype='float32', value=float(values[i])) - with switch.case(layers.less_than(global_step, boundary_val)): + with switch.case(global_step < boundary_val): layers.assign(value_var, lr) last_value_var = layers.fill_constant( shape=[1], diff --git a/python/paddle/v2/fluid/tests/test_detection.py b/python/paddle/v2/fluid/tests/test_detection.py index d2207f1bfa9c73d0ededd8f353066e50e59c0522..d50efb3f7466d0ccedf41d16d1faa3d27fefd28e 100644 --- a/python/paddle/v2/fluid/tests/test_detection.py +++ b/python/paddle/v2/fluid/tests/test_detection.py @@ -14,15 +14,10 @@ from __future__ import print_function import paddle.v2.fluid as fluid -import paddle.v2.fluid.core as core import paddle.v2.fluid.layers as layers import paddle.v2.fluid.layers.detection as detection from paddle.v2.fluid.framework import Program, program_guard import unittest -import numpy as np - -import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.framework import Program, program_guard class TestBook(unittest.TestCase): @@ -55,15 +50,67 @@ class TestBook(unittest.TestCase): print(str(program)) +class TestPriorBox(unittest.TestCase): + def test_prior_box(self): + data_shape = [3, 224, 224] + box, var = self.prior_box_output(data_shape) + + assert len(box.shape) == 2 + assert box.shape == var.shape + assert box.shape[1] == 4 + + def prior_box_output(self, data_shape): + images = fluid.layers.data( + name='pixel', shape=data_shape, dtype='float32') + conv1 = fluid.layers.conv2d( + input=images, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv2 = fluid.layers.conv2d( + input=conv1, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv3 = fluid.layers.conv2d( + input=conv2, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv4 = fluid.layers.conv2d( + input=conv3, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + conv5 = fluid.layers.conv2d( + input=conv4, + num_filters=3, + filter_size=3, + stride=2, + use_cudnn=False) + + box, var = detection.prior_box( + inputs=[conv1, conv2, conv3, conv4, conv5, conv5], + image=images, + min_ratio=20, + max_ratio=90, + # steps=[8, 16, 32, 64, 100, 300], + aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + base_size=300, + offset=0.5, + flip=True, + clip=True) + return box, var + + class TestMultiBoxHead(unittest.TestCase): def test_prior_box(self): data_shape = [3, 224, 224] mbox_locs, mbox_confs = self.multi_box_output(data_shape) - # print mbox_locs.shape - # print mbox_confs.shape - # assert len(box.shape) == 2 - # assert box.shape == var.shape - # assert box.shape[1] == 4 def multi_box_output(self, data_shape): images = fluid.layers.data( diff --git a/python/paddle/v2/fluid/tests/test_detection_map_op.py b/python/paddle/v2/fluid/tests/test_detection_map_op.py new file mode 100644 index 0000000000000000000000000000000000000000..70ccd885d89f245df492bad0fbcecc093dc1928c --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_detection_map_op.py @@ -0,0 +1,265 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +import collections +import math +from op_test import OpTest + + +class TestDetectionMAPOp(OpTest): + def set_data(self): + self.init_test_case() + + self.mAP = [self.calc_map(self.tf_pos, self.tf_pos_lod)] + self.label = np.array(self.label).astype('float32') + self.detect = np.array(self.detect).astype('float32') + self.mAP = np.array(self.mAP).astype('float32') + + if (len(self.class_pos_count) > 0): + self.class_pos_count = np.array(self.class_pos_count).astype( + 'int32') + self.true_pos = np.array(self.true_pos).astype('float32') + self.false_pos = np.array(self.false_pos).astype('float32') + + self.inputs = { + 'Label': (self.label, self.label_lod), + 'DetectRes': (self.detect, self.detect_lod), + 'PosCount': self.class_pos_count, + 'TruePos': (self.true_pos, self.true_pos_lod), + 'FalsePos': (self.false_pos, self.false_pos_lod) + } + else: + self.inputs = { + 'Label': (self.label, self.label_lod), + 'DetectRes': (self.detect, self.detect_lod), + } + + self.attrs = { + 'overlap_threshold': self.overlap_threshold, + 'evaluate_difficult': self.evaluate_difficult, + 'ap_type': self.ap_type + } + + self.out_class_pos_count = np.array(self.out_class_pos_count).astype( + 'int') + self.out_true_pos = np.array(self.out_true_pos).astype('float32') + self.out_false_pos = np.array(self.out_false_pos).astype('float32') + + self.outputs = { + 'MAP': self.mAP, + 'AccumPosCount': self.out_class_pos_count, + 'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod), + 'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod) + } + + def init_test_case(self): + self.overlap_threshold = 0.3 + self.evaluate_difficult = True + self.ap_type = "integral" + + self.label_lod = [[0, 2, 4]] + # label difficult xmin ymin xmax ymax + self.label = [[1, 0, 0.1, 0.1, 0.3, 0.3], [1, 1, 0.6, 0.6, 0.8, 0.8], + [2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]] + + # label score xmin ymin xmax ymax difficult + self.detect_lod = [[0, 3, 7]] + self.detect = [ + [1, 0.3, 0.1, 0.0, 0.4, 0.3], [1, 0.7, 0.0, 0.1, 0.2, 0.3], + [1, 0.9, 0.7, 0.6, 0.8, 0.8], [2, 0.8, 0.2, 0.1, 0.4, 0.4], + [2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 0.2, 0.8, 0.1, 1.0, 0.3], + [3, 0.2, 0.8, 0.1, 1.0, 0.3] + ] + + # label score true_pos false_pos + self.tf_pos_lod = [[0, 3, 7]] + self.tf_pos = [[1, 0.9, 1, 0], [1, 0.7, 1, 0], [1, 0.3, 0, 1], + [1, 0.2, 1, 0], [2, 0.8, 0, 1], [2, 0.1, 1, 0], + [3, 0.2, 0, 1]] + + self.class_pos_count = [] + self.true_pos_lod = [[]] + self.true_pos = [[]] + self.false_pos_lod = [[]] + self.false_pos = [[]] + + def calc_map(self, tf_pos, tf_pos_lod): + mAP = 0.0 + count = 0 + + def get_input_pos(class_pos_count, true_pos, true_pos_lod, false_pos, + false_pos_lod): + class_pos_count_dict = collections.Counter() + true_pos_dict = collections.defaultdict(list) + false_pos_dict = collections.defaultdict(list) + for i, count in enumerate(class_pos_count): + class_pos_count_dict[i] = count + + for i in range(len(true_pos_lod[0]) - 1): + start = true_pos_lod[0][i] + end = true_pos_lod[0][i + 1] + for j in range(start, end): + true_pos_dict[i].append(true_pos[j]) + + for i in range(len(false_pos_lod[0]) - 1): + start = false_pos_lod[0][i] + end = false_pos_lod[0][i + 1] + for j in range(start, end): + false_pos_dict[i].append(false_pos[j]) + + return class_pos_count_dict, true_pos_dict, false_pos_dict + + def get_output_pos(label_count, true_pos, false_pos): + max_label = 0 + for (label, label_pos_num) in label_count.items(): + if max_label < label: + max_label = label + + label_number = max_label + 1 + + out_class_pos_count = [] + out_true_pos_lod = [0] + out_true_pos = [] + out_false_pos_lod = [0] + out_false_pos = [] + + for i in range(label_number): + out_class_pos_count.append([label_count[i]]) + true_pos_list = true_pos[i] + out_true_pos += true_pos_list + out_true_pos_lod.append(len(out_true_pos)) + false_pos_list = false_pos[i] + out_false_pos += false_pos_list + out_false_pos_lod.append(len(out_false_pos)) + + return out_class_pos_count, out_true_pos, [ + out_true_pos_lod + ], out_false_pos, [out_false_pos_lod] + + def get_accumulation(pos_list): + sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True) + sum = 0 + accu_list = [] + for (score, count) in sorted_list: + sum += count + accu_list.append(sum) + return accu_list + + label_count, true_pos, false_pos = get_input_pos( + self.class_pos_count, self.true_pos, self.true_pos_lod, + self.false_pos, self.false_pos_lod) + for (label, difficult, xmin, ymin, xmax, ymax) in self.label: + if self.evaluate_difficult: + label_count[label] += 1 + elif not difficult: + label_count[label] += 1 + + true_pos = collections.defaultdict(list) + false_pos = collections.defaultdict(list) + for (label, score, tp, fp) in tf_pos: + true_pos[label].append([score, tp]) + false_pos[label].append([score, fp]) + + for (label, label_pos_num) in label_count.items(): + if label_pos_num == 0 or label not in true_pos: continue + label_true_pos = true_pos[label] + label_false_pos = false_pos[label] + + accu_tp_sum = get_accumulation(label_true_pos) + accu_fp_sum = get_accumulation(label_false_pos) + + precision = [] + recall = [] + + for i in range(len(accu_tp_sum)): + precision.append( + float(accu_tp_sum[i]) / + float(accu_tp_sum[i] + accu_fp_sum[i])) + recall.append(float(accu_tp_sum[i]) / label_pos_num) + + if self.ap_type == "11point": + max_precisions = [0.0] * 11 + start_idx = len(accu_tp_sum) - 1 + for j in range(10, -1, -1): + for i in range(start_idx, -1, -1): + if recall[i] < float(j) / 10.0: + start_idx = i + if j > 0: + max_precisions[j - 1] = max_precisions[j] + break + else: + if max_precisions[j] < precision[i]: + max_precisions[j] = precision[i] + for j in range(10, -1, -1): + mAP += max_precisions[j] / 11 + count += 1 + elif self.ap_type == "integral": + average_precisions = 0.0 + prev_recall = 0.0 + for i in range(len(accu_tp_sum)): + if math.fabs(recall[i] - prev_recall) > 1e-6: + average_precisions += precision[i] * \ + math.fabs(recall[i] - prev_recall) + prev_recall = recall[i] + + mAP += average_precisions + count += 1 + self.out_class_pos_count, self.out_true_pos, self.out_true_pos_lod, self.out_false_pos, self.out_false_pos_lod = get_output_pos( + label_count, true_pos, false_pos) + if count != 0: + mAP /= count + return mAP * 100.0 + + def setUp(self): + self.op_type = "detection_map" + self.set_data() + + def test_check_output(self): + self.check_output() + + +class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOpSkipDiff, self).init_test_case() + + self.evaluate_difficult = False + + self.tf_pos_lod = [[0, 2, 6]] + # label score true_pos false_pos + self.tf_pos = [[1, 0.7, 1, 0], [1, 0.3, 0, 1], [1, 0.2, 1, 0], + [2, 0.8, 0, 1], [2, 0.1, 1, 0], [3, 0.2, 0, 1]] + + +class TestDetectionMAPOp11Point(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOp11Point, self).init_test_case() + + self.ap_type = "11point" + + +class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp): + def init_test_case(self): + super(TestDetectionMAPOpMultiBatch, self).init_test_case() + self.class_pos_count = [0, 2, 1] + self.true_pos_lod = [[0, 0, 3, 5]] + self.true_pos = [[0.7, 1.], [0.3, 0.], [0.2, 1.], [0.8, 0.], [0.1, 1.]] + self.false_pos_lod = [[0, 0, 3, 5]] + self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index aea43c2517a02c72c1ee3307afdd3b21910f0064..50ef8204249250b5ca1555a5192bc3ed0ca108b9 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -161,8 +161,8 @@ class TestBook(unittest.TestCase): label=label, chunk_scheme="IOB", num_chunk_types=(label_dict_len - 1) / 2) - self.assertNotEqual(crf, None) - self.assertNotEqual(crf_decode, None) + self.assertFalse(crf is None) + self.assertFalse(crf_decode is None) print(str(program)) @@ -309,6 +309,24 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(out) print(str(program)) + def test_softmax_with_cross_entropy(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[16], dtype='float32') + y = layers.data(name='label', shape=[1], dtype='int64') + loss = layers.softmax_with_cross_entropy(x, y) + self.assertIsNotNone(loss) + print(str(program)) + + def test_smooth_l1(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[4], dtype='float32') + y = layers.data(name='label', shape=[4], dtype='float32') + loss = layers.smooth_l1(x, y) + self.assertIsNotNone(loss) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_prior_box_op.py b/python/paddle/v2/fluid/tests/test_prior_box_op.py index ca8d2bca74ce2d4be8160c8851e393489691ae56..a6c21af49f63269720156ec833c94688d0e3230e 100644 --- a/python/paddle/v2/fluid/tests/test_prior_box_op.py +++ b/python/paddle/v2/fluid/tests/test_prior_box_op.py @@ -65,9 +65,9 @@ class TestPriorBoxOp(OpTest): self.batch_size = 10 self.min_sizes = [2, 4] - self.min_sizes = np.array(self.min_sizes).astype('int64') + self.min_sizes = np.array(self.min_sizes).astype('float32').tolist() self.max_sizes = [5, 10] - self.max_sizes = np.array(self.max_sizes).astype('int64') + self.max_sizes = np.array(self.max_sizes).astype('float32').tolist() self.aspect_ratios = [2.0, 3.0] self.flip = True self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0] diff --git a/python/paddle/v2/fluid/tests/test_python_operator_overriding.py b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py new file mode 100644 index 0000000000000000000000000000000000000000..e5198ec17d027f007b4a831ef2e427481f8ff8c4 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_python_operator_overriding.py @@ -0,0 +1,76 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.framework as framework +import paddle.v2.fluid as fluid + + +class TestPythonOperatorOverride(unittest.TestCase): + def check_result(self, fn, place, dtype): + shape = [9, 10] + + x_data = np.random.random(size=shape).astype(dtype) + y_data = np.random.random(size=shape).astype(dtype) + python_out = fn(x_data, y_data) + + x_var = layers.create_global_var( + name='x', shape=shape, value=0.0, dtype=dtype, persistable=True) + y_var = layers.create_global_var( + name='y', shape=shape, value=0.0, dtype=dtype, persistable=True) + out = fn(x_var, y_var) + + exe = fluid.Executor(place) + + exe.run(fluid.default_startup_program()) + fluid_out = exe.run(fluid.default_main_program(), + feed={'x': x_data, + 'y': y_data}, + fetch_list=[out]) + + np.testing.assert_array_equal(python_out, fluid_out[0]) + + def test_override(self): + # compare func to check + compare_fns = [ + lambda _a, _b: _a == _b, + lambda _a, _b: _a != _b, + lambda _a, _b: _a < _b, + lambda _a, _b: _a <= _b, + lambda _a, _b: _a > _b, + lambda _a, _b: _a >= _b, + ] + + # places to check + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + # dtypes to check + dtypes = ['int32', 'float32'] + + for place in places: + for dtype in dtypes: + for compare_fn in compare_fns: + with framework.program_guard(framework.Program(), + framework.Program()): + self.check_result(compare_fn, place, dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/manylinux1/Dockerfile.x64 b/tools/manylinux1/Dockerfile.x64 index 0f1b8331309248aaaf0ed32cf14c583a4cdb7437..93cab692e363cde43bdd4dd9ad20f4a2c06be121 100644 --- a/tools/manylinux1/Dockerfile.x64 +++ b/tools/manylinux1/Dockerfile.x64 @@ -52,3 +52,5 @@ RUN wget -O /opt/swig-2.0.12.tar.gz https://sourceforge.net/projects/swig/files/ RUN mkdir -p /src && cd /src && git clone https://github.com/NVIDIA/nccl.git nccl && cd nccl &&\ make -j `nproc` install && cd .. && rm -rf nccl + +CMD ["bash", "/paddle/paddle/scripts/docker/build.sh"]