提交 99c9dbf5 编写于 作者: C chengduoZH

remove conflict

......@@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
......@@ -17,3 +17,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);
......@@ -48,6 +48,14 @@ struct EqualFunctor {
}
};
template <typename T>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return !EqualFunctor<T>()(a, b);
}
};
template <typename DeviceContext, typename Functor>
class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection_map_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class DetectionMAPOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DetectRes"),
"Input(DetectRes) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumPosCount"),
"Output(AccumPosCount) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumTruePos"),
"Output(AccumTruePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("AccumFalsePos"),
"Output(AccumFalsePos) of DetectionMAPOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MAP"),
"Output(MAP) of DetectionMAPOp should not be null.");
auto det_dims = ctx->GetInputDim("DetectRes");
PADDLE_ENFORCE_EQ(det_dims.size(), 2UL,
"The rank of Input(DetectRes) must be 2, "
"the shape is [N, 6].");
PADDLE_ENFORCE_EQ(det_dims[1], 6UL,
"The shape is of Input(DetectRes) [N, 6].");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
"The rank of Input(Label) must be 2, "
"the shape is [N, 6].");
PADDLE_ENFORCE_EQ(label_dims[1], 6UL,
"The shape is of Input(Label) [N, 6].");
if (ctx->HasInput("PosCount")) {
PADDLE_ENFORCE(ctx->HasInput("TruePos"),
"Input(TruePos) of DetectionMAPOp should not be null when "
"Input(TruePos) is not null.");
PADDLE_ENFORCE(
ctx->HasInput("FalsePos"),
"Input(FalsePos) of DetectionMAPOp should not be null when "
"Input(FalsePos) is not null.");
}
ctx->SetOutputDim("MAP", framework::make_ddim({1}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.Input<framework::Tensor>("DetectRes")->type()),
ctx.device_context());
}
};
class DetectionMAPOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionMAPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("DetectRes",
"(LoDTensor) A 2-D LoDTensor with shape [M, 6] represents the "
"detections. Each row has 6 values: "
"[label, confidence, xmin, ymin, xmax, ymax], M is the total "
"number of detect results in this mini-batch. For each instance, "
"the offsets in first dimension are called LoD, the number of "
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
"no detected data.");
AddInput("Label",
"(LoDTensor) A 2-D LoDTensor with shape[N, 6] represents the"
"Labeled ground-truth data. Each row has 6 values: "
"[label, is_difficult, xmin, ymin, xmax, ymax], N is the total "
"number of ground-truth data in this mini-batch. For each "
"instance, the offsets in first dimension are called LoD, "
"the number of offset is N + 1, if LoD[i + 1] - LoD[i] == 0, "
"means there is no ground-truth data.");
AddInput("PosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the "
"input positive example count of each class, Ncls is the count of "
"input classification. "
"This input is used to pass the AccumPosCount generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. "
"When the input(PosCount) is empty, the cumulative "
"calculation is not carried out, and only the results of the "
"current mini-batch are calculated.")
.AsDispensable();
AddInput("TruePos",
"(LoDTensor) A 2-D LoDTensor with shape [Ntp, 2], store the "
"input true positive example of each class."
"This input is used to pass the AccumTruePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable();
AddInput("FalsePos",
"(LoDTensor) A 2-D LoDTensor with shape [Nfp, 2], store the "
"input false positive example of each class."
"This input is used to pass the AccumFalsePos generated by the "
"previous mini-batch when the multi mini-batches cumulative "
"calculation carried out. ")
.AsDispensable();
AddOutput("AccumPosCount",
"(Tensor) A tensor with shape [Ncls, 1], store the "
"positive example count of each class. It combines the input "
"input(PosCount) and the positive example count computed from "
"input(Detection) and input(Label).");
AddOutput("AccumTruePos",
"(LoDTensor) A LoDTensor with shape [Ntp', 2], store the "
"true positive example of each class. It combines the "
"input(TruePos) and the true positive examples computed from "
"input(Detection) and input(Label).");
AddOutput("AccumFalsePos",
"(LoDTensor) A LoDTensor with shape [Nfp', 2], store the "
"false positive example of each class. It combines the "
"input(FalsePos) and the false positive examples computed from "
"input(Detection) and input(Label).");
AddOutput("MAP",
"(Tensor) A tensor with shape [1], store the mAP evaluate "
"result of the detection.");
AddAttr<float>(
"overlap_threshold",
"(float) "
"The lower bound jaccard overlap threshold of detection output and "
"ground-truth data.")
.SetDefault(.3f);
AddAttr<bool>("evaluate_difficult",
"(bool, default true) "
"Switch to control whether the difficult data is evaluated.")
.SetDefault(true);
AddAttr<std::string>("ap_type",
"(string, default 'integral') "
"The AP algorithm type, 'integral' or '11point'.")
.SetDefault("integral")
.InEnum({"integral", "11point"})
.AddCustomChecker([](const std::string& ap_type) {
PADDLE_ENFORCE_NE(GetAPType(ap_type), APType::kNone,
"The ap_type should be 'integral' or '11point.");
});
AddComment(R"DOC(
Detection mAP evaluate operator.
The general steps are as follows. First, calculate the true positive and
false positive according to the input of detection and labels, then
calculate the mAP evaluate value.
Supporting '11 point' and 'integral' mAP algorithm. Please get more information
from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(detection_map, ops::DetectionMAPOp,
ops::DetectionMAPOpMaker);
REGISTER_OP_CPU_KERNEL(
detection_map, ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, float>,
ops::DetectionMAPOpKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
enum APType { kNone = 0, kIntegral, k11point };
APType GetAPType(std::string str) {
if (str == "integral") {
return APType::kIntegral;
} else if (str == "11point") {
return APType::k11point;
} else {
return APType::kNone;
}
}
template <typename T>
inline bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <typename T>
inline void GetAccumulation(std::vector<std::pair<T, int>> in_pairs,
std::vector<int>* accu_vec) {
std::stable_sort(in_pairs.begin(), in_pairs.end(), SortScorePairDescend<int>);
accu_vec->clear();
size_t sum = 0;
for (size_t i = 0; i < in_pairs.size(); ++i) {
auto count = in_pairs[i].second;
sum += count;
accu_vec->push_back(sum);
}
}
template <typename Place, typename T>
class DetectionMAPOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_detect = ctx.Input<framework::LoDTensor>("DetectRes");
auto* in_label = ctx.Input<framework::LoDTensor>("Label");
auto* out_map = ctx.Output<framework::Tensor>("MAP");
auto* in_pos_count = ctx.Input<framework::Tensor>("PosCount");
auto* in_true_pos = ctx.Input<framework::LoDTensor>("TruePos");
auto* in_false_pos = ctx.Input<framework::LoDTensor>("FalsePos");
auto* out_pos_count = ctx.Output<framework::Tensor>("AccumPosCount");
auto* out_true_pos = ctx.Output<framework::LoDTensor>("AccumTruePos");
auto* out_false_pos = ctx.Output<framework::LoDTensor>("AccumFalsePos");
float overlap_threshold = ctx.Attr<float>("overlap_threshold");
float evaluate_difficult = ctx.Attr<bool>("evaluate_difficult");
auto ap_type = GetAPType(ctx.Attr<std::string>("ap_type"));
auto label_lod = in_label->lod();
auto detect_lod = in_detect->lod();
PADDLE_ENFORCE_EQ(label_lod.size(), 1UL,
"Only support one level sequence now.");
PADDLE_ENFORCE_EQ(label_lod[0].size(), detect_lod[0].size(),
"The batch_size of input(Label) and input(Detection) "
"must be the same.");
std::vector<std::map<int, std::vector<Box>>> gt_boxes;
std::vector<std::map<int, std::vector<std::pair<T, Box>>>> detect_boxes;
GetBoxes(*in_label, *in_detect, gt_boxes, detect_boxes);
std::map<int, int> label_pos_count;
std::map<int, std::vector<std::pair<T, int>>> true_pos;
std::map<int, std::vector<std::pair<T, int>>> false_pos;
if (in_pos_count != nullptr) {
GetInputPos(*in_pos_count, *in_true_pos, *in_false_pos, label_pos_count,
true_pos, false_pos);
}
CalcTrueAndFalsePositive(gt_boxes, detect_boxes, evaluate_difficult,
overlap_threshold, label_pos_count, true_pos,
false_pos);
T map = CalcMAP(ap_type, label_pos_count, true_pos, false_pos);
GetOutputPos(ctx, label_pos_count, true_pos, false_pos, *out_pos_count,
*out_true_pos, *out_false_pos);
T* map_data = out_map->mutable_data<T>(ctx.GetPlace());
map_data[0] = map;
}
protected:
struct Box {
Box(T xmin, T ymin, T xmax, T ymax)
: xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax), is_difficult(false) {}
T xmin, ymin, xmax, ymax;
bool is_difficult;
};
inline T JaccardOverlap(const Box& box1, const Box& box2) const {
if (box2.xmin > box1.xmax || box2.xmax < box1.xmin ||
box2.ymin > box1.ymax || box2.ymax < box1.ymin) {
return 0.0;
} else {
T inter_xmin = std::max(box1.xmin, box2.xmin);
T inter_ymin = std::max(box1.ymin, box2.ymin);
T inter_xmax = std::min(box1.xmax, box2.xmax);
T inter_ymax = std::min(box1.ymax, box2.ymax);
T inter_width = inter_xmax - inter_xmin;
T inter_height = inter_ymax - inter_ymin;
T inter_area = inter_width * inter_height;
T bbox_area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
T bbox_area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
return inter_area / (bbox_area1 + bbox_area2 - inter_area);
}
}
void GetBoxes(const framework::LoDTensor& input_label,
const framework::LoDTensor& input_detect,
std::vector<std::map<int, std::vector<Box>>>& gt_boxes,
std::vector<std::map<int, std::vector<std::pair<T, Box>>>>&
detect_boxes) const {
auto labels = framework::EigenTensor<T, 2>::From(input_label);
auto detect = framework::EigenTensor<T, 2>::From(input_detect);
auto label_lod = input_label.lod();
auto detect_lod = input_detect.lod();
int batch_size = label_lod[0].size() - 1;
auto label_index = label_lod[0];
for (int n = 0; n < batch_size; ++n) {
std::map<int, std::vector<Box>> boxes;
for (int i = label_index[n]; i < label_index[n + 1]; ++i) {
Box box(labels(i, 2), labels(i, 3), labels(i, 4), labels(i, 5));
int label = labels(i, 0);
auto is_difficult = labels(i, 1);
if (std::abs(is_difficult - 0.0) < 1e-6)
box.is_difficult = false;
else
box.is_difficult = true;
boxes[label].push_back(box);
}
gt_boxes.push_back(boxes);
}
auto detect_index = detect_lod[0];
for (int n = 0; n < batch_size; ++n) {
std::map<int, std::vector<std::pair<T, Box>>> boxes;
for (int i = detect_index[n]; i < detect_index[n + 1]; ++i) {
Box box(detect(i, 2), detect(i, 3), detect(i, 4), detect(i, 5));
int label = detect(i, 0);
auto score = detect(i, 1);
boxes[label].push_back(std::make_pair(score, box));
}
detect_boxes.push_back(boxes);
}
}
void GetOutputPos(
const framework::ExecutionContext& ctx,
const std::map<int, int>& label_pos_count,
const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
const std::map<int, std::vector<std::pair<T, int>>>& false_pos,
framework::Tensor& output_pos_count,
framework::LoDTensor& output_true_pos,
framework::LoDTensor& output_false_pos) const {
int max_class_id = 0;
int true_pos_count = 0;
int false_pos_count = 0;
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
int label = it->first;
if (label > max_class_id) max_class_id = label;
int label_num_pos = it->second;
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end())
continue;
auto label_true_pos = true_pos.find(label)->second;
auto label_false_pos = false_pos.find(label)->second;
true_pos_count += label_true_pos.size();
false_pos_count += label_false_pos.size();
}
int* pos_count_data = output_pos_count.mutable_data<int>(
framework::make_ddim({max_class_id + 1, 1}), ctx.GetPlace());
T* true_pos_data = output_true_pos.mutable_data<T>(
framework::make_ddim({true_pos_count, 2}), ctx.GetPlace());
T* false_pos_data = output_false_pos.mutable_data<T>(
framework::make_ddim({false_pos_count, 2}), ctx.GetPlace());
true_pos_count = 0;
false_pos_count = 0;
std::vector<size_t> true_pos_starts = {0};
std::vector<size_t> false_pos_starts = {0};
for (int i = 0; i <= max_class_id; ++i) {
auto it_count = label_pos_count.find(i);
pos_count_data[i] = 0;
if (it_count != label_pos_count.end()) {
pos_count_data[i] = it_count->second;
}
auto it_true_pos = true_pos.find(i);
if (it_true_pos != true_pos.end()) {
const std::vector<std::pair<T, int>>& true_pos_vec =
it_true_pos->second;
for (const std::pair<T, int>& tp : true_pos_vec) {
true_pos_data[true_pos_count * 2] = tp.first;
true_pos_data[true_pos_count * 2 + 1] = static_cast<T>(tp.second);
true_pos_count++;
}
}
true_pos_starts.push_back(true_pos_count);
auto it_false_pos = false_pos.find(i);
if (it_false_pos != false_pos.end()) {
const std::vector<std::pair<T, int>>& false_pos_vec =
it_false_pos->second;
for (const std::pair<T, int>& fp : false_pos_vec) {
false_pos_data[false_pos_count * 2] = fp.first;
false_pos_data[false_pos_count * 2 + 1] = static_cast<T>(fp.second);
false_pos_count++;
}
}
false_pos_starts.push_back(false_pos_count);
}
framework::LoD true_pos_lod;
true_pos_lod.emplace_back(true_pos_starts);
framework::LoD false_pos_lod;
false_pos_lod.emplace_back(false_pos_starts);
output_true_pos.set_lod(true_pos_lod);
output_false_pos.set_lod(false_pos_lod);
return;
}
void GetInputPos(
const framework::Tensor& input_pos_count,
const framework::LoDTensor& input_true_pos,
const framework::LoDTensor& input_false_pos,
std::map<int, int>& label_pos_count,
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
constexpr T kEPS = static_cast<T>(1e-6);
int class_number = input_pos_count.dims()[0];
const int* pos_count_data = input_pos_count.data<int>();
for (int i = 0; i < class_number; ++i) {
label_pos_count[i] = pos_count_data[i];
}
auto SetData = [](const framework::LoDTensor& pos_tensor,
std::map<int, std::vector<std::pair<T, int>>>& pos) {
const T* pos_data = pos_tensor.data<T>();
auto pos_data_lod = pos_tensor.lod();
for (int i = 0; i < pos_data_lod.size(); ++i) {
for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) {
T score = pos_data[j * 2];
int flag = 1;
if (pos_data[j * 2 + 1] < kEPS) flag = 0;
pos[i].push_back(std::make_pair(score, flag));
}
}
};
SetData(input_true_pos, true_pos);
SetData(input_false_pos, false_pos);
return;
}
void CalcTrueAndFalsePositive(
const std::vector<std::map<int, std::vector<Box>>>& gt_boxes,
const std::vector<std::map<int, std::vector<std::pair<T, Box>>>>&
detect_boxes,
bool evaluate_difficult, float overlap_threshold,
std::map<int, int>& label_pos_count,
std::map<int, std::vector<std::pair<T, int>>>& true_pos,
std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
int batch_size = gt_boxes.size();
for (int n = 0; n < batch_size; ++n) {
auto image_gt_boxes = gt_boxes[n];
for (auto it = image_gt_boxes.begin(); it != image_gt_boxes.end(); ++it) {
size_t count = 0;
auto labeled_bboxes = it->second;
if (evaluate_difficult) {
count = labeled_bboxes.size();
} else {
for (size_t i = 0; i < labeled_bboxes.size(); ++i)
if (!(labeled_bboxes[i].is_difficult)) ++count;
}
if (count == 0) {
continue;
}
int label = it->first;
if (label_pos_count.find(label) == label_pos_count.end()) {
label_pos_count[label] = count;
} else {
label_pos_count[label] += count;
}
}
}
for (size_t n = 0; n < detect_boxes.size(); ++n) {
auto image_gt_boxes = gt_boxes[n];
auto detections = detect_boxes[n];
if (image_gt_boxes.size() == 0) {
for (auto it = detections.begin(); it != detections.end(); ++it) {
auto pred_boxes = it->second;
int label = it->first;
for (size_t i = 0; i < pred_boxes.size(); ++i) {
auto score = pred_boxes[i].first;
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
}
continue;
}
for (auto it = detections.begin(); it != detections.end(); ++it) {
int label = it->first;
auto pred_boxes = it->second;
if (image_gt_boxes.find(label) == image_gt_boxes.end()) {
for (size_t i = 0; i < pred_boxes.size(); ++i) {
auto score = pred_boxes[i].first;
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
continue;
}
auto matched_bboxes = image_gt_boxes.find(label)->second;
std::vector<bool> visited(matched_bboxes.size(), false);
// Sort detections in descend order based on scores
std::sort(pred_boxes.begin(), pred_boxes.end(),
SortScorePairDescend<Box>);
for (size_t i = 0; i < pred_boxes.size(); ++i) {
T max_overlap = -1.0;
size_t max_idx = 0;
auto score = pred_boxes[i].first;
for (size_t j = 0; j < matched_bboxes.size(); ++j) {
T overlap = JaccardOverlap(pred_boxes[i].second, matched_bboxes[j]);
if (overlap > max_overlap) {
max_overlap = overlap;
max_idx = j;
}
}
if (max_overlap > overlap_threshold) {
bool match_evaluate_difficult =
evaluate_difficult ||
(!evaluate_difficult && !matched_bboxes[max_idx].is_difficult);
if (match_evaluate_difficult) {
if (!visited[max_idx]) {
true_pos[label].push_back(std::make_pair(score, 1));
false_pos[label].push_back(std::make_pair(score, 0));
visited[max_idx] = true;
} else {
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
}
} else {
true_pos[label].push_back(std::make_pair(score, 0));
false_pos[label].push_back(std::make_pair(score, 1));
}
}
}
}
}
T CalcMAP(
APType ap_type, const std::map<int, int>& label_pos_count,
const std::map<int, std::vector<std::pair<T, int>>>& true_pos,
const std::map<int, std::vector<std::pair<T, int>>>& false_pos) const {
T mAP = 0.0;
int count = 0;
for (auto it = label_pos_count.begin(); it != label_pos_count.end(); ++it) {
int label = it->first;
int label_num_pos = it->second;
if (label_num_pos == 0 || true_pos.find(label) == true_pos.end())
continue;
auto label_true_pos = true_pos.find(label)->second;
auto label_false_pos = false_pos.find(label)->second;
// Compute average precision.
std::vector<int> tp_sum;
GetAccumulation<T>(label_true_pos, &tp_sum);
std::vector<int> fp_sum;
GetAccumulation<T>(label_false_pos, &fp_sum);
std::vector<T> precision, recall;
size_t num = tp_sum.size();
// Compute Precision.
for (size_t i = 0; i < num; ++i) {
precision.push_back(static_cast<T>(tp_sum[i]) /
static_cast<T>(tp_sum[i] + fp_sum[i]));
recall.push_back(static_cast<T>(tp_sum[i]) / label_num_pos);
}
// VOC2007 style
if (ap_type == APType::k11point) {
std::vector<T> max_precisions(11, 0.0);
int start_idx = num - 1;
for (int j = 10; j >= 0; --j)
for (int i = start_idx; i >= 0; --i) {
if (recall[i] < j / 10.) {
start_idx = i;
if (j > 0) max_precisions[j - 1] = max_precisions[j];
break;
} else {
if (max_precisions[j] < precision[i])
max_precisions[j] = precision[i];
}
}
for (int j = 10; j >= 0; --j) mAP += max_precisions[j] / 11;
++count;
} else if (ap_type == APType::kIntegral) {
// Nature integral
float average_precisions = 0.;
float prev_recall = 0.;
for (size_t i = 0; i < num; ++i) {
if (fabs(recall[i] - prev_recall) > 1e-6)
average_precisions += precision[i] * fabs(recall[i] - prev_recall);
prev_recall = recall[i];
}
mAP += average_precisions;
++count;
} else {
LOG(FATAL) << "Unkown ap version: " << ap_type;
}
}
if (count != 0) mAP /= count;
return mAP * 100;
}
}; // namespace operators
} // namespace operators
} // namespace paddle
......@@ -38,8 +38,8 @@ class PriorBoxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_LT(input_dims[3], image_dims[3],
"The width of input must smaller than image.");
auto min_sizes = ctx->Attrs().Get<std::vector<int>>("min_sizes");
auto max_sizes = ctx->Attrs().Get<std::vector<int>>("max_sizes");
auto min_sizes = ctx->Attrs().Get<std::vector<float>>("min_sizes");
auto max_sizes = ctx->Attrs().Get<std::vector<float>>("max_sizes");
auto variances = ctx->Attrs().Get<std::vector<float>>("variances");
auto aspect_ratios = ctx->Attrs().Get<std::vector<float>>("aspect_ratios");
bool flip = ctx->Attrs().Get<bool>("flip");
......@@ -47,15 +47,15 @@ class PriorBoxOp : public framework::OperatorWithKernel {
std::vector<float> aspect_ratios_vec;
ExpandAspectRatios(aspect_ratios, flip, aspect_ratios_vec);
int num_priors = aspect_ratios_vec.size() * min_sizes.size();
size_t num_priors = aspect_ratios_vec.size() * min_sizes.size();
if (max_sizes.size() > 0) {
PADDLE_ENFORCE_EQ(max_sizes.size(), min_sizes.size(),
"The number of min_size and max_size must be equal.");
for (size_t i = 0; i < min_sizes.size(); ++i) {
num_priors += max_sizes.size();
for (size_t i = 0; i < max_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(max_sizes[i], min_sizes[i],
"max_size[%d] must be greater than min_size[%d].", i,
i);
num_priors += 1;
}
}
......@@ -90,20 +90,20 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"H is the height of input, W is the width of input, num_priors "
"is the box count of each position.");
AddAttr<std::vector<int>>("min_sizes",
"(vector<int>) List of min sizes "
"of generated prior boxes.")
.AddCustomChecker([](const std::vector<int>& min_sizes) {
AddAttr<std::vector<float>>("min_sizes",
"(vector<float>) List of min sizes "
"of generated prior boxes.")
.AddCustomChecker([](const std::vector<float>& min_sizes) {
PADDLE_ENFORCE_GT(min_sizes.size(), 0,
"Size of min_sizes must be at least 1.");
for (size_t i = 0; i < min_sizes.size(); ++i) {
PADDLE_ENFORCE_GT(min_sizes[i], 0,
PADDLE_ENFORCE_GT(min_sizes[i], 0.0,
"min_sizes[%d] must be positive.", i);
}
});
AddAttr<std::vector<int>>(
AddAttr<std::vector<float>>(
"max_sizes",
"(vector<int>) List of max sizes of generated prior boxes.");
"(vector<float>) List of max sizes of generated prior boxes.");
AddAttr<std::vector<float>>(
"aspect_ratios",
"(vector<float>) List of aspect ratios of generated prior boxes.");
......@@ -125,16 +125,16 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(true);
AddAttr<float>("step_w",
"Prior boxes step across width, 0 for auto calculation.")
"Prior boxes step across width, 0.0 for auto calculation.")
.SetDefault(0.0)
.AddCustomChecker([](const float& step_w) {
PADDLE_ENFORCE_GT(step_w, 0.0, "step_w should be larger than 0.");
PADDLE_ENFORCE_GE(step_w, 0.0, "step_w should be larger than 0.");
});
AddAttr<float>("step_h",
"Prior boxes step across height, 0 for auto calculation.")
"Prior boxes step across height, 0.0 for auto calculation.")
.SetDefault(0.0)
.AddCustomChecker([](const float& step_h) {
PADDLE_ENFORCE_GT(step_h, 0.0, "step_h should be larger than 0.");
PADDLE_ENFORCE_GE(step_h, 0.0, "step_h should be larger than 0.");
});
AddAttr<float>("offset",
......
......@@ -60,8 +60,8 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
auto min_sizes = ctx.Attr<std::vector<int>>("min_sizes");
auto max_sizes = ctx.Attr<std::vector<int>>("max_sizes");
auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes");
auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes");
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
auto variances = ctx.Attr<std::vector<float>>("variances");
auto flip = ctx.Attr<bool>("flip");
......@@ -108,7 +108,7 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
T box_width, box_height;
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
int min_size = min_sizes[s];
auto min_size = min_sizes[s];
// first prior: aspect_ratio = 1, size = min_size
box_width = box_height = min_size;
// xmin
......@@ -124,7 +124,7 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
idx++;
if (max_sizes.size() > 0) {
int max_size = max_sizes[s];
auto max_size = max_sizes[s];
// second prior: aspect_ratio = 1,
// size = sqrt(min_size * max_size)
box_width = box_height = sqrt(min_size * max_size);
......
......@@ -44,7 +44,6 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SmoothL1LossOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -73,10 +72,10 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out",
"(Tensor, default Tensor<float>) A tensor with rank be 2. "
"The output smooth l1 loss with shape [batch_size, 1].");
AddAttr<AttrType>("sigma",
"Hyper parameter of smooth l1 loss op."
"A float scalar with default value 3.0.")
.SetDefault(3.0);
AddAttr<float>("sigma",
"Hyper parameter of smooth l1 loss op."
"A float scalar with default value 3.0.")
.SetDefault(1.0);
AddComment(R"DOC(
Smooth L1 Loss Operator.
......@@ -133,9 +132,8 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp,
ops::SmoothL1LossOpMaker<float>, smooth_l1_loss_grad,
ops::SmoothL1LossGradOp);
REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp, ops::SmoothL1LossOpMaker,
smooth_l1_loss_grad, ops::SmoothL1LossGradOp);
REGISTER_OP_CPU_KERNEL(
smooth_l1_loss,
ops::SmoothL1LossKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -28,8 +28,11 @@ import device
from device import *
import math_op_patch
from math_op_patch import *
import detection
from detection import *
__all__ = []
__all__ += math_op_patch.__all__
__all__ += detection.__all__
__all__ += nn.__all__
__all__ += io.__all__
......@@ -37,4 +40,4 @@ __all__ += tensor.__all__
__all__ += control_flow.__all__
__all__ += ops.__all__
__all__ += device.__all__
__all__ += math_op_patch.__all__
__all__ += detection.__all__
......@@ -18,15 +18,15 @@ All layers just related to the detection neural network.
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..framework import Variable
from layer_function_generator import autodoc
from ..nets import img_conv_with_bn
from tensor import concat
from ops import reshape
from ..nets import img_conv_with_bn
from nn import transpose
import math
__all__ = [
'detection_output',
'prior_box',
'multi_box_head',
]
......@@ -44,7 +44,7 @@ def detection_output(scores,
"""
**Detection Output Layer**
This layer applies the NMS to the output of network and computes the
This layer applies the NMS to the output of network and computes the
predict bounding box location. The output's shape of this layer could
be zero if there is no valid bounding box.
......@@ -127,6 +127,211 @@ def detection_output(scores,
return nmsed_outs
def prior_box(inputs,
image,
min_ratio,
max_ratio,
aspect_ratios,
base_size,
steps=None,
step_w=None,
step_h=None,
offset=0.5,
variance=[0.1, 0.1, 0.1, 0.1],
flip=False,
clip=False,
min_sizes=None,
max_sizes=None,
name=None):
"""
**Prior_boxes**
Generate prior boxes for SSD(Single Shot MultiBox Detector)
algorithm. The details of this algorithm, please refer the
section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector)
<https://arxiv.org/abs/1512.02325>`_ .
Args:
inputs(list): The list of input Variables, the format
of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
min_ratio(int): the min ratio of generated prior boxes.
max_ratio(int): the max ratio of generated prior boxes.
aspect_ratios(list): the aspect ratios of generated prior
boxes. The length of input and aspect_ratios must be equal.
base_size(int): the base_size is used to get min_size
and max_size according to min_ratio and max_ratio.
step_w(list, optional, default=None): Prior boxes step
across width. If step_w[i] == 0.0, the prior boxes step
across width of the inputs[i] will be automatically calculated.
step_h(list, optional, default=None): Prior boxes step
across height, If step_h[i] == 0.0, the prior boxes
step across height of the inputs[i] will be automatically calculated.
offset(float, optional, default=0.5): Prior boxes center offset.
variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances
to be encoded in prior boxes.
flip(bool, optional, default=False): Whether to flip
aspect ratios.
clip(bool, optional, default=False): Whether to clip
out-of-boundary boxes.
min_sizes(list, optional, default=None): If `len(inputs) <=2`,
min_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
max_sizes(list, optional, default=None): If `len(inputs) <=2`,
max_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
name(str, optional, None): Name of the prior box layer.
Returns:
boxes(Variable): the output prior boxes of PriorBoxOp.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBoxOp.
The layout is [num_priors, 4]. num_priors is the total
box count of each position of inputs
Examples:
.. code-block:: python
prior_box(
inputs = [conv1, conv2, conv3, conv4, conv5, conv6],
image = data,
min_ratio = 20, # 0.20
max_ratio = 90, # 0.90
offset = 0.5,
base_size = 300,
variance = [0.1,0.1,0.1,0.1],
aspect_ratios = [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
flip=True,
clip=True)
"""
def _prior_box_(input,
image,
min_sizes,
max_sizes,
aspect_ratios,
variance,
flip=False,
clip=False,
step_w=0.0,
step_h=0.0,
offset=0.5,
name=None):
helper = LayerHelper("prior_box", **locals())
dtype = helper.input_dtype()
box = helper.create_tmp_variable(dtype)
var = helper.create_tmp_variable(dtype)
helper.append_op(
type="prior_box",
inputs={"Input": input,
"Image": image},
outputs={"Boxes": box,
"Variances": var},
attrs={
'min_sizes': min_sizes,
'max_sizes': max_sizes,
'aspect_ratios': aspect_ratios,
'variances': variance,
'flip': flip,
'clip': clip,
'step_w': step_w,
'step_h': step_h,
'offset': offset
})
return box, var
def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)):
raise ValueError("The axis should be smaller than "
"the arity of input and bigger than 0.")
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = reshape(x=input, shape=new_shape)
return out
assert isinstance(inputs, list), 'inputs should be a list.'
num_layer = len(inputs)
if num_layer <= 2:
assert min_sizes is not None and max_sizes is not None
assert len(min_sizes) == num_layer and len(max_sizes) == num_layer
else:
min_sizes = []
max_sizes = []
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
for ratio in xrange(min_ratio, max_ratio + 1, step):
min_sizes.append(base_size * ratio / 100.)
max_sizes.append(base_size * (ratio + step) / 100.)
min_sizes = [base_size * .10] + min_sizes
max_sizes = [base_size * .20] + max_sizes
if aspect_ratios:
if not (isinstance(aspect_ratios, list) and
len(aspect_ratios) == num_layer):
raise ValueError(
'aspect_ratios should be list and the length of inputs '
'and aspect_ratios should be the same.')
if step_h:
if not (isinstance(step_h, list) and len(step_h) == num_layer):
raise ValueError(
'step_h should be list and the length of inputs and '
'step_h should be the same.')
if step_w:
if not (isinstance(step_w, list) and len(step_w) == num_layer):
raise ValueError(
'step_w should be list and the length of inputs and '
'step_w should be the same.')
if steps:
if not (isinstance(steps, list) and len(steps) == num_layer):
raise ValueError(
'steps should be list and the length of inputs and '
'step_w should be the same.')
step_w = steps
step_h = steps
box_results = []
var_results = []
for i, input in enumerate(inputs):
min_size = min_sizes[i]
max_size = max_sizes[i]
aspect_ratio = []
if not isinstance(min_size, list):
min_size = [min_size]
if not isinstance(max_size, list):
max_size = [max_size]
if aspect_ratios:
aspect_ratio = aspect_ratios[i]
if not isinstance(aspect_ratio, list):
aspect_ratio = [aspect_ratio]
box, var = _prior_box_(input, image, min_size, max_size, aspect_ratio,
variance, flip, clip, step_w[i]
if step_w else 0.0, step_h[i]
if step_w else 0.0, offset)
box_results.append(box)
var_results.append(var)
if len(box_results) == 1:
box = box_results[0]
var = var_results[0]
else:
reshaped_boxes = []
reshaped_vars = []
for i in range(len(box_results)):
reshaped_boxes.append(_reshape_with_axis_(box_results[i], axis=3))
reshaped_vars.append(_reshape_with_axis_(var_results[i], axis=3))
box = concat(reshaped_boxes)
var = concat(reshaped_vars)
return box, var
def multi_box_head(inputs,
num_classes,
min_sizes=None,
......@@ -171,34 +376,53 @@ def multi_box_head(inputs,
Returns:
mbox_loc(Variable): the output prior boxes of PriorBoxOp. The layout is
mbox_loc(list): the output prior boxes of PriorBoxOp. The layout is
[num_priors, 4]. num_priors is the total box count of each
position of inputs.
mbox_conf(Variable): the expanded variances of PriorBoxOp. The layout
mbox_conf(list): the expanded variances of PriorBoxOp. The layout
is [num_priors, 4]. num_priors is the total box count of each
position of inputs
Examples:
.. code-block:: python
mbox_locs, mbox_confs = detection.multi_box_head(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
num_classes=21,
min_ratio=20,
max_ratio=90,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
flip=True)
"""
assert isinstance(inputs, list), 'inputs should be a list.'
if not (isinstance(inputs, list)):
raise ValueError('inputs should be a list.')
if min_sizes is not None:
assert len(inputs) == len(min_sizes)
if not (len(inputs) == len(min_sizes)):
raise ValueError('the length of min_sizes '
'and inputs should be the same.')
if max_sizes is not None:
assert len(inputs) == len(max_sizes)
if not (len(inputs) == len(max_sizes)):
raise ValueError('the length of max_sizes '
'and inputs should be the same.')
if aspect_ratios is not None:
if not (len(inputs) == len(aspect_ratios)):
raise ValueError('the length of aspect_ratios '
'and inputs should be the same.')
if min_sizes is None:
# if min_sizes is None, min_sizes and max_sizes
# will be set according to max_ratio and min_ratio
assert max_ratio is not None and min_ratio is not None
# If min_sizes is None, min_sizes and max_sizes
# will be set according to max_ratio and min_ratio.
num_layer = len(inputs)
assert max_ratio is not None and min_ratio is not None,\
'max_ratio and min_ratio must be not None.'
assert num_layer >= 3, 'The length of the input data is at least three.'
min_sizes = []
max_sizes = []
num_layer = len(inputs)
step = int(math.floor(((max_ratio - min_ratio)) / (num_layer - 2)))
for ratio in xrange(min_ratio, max_ratio + 1, step):
min_sizes.append(base_size * ratio / 100.)
......@@ -206,9 +430,6 @@ def multi_box_head(inputs,
min_sizes = [base_size * .10] + min_sizes
max_sizes = [base_size * .20] + max_sizes
if aspect_ratios is not None:
assert len(inputs) == len(aspect_ratios)
mbox_locs = []
mbox_confs = []
for i, input in enumerate(inputs):
......@@ -221,9 +442,9 @@ def multi_box_head(inputs,
max_size = max_sizes[i]
if type(max_size) is not list:
max_size = [max_size]
if max_size:
assert len(max_size) == len(
min_size), "max_size and min_size should have same length."
if not (len(max_size) == len(min_size)):
raise ValueError(
'max_size and min_size should have same length.')
aspect_ratio = []
if aspect_ratios is not None:
......@@ -231,17 +452,19 @@ def multi_box_head(inputs,
if type(aspect_ratio) is not list:
aspect_ratio = [aspect_ratio]
# get the number of prior box on each location
num_priors_per_location = 0
if max_sizes is not None:
num_priors_per_location = len(min_size) + len(aspect_ratio) * len(
min_size) + len(max_size)
num_priors_per_location = len(min_size) + \
len(aspect_ratio) * len(min_size) +\
len(max_size)
else:
num_priors_per_location = len(min_size) + len(aspect_ratio) * len(
min_size)
num_priors_per_location = len(min_size) +\
len(aspect_ratio) * len(min_size)
if flip:
num_priors_per_location += len(aspect_ratio) * len(min_size)
# mbox_loc
# get mbox_loc
num_loc_output = num_priors_per_location * 4
if share_location:
num_loc_output *= num_classes
......@@ -256,7 +479,7 @@ def multi_box_head(inputs,
mbox_loc = transpose(mbox_loc, perm=[0, 2, 3, 1])
mbox_locs.append(mbox_loc)
# get the number of prior box
# get conf_loc
num_conf_output = num_priors_per_location * num_classes
conf_loc = img_conv_with_bn(
input=input,
......
......@@ -152,7 +152,12 @@ def monkey_patch_variable():
("__div__", "elementwise_div", False),
("__rdiv__", "elementwise_div", True),
("__pow__", "elementwise_pow", False),
("__rpow__", "elementwise_pow", True)):
("__rpow__", "elementwise_pow", True),
# for logical compare
("__eq__", "equal", False),
("__ne__", "not_equal", False),
("__lt__", "less_than", False),
("__le__", "less_equal", False)):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse))
......
......@@ -66,6 +66,8 @@ __all__ = [
'row_conv',
'multiplex',
'layer_norm',
'softmax_with_cross_entropy',
'smooth_l1',
]
......@@ -3091,3 +3093,122 @@ def multiplex(inputs, index):
'Ids': index},
outputs={'Out': [out]})
return out
def softmax_with_cross_entropy(logits, label, soft_label=False):
"""
**Softmax With Cross Entropy Operator.**
Cross entropy loss with softmax is used as the output layer extensively. This
operator computes the softmax normalized values for each row of the input
tensor, after which cross-entropy loss is computed. This provides a more
numerically stable gradient.
Because this operator performs a softmax on logits internally, it expects
unscaled logits. This operator should not be used with the output of
softmax operator since that would produce incorrect results.
When the attribute soft_label is set false, this operators expects mutually
exclusive hard labels, each sample in a batch is in exactly one class with a
probability of 1.0. Each sample in the batch will have a single label.
The equation is as follows:
1) Hard label (one-hot label, so every sample has exactly one class)
.. math::
loss_j = -\\text{logit}_{label_j} +
\\log\\left(\\sum_{i=0}^{K}\\exp(\\text{logit}_i)\\right), j = 1,..., K
2) Soft label (each sample can have a distribution over all classes)
.. math::
loss_j = -\\sum_{i=0}^{K}\\text{label}_i
\\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
label (Variable): The ground truth which is a 2-D tensor. If soft_label
is set to false, Label is a Tensor<int64> with shape [N x 1]. If
soft_label is set to true, Label is a Tensor<float/double> with
soft_label (bool): A flag to indicate whether to interpretate the given
labels as soft labels. By default, `soft_label` is set to False.
Returns:
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label)
"""
helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_tmp_variable(dtype=logits.dtype)
loss = helper.create_tmp_variable(dtype=logits.dtype)
helper.append_op(
type='softmax_with_cross_entropy',
inputs={'Logits': logits,
'Label': label},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={'soft_label': soft_label})
return loss
def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
"""
**Smooth L1 Loss Operator. **
This operator computes the smooth l1 loss for X and Y.
The operator takes the first dimension of X and Y as batch size.
For each instance, it computes the smooth l1 loss element by element first
and then sums all the losses. So the shape of Out is [batch_size, 1].
Args:
x (Variable): A tensor with rank at least 2. The input value of smooth
l1 loss op with shape [batch_size, dim1, ..., dimN].
y (Variable): A tensor with rank at least 2. The target value of smooth
l1 loss op with same shape as x.
inside_weight (Variable|None): A tensor with rank at least 2. This
input is optional and should have same shape with x. If provided,
the result of (x - y) will be multiplied by this tensor element by
element.
outside_weight (Variable|None): A tensor with rank at least 2. This
input is optional and should have same shape with x. If provided,
the out smooth l1 loss will be multiplied by this tensor element
by element.
sigma (float|None): Hyper parameter of smooth l1 loss op. A float scalar
with default value 1.0.
Returns:
Variable: A tensor with rank be 2. The output smooth l1 loss with
shape [batch_size, 1].
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[128], dtype='float32')
label = fluid.layers.data(name='label', shape=[100], dtype='int64')
fc = fluid.layers.fc(input=data, size=100)
out = fluid.layers.smooth_l1(logits=fc, label=label)
"""
helper = LayerHelper('smooth_l1_loss', **locals())
diff = helper.create_tmp_variable(dtype=x.dtype)
loss = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type='smooth_l1_loss',
inputs={
'X': x,
'Y': y,
'InsideWeight': inside_weight,
'OutsideWeight': outside_weight
},
outputs={'Diff': diff,
'Out': loss},
attrs={'sigma': sigma})
return loss
......@@ -179,7 +179,7 @@ def polynomial_decay(learning_rate,
shape=[1], dtype='float32', value=1.0)
with layers.Switch() as switch:
with switch.case(layers.equal(x=global_step, y=zero_var)):
with switch.case(global_step == zero_var):
layers.assign(input=one_var, output=div_res)
decay_steps = decay_steps * div_res
else:
......@@ -229,7 +229,7 @@ def piecewise_decay(global_step, boundaries, values):
shape=[1], dtype='float32', value=float(boundaries[i]))
value_var = layers.fill_constant(
shape=[1], dtype='float32', value=float(values[i]))
with switch.case(layers.less_than(global_step, boundary_val)):
with switch.case(global_step < boundary_val):
layers.assign(value_var, lr)
last_value_var = layers.fill_constant(
shape=[1],
......
......@@ -14,15 +14,10 @@
from __future__ import print_function
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.layers.detection as detection
from paddle.v2.fluid.framework import Program, program_guard
import unittest
import numpy as np
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.framework import Program, program_guard
class TestBook(unittest.TestCase):
......@@ -55,15 +50,67 @@ class TestBook(unittest.TestCase):
print(str(program))
class TestPriorBox(unittest.TestCase):
def test_prior_box(self):
data_shape = [3, 224, 224]
box, var = self.prior_box_output(data_shape)
assert len(box.shape) == 2
assert box.shape == var.shape
assert box.shape[1] == 4
def prior_box_output(self, data_shape):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
conv1 = fluid.layers.conv2d(
input=images,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv2 = fluid.layers.conv2d(
input=conv1,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv3 = fluid.layers.conv2d(
input=conv2,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv4 = fluid.layers.conv2d(
input=conv3,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
conv5 = fluid.layers.conv2d(
input=conv4,
num_filters=3,
filter_size=3,
stride=2,
use_cudnn=False)
box, var = detection.prior_box(
inputs=[conv1, conv2, conv3, conv4, conv5, conv5],
image=images,
min_ratio=20,
max_ratio=90,
# steps=[8, 16, 32, 64, 100, 300],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
base_size=300,
offset=0.5,
flip=True,
clip=True)
return box, var
class TestMultiBoxHead(unittest.TestCase):
def test_prior_box(self):
data_shape = [3, 224, 224]
mbox_locs, mbox_confs = self.multi_box_output(data_shape)
# print mbox_locs.shape
# print mbox_confs.shape
# assert len(box.shape) == 2
# assert box.shape == var.shape
# assert box.shape[1] == 4
def multi_box_output(self, data_shape):
images = fluid.layers.data(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
import collections
import math
from op_test import OpTest
class TestDetectionMAPOp(OpTest):
def set_data(self):
self.init_test_case()
self.mAP = [self.calc_map(self.tf_pos, self.tf_pos_lod)]
self.label = np.array(self.label).astype('float32')
self.detect = np.array(self.detect).astype('float32')
self.mAP = np.array(self.mAP).astype('float32')
if (len(self.class_pos_count) > 0):
self.class_pos_count = np.array(self.class_pos_count).astype(
'int32')
self.true_pos = np.array(self.true_pos).astype('float32')
self.false_pos = np.array(self.false_pos).astype('float32')
self.inputs = {
'Label': (self.label, self.label_lod),
'DetectRes': (self.detect, self.detect_lod),
'PosCount': self.class_pos_count,
'TruePos': (self.true_pos, self.true_pos_lod),
'FalsePos': (self.false_pos, self.false_pos_lod)
}
else:
self.inputs = {
'Label': (self.label, self.label_lod),
'DetectRes': (self.detect, self.detect_lod),
}
self.attrs = {
'overlap_threshold': self.overlap_threshold,
'evaluate_difficult': self.evaluate_difficult,
'ap_type': self.ap_type
}
self.out_class_pos_count = np.array(self.out_class_pos_count).astype(
'int')
self.out_true_pos = np.array(self.out_true_pos).astype('float32')
self.out_false_pos = np.array(self.out_false_pos).astype('float32')
self.outputs = {
'MAP': self.mAP,
'AccumPosCount': self.out_class_pos_count,
'AccumTruePos': (self.out_true_pos, self.out_true_pos_lod),
'AccumFalsePos': (self.out_false_pos, self.out_false_pos_lod)
}
def init_test_case(self):
self.overlap_threshold = 0.3
self.evaluate_difficult = True
self.ap_type = "integral"
self.label_lod = [[0, 2, 4]]
# label difficult xmin ymin xmax ymax
self.label = [[1, 0, 0.1, 0.1, 0.3, 0.3], [1, 1, 0.6, 0.6, 0.8, 0.8],
[2, 0, 0.3, 0.3, 0.6, 0.5], [1, 0, 0.7, 0.1, 0.9, 0.3]]
# label score xmin ymin xmax ymax difficult
self.detect_lod = [[0, 3, 7]]
self.detect = [
[1, 0.3, 0.1, 0.0, 0.4, 0.3], [1, 0.7, 0.0, 0.1, 0.2, 0.3],
[1, 0.9, 0.7, 0.6, 0.8, 0.8], [2, 0.8, 0.2, 0.1, 0.4, 0.4],
[2, 0.1, 0.4, 0.3, 0.7, 0.5], [1, 0.2, 0.8, 0.1, 1.0, 0.3],
[3, 0.2, 0.8, 0.1, 1.0, 0.3]
]
# label score true_pos false_pos
self.tf_pos_lod = [[0, 3, 7]]
self.tf_pos = [[1, 0.9, 1, 0], [1, 0.7, 1, 0], [1, 0.3, 0, 1],
[1, 0.2, 1, 0], [2, 0.8, 0, 1], [2, 0.1, 1, 0],
[3, 0.2, 0, 1]]
self.class_pos_count = []
self.true_pos_lod = [[]]
self.true_pos = [[]]
self.false_pos_lod = [[]]
self.false_pos = [[]]
def calc_map(self, tf_pos, tf_pos_lod):
mAP = 0.0
count = 0
def get_input_pos(class_pos_count, true_pos, true_pos_lod, false_pos,
false_pos_lod):
class_pos_count_dict = collections.Counter()
true_pos_dict = collections.defaultdict(list)
false_pos_dict = collections.defaultdict(list)
for i, count in enumerate(class_pos_count):
class_pos_count_dict[i] = count
for i in range(len(true_pos_lod[0]) - 1):
start = true_pos_lod[0][i]
end = true_pos_lod[0][i + 1]
for j in range(start, end):
true_pos_dict[i].append(true_pos[j])
for i in range(len(false_pos_lod[0]) - 1):
start = false_pos_lod[0][i]
end = false_pos_lod[0][i + 1]
for j in range(start, end):
false_pos_dict[i].append(false_pos[j])
return class_pos_count_dict, true_pos_dict, false_pos_dict
def get_output_pos(label_count, true_pos, false_pos):
max_label = 0
for (label, label_pos_num) in label_count.items():
if max_label < label:
max_label = label
label_number = max_label + 1
out_class_pos_count = []
out_true_pos_lod = [0]
out_true_pos = []
out_false_pos_lod = [0]
out_false_pos = []
for i in range(label_number):
out_class_pos_count.append([label_count[i]])
true_pos_list = true_pos[i]
out_true_pos += true_pos_list
out_true_pos_lod.append(len(out_true_pos))
false_pos_list = false_pos[i]
out_false_pos += false_pos_list
out_false_pos_lod.append(len(out_false_pos))
return out_class_pos_count, out_true_pos, [
out_true_pos_lod
], out_false_pos, [out_false_pos_lod]
def get_accumulation(pos_list):
sorted_list = sorted(pos_list, key=lambda pos: pos[0], reverse=True)
sum = 0
accu_list = []
for (score, count) in sorted_list:
sum += count
accu_list.append(sum)
return accu_list
label_count, true_pos, false_pos = get_input_pos(
self.class_pos_count, self.true_pos, self.true_pos_lod,
self.false_pos, self.false_pos_lod)
for (label, difficult, xmin, ymin, xmax, ymax) in self.label:
if self.evaluate_difficult:
label_count[label] += 1
elif not difficult:
label_count[label] += 1
true_pos = collections.defaultdict(list)
false_pos = collections.defaultdict(list)
for (label, score, tp, fp) in tf_pos:
true_pos[label].append([score, tp])
false_pos[label].append([score, fp])
for (label, label_pos_num) in label_count.items():
if label_pos_num == 0 or label not in true_pos: continue
label_true_pos = true_pos[label]
label_false_pos = false_pos[label]
accu_tp_sum = get_accumulation(label_true_pos)
accu_fp_sum = get_accumulation(label_false_pos)
precision = []
recall = []
for i in range(len(accu_tp_sum)):
precision.append(
float(accu_tp_sum[i]) /
float(accu_tp_sum[i] + accu_fp_sum[i]))
recall.append(float(accu_tp_sum[i]) / label_pos_num)
if self.ap_type == "11point":
max_precisions = [0.0] * 11
start_idx = len(accu_tp_sum) - 1
for j in range(10, -1, -1):
for i in range(start_idx, -1, -1):
if recall[i] < float(j) / 10.0:
start_idx = i
if j > 0:
max_precisions[j - 1] = max_precisions[j]
break
else:
if max_precisions[j] < precision[i]:
max_precisions[j] = precision[i]
for j in range(10, -1, -1):
mAP += max_precisions[j] / 11
count += 1
elif self.ap_type == "integral":
average_precisions = 0.0
prev_recall = 0.0
for i in range(len(accu_tp_sum)):
if math.fabs(recall[i] - prev_recall) > 1e-6:
average_precisions += precision[i] * \
math.fabs(recall[i] - prev_recall)
prev_recall = recall[i]
mAP += average_precisions
count += 1
self.out_class_pos_count, self.out_true_pos, self.out_true_pos_lod, self.out_false_pos, self.out_false_pos_lod = get_output_pos(
label_count, true_pos, false_pos)
if count != 0:
mAP /= count
return mAP * 100.0
def setUp(self):
self.op_type = "detection_map"
self.set_data()
def test_check_output(self):
self.check_output()
class TestDetectionMAPOpSkipDiff(TestDetectionMAPOp):
def init_test_case(self):
super(TestDetectionMAPOpSkipDiff, self).init_test_case()
self.evaluate_difficult = False
self.tf_pos_lod = [[0, 2, 6]]
# label score true_pos false_pos
self.tf_pos = [[1, 0.7, 1, 0], [1, 0.3, 0, 1], [1, 0.2, 1, 0],
[2, 0.8, 0, 1], [2, 0.1, 1, 0], [3, 0.2, 0, 1]]
class TestDetectionMAPOp11Point(TestDetectionMAPOp):
def init_test_case(self):
super(TestDetectionMAPOp11Point, self).init_test_case()
self.ap_type = "11point"
class TestDetectionMAPOpMultiBatch(TestDetectionMAPOp):
def init_test_case(self):
super(TestDetectionMAPOpMultiBatch, self).init_test_case()
self.class_pos_count = [0, 2, 1]
self.true_pos_lod = [[0, 0, 3, 5]]
self.true_pos = [[0.7, 1.], [0.3, 0.], [0.2, 1.], [0.8, 0.], [0.1, 1.]]
self.false_pos_lod = [[0, 0, 3, 5]]
self.false_pos = [[0.7, 0.], [0.3, 1.], [0.2, 0.], [0.8, 1.], [0.1, 0.]]
if __name__ == '__main__':
unittest.main()
......@@ -161,8 +161,8 @@ class TestBook(unittest.TestCase):
label=label,
chunk_scheme="IOB",
num_chunk_types=(label_dict_len - 1) / 2)
self.assertNotEqual(crf, None)
self.assertNotEqual(crf_decode, None)
self.assertFalse(crf is None)
self.assertFalse(crf_decode is None)
print(str(program))
......@@ -309,6 +309,24 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_softmax_with_cross_entropy(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[16], dtype='float32')
y = layers.data(name='label', shape=[1], dtype='int64')
loss = layers.softmax_with_cross_entropy(x, y)
self.assertIsNotNone(loss)
print(str(program))
def test_smooth_l1(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[4], dtype='float32')
y = layers.data(name='label', shape=[4], dtype='float32')
loss = layers.smooth_l1(x, y)
self.assertIsNotNone(loss)
print(str(program))
if __name__ == '__main__':
unittest.main()
......@@ -65,9 +65,9 @@ class TestPriorBoxOp(OpTest):
self.batch_size = 10
self.min_sizes = [2, 4]
self.min_sizes = np.array(self.min_sizes).astype('int64')
self.min_sizes = np.array(self.min_sizes).astype('float32').tolist()
self.max_sizes = [5, 10]
self.max_sizes = np.array(self.max_sizes).astype('int64')
self.max_sizes = np.array(self.max_sizes).astype('float32').tolist()
self.aspect_ratios = [2.0, 3.0]
self.flip = True
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid as fluid
class TestPythonOperatorOverride(unittest.TestCase):
def check_result(self, fn, place, dtype):
shape = [9, 10]
x_data = np.random.random(size=shape).astype(dtype)
y_data = np.random.random(size=shape).astype(dtype)
python_out = fn(x_data, y_data)
x_var = layers.create_global_var(
name='x', shape=shape, value=0.0, dtype=dtype, persistable=True)
y_var = layers.create_global_var(
name='y', shape=shape, value=0.0, dtype=dtype, persistable=True)
out = fn(x_var, y_var)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid_out = exe.run(fluid.default_main_program(),
feed={'x': x_data,
'y': y_data},
fetch_list=[out])
np.testing.assert_array_equal(python_out, fluid_out[0])
def test_override(self):
# compare func to check
compare_fns = [
lambda _a, _b: _a == _b,
lambda _a, _b: _a != _b,
lambda _a, _b: _a < _b,
lambda _a, _b: _a <= _b,
lambda _a, _b: _a > _b,
lambda _a, _b: _a >= _b,
]
# places to check
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
# dtypes to check
dtypes = ['int32', 'float32']
for place in places:
for dtype in dtypes:
for compare_fn in compare_fns:
with framework.program_guard(framework.Program(),
framework.Program()):
self.check_result(compare_fn, place, dtype)
if __name__ == '__main__':
unittest.main()
......@@ -52,3 +52,5 @@ RUN wget -O /opt/swig-2.0.12.tar.gz https://sourceforge.net/projects/swig/files/
RUN mkdir -p /src && cd /src && git clone https://github.com/NVIDIA/nccl.git nccl && cd nccl &&\
make -j `nproc` install <NCCL_MAKE_OPTS> && cd .. && rm -rf nccl
CMD ["bash", "/paddle/paddle/scripts/docker/build.sh"]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册