提交 a0284f6f 编写于 作者: D dengkaipeng

Add backward CPU kernel. test=develop

上级 36c46152
......@@ -183,6 +183,7 @@ paddle.fluid.layers.similarity_focus ArgSpec(args=['input', 'axis', 'indexes', '
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'anchors', 'class_num', 'ignore_thresh', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
......
......@@ -20,8 +20,6 @@ using framework::Tensor;
class Yolov3LossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of Yolov3LossOp should not be null.");
......@@ -32,7 +30,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
auto dim_x = ctx->GetInputDim("X");
auto dim_gt = ctx->GetInputDim("GTBox");
auto img_height = ctx->Attrs().Get<int>("img_height");
auto anchors = ctx->Attrs().Get<std::vector<int>>("anchors");
auto class_num = ctx->Attrs().Get<int>("class_num");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "Input(X) should be a 4-D tensor.");
......@@ -43,8 +40,6 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
"+ class_num)).");
PADDLE_ENFORCE_EQ(dim_gt.size(), 3, "Input(GTBox) should be a 3-D tensor");
PADDLE_ENFORCE_EQ(dim_gt[2], 5, "Input(GTBox) dim[2] should be 5");
PADDLE_ENFORCE_GT(img_height, 0,
"Attr(img_height) value should be greater then 0");
PADDLE_ENFORCE_GT(anchors.size(), 0,
"Attr(anchors) length should be greater then 0.");
PADDLE_ENFORCE_EQ(anchors.size() % 2, 0,
......@@ -87,13 +82,43 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("anchors",
"The anchor width and height, "
"it will be parsed pair by pair.");
AddAttr<int>("img_height",
"The input image height after crop of yolov3 network.");
AddAttr<float>("ignore_thresh",
"The ignore threshold to ignore confidence loss.");
AddComment(R"DOC(
This operator generate yolov3 loss by given predict result and ground
truth boxes.
The output of previous network is in shape [N, C, H, W], while H and W
should be the same, specify the grid size, each grid point predict given
number boxes, this given number is specified by anchors, it should be
half anchors length, which following will be represented as S. In the
second dimention(the channel dimention), C should be S * (class_num + 5),
class_num is the box categoriy number of source dataset(such as coco),
so in the second dimention, stores 4 box location coordinates x, y, w, h
and confidence score of the box and class one-hot key of each anchor box.
While the 4 location coordinates if $$tx, ty, tw, th$$, the box predictions
correspnd to:
$$
b_x = \sigma(t_x) + c_x
b_y = \sigma(t_y) + c_y
b_w = p_w e^{t_w}
b_h = p_h e^{t_h}
$$
While $$c_x, c_y$$ is the left top corner of current grid and $$p_w, p_h$$
is specified by anchors.
As for confidence score, it is the logistic regression value of IoU between
anchor boxes and ground truth boxes, the score of the anchor box which has
the max IoU should be 1, and if the anchor box has IoU bigger then ignore
thresh, the confidence score loss of this anchor box will be ignored.
Therefore, the yolov3 loss consist of three major parts, box location loss,
confidence score loss, and classification loss. The MSE loss is used for
box location, and binary cross entropy loss is used for confidence score
loss and classification loss.
)DOC");
}
};
......@@ -101,8 +126,6 @@ class Yolov3LossOpMaker : public framework::OpProtoAndCheckerMaker {
class Yolov3LossOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
......@@ -113,6 +136,7 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
......@@ -120,12 +144,32 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
}
};
class Yolov3LossGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("yolov3_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("GTBox", Input("GTBox"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("GTBox"), {});
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(yolov3_loss, ops::Yolov3LossOp, ops::Yolov3LossOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::Yolov3LossGradMaker);
REGISTER_OPERATOR(yolov3_loss_grad, ops::Yolov3LossOpGrad);
REGISTER_OP_CPU_KERNEL(
yolov3_loss,
......
......@@ -17,7 +17,7 @@
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
yolov3_loss,
ops::Yolov3LossOpKernel<paddle::platform::CUDADeviceContext, float>);
ops::Yolov3LossKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
yolov3_loss_grad,
ops::Yolov3LossGradOpKernel<paddle::platform::CUDADeviceContext, float>);
ops::Yolov3LossGradKernel<paddle::platform::CUDADeviceContext, float>);
......@@ -33,10 +33,22 @@ static inline bool isZero(T x) {
}
template <typename T>
static inline T sigmod(T x) {
static inline T sigmoid(T x) {
return 1.0 / (exp(-1.0 * x) + 1.0);
}
template <typename T>
static inline T CalcMaskPointNum(const Tensor& mask) {
auto mask_t = EigenVector<int>::Flatten(mask);
T count = 0.0;
for (int i = 0; i < mask_t.dimensions()[0]; i++) {
if (mask_t(i)) {
count += 1.0;
}
}
return count;
}
template <typename T>
static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y,
const Tensor& mask) {
......@@ -55,6 +67,21 @@ static inline T CalcMSEWithMask(const Tensor& x, const Tensor& y,
return (error_sum / points);
}
template <typename T>
static void CalcMSEGradWithMask(Tensor* grad, const Tensor& x, const Tensor& y,
const Tensor& mask, T mf) {
auto grad_t = EigenVector<T>::Flatten(*grad).setConstant(0.0);
auto x_t = EigenVector<T>::Flatten(x);
auto y_t = EigenVector<T>::Flatten(y);
auto mask_t = EigenVector<int>::Flatten(mask);
for (int i = 0; i < x_t.dimensions()[0]; i++) {
if (mask_t(i)) {
grad_t(i) = 2.0 * (x_t(i) - y_t(i)) / mf;
}
}
}
template <typename T>
static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y,
const Tensor& mask) {
......@@ -75,21 +102,34 @@ static inline T CalcBCEWithMask(const Tensor& x, const Tensor& y,
}
template <typename T>
static void CalcPredResult(const Tensor& input, Tensor* pred_confs,
Tensor* pred_classes, Tensor* pred_x, Tensor* pred_y,
Tensor* pred_w, Tensor* pred_h,
std::vector<int> anchors, const int class_num,
const int stride) {
static inline void CalcBCEGradWithMask(Tensor* grad, const Tensor& x,
const Tensor& y, const Tensor& mask,
T mf) {
auto grad_t = EigenVector<T>::Flatten(*grad).setConstant(0.0);
auto x_t = EigenVector<T>::Flatten(x);
auto y_t = EigenVector<T>::Flatten(y);
auto mask_t = EigenVector<int>::Flatten(mask);
for (int i = 0; i < x_t.dimensions()[0]; i++) {
if (mask_t(i)) {
grad_t(i) = ((1.0 - y_t(i)) / (1.0 - x_t(i)) - y_t(i) / x_t(i)) / mf;
}
}
}
template <typename T>
static void CalcPredResult(const Tensor& input, Tensor* pred_conf,
Tensor* pred_class, Tensor* pred_x, Tensor* pred_y,
Tensor* pred_w, Tensor* pred_h, const int anchor_num,
const int class_num) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int h = input.dims()[2];
const int w = input.dims()[3];
const int anchor_num = anchors.size() / 2;
const int box_attr_num = 5 + class_num;
auto input_t = EigenTensor<T, 4>::From(input);
auto pred_confs_t = EigenTensor<T, 4>::From(*pred_confs);
auto pred_classes_t = EigenTensor<T, 5>::From(*pred_classes);
auto pred_conf_t = EigenTensor<T, 4>::From(*pred_conf);
auto pred_class_t = EigenTensor<T, 5>::From(*pred_class);
auto pred_x_t = EigenTensor<T, 4>::From(*pred_x);
auto pred_y_t = EigenTensor<T, 4>::From(*pred_y);
auto pred_w_t = EigenTensor<T, 4>::From(*pred_w);
......@@ -97,26 +137,23 @@ static void CalcPredResult(const Tensor& input, Tensor* pred_confs,
for (int i = 0; i < n; i++) {
for (int an_idx = 0; an_idx < anchor_num; an_idx++) {
float an_w = anchors[an_idx * 2] / stride;
float an_h = anchors[an_idx * 2 + 1] / stride;
for (int j = 0; j < h; j++) {
for (int k = 0; k < w; k++) {
pred_x_t(i, an_idx, j, k) =
sigmod(input_t(i, box_attr_num * an_idx, j, k));
sigmoid(input_t(i, box_attr_num * an_idx, j, k));
pred_y_t(i, an_idx, j, k) =
sigmod(input_t(i, box_attr_num * an_idx + 1, j, k));
sigmoid(input_t(i, box_attr_num * an_idx + 1, j, k));
pred_w_t(i, an_idx, j, k) =
input_t(i, box_attr_num * an_idx + 2, j, k);
pred_h_t(i, an_idx, j, k) =
input_t(i, box_attr_num * an_idx + 3, j, k);
pred_confs_t(i, an_idx, j, k) =
sigmod(input_t(i, box_attr_num * an_idx + 4, j, k));
pred_conf_t(i, an_idx, j, k) =
sigmoid(input_t(i, box_attr_num * an_idx + 4, j, k));
for (int c = 0; c < class_num; c++) {
pred_classes_t(i, an_idx, j, k, c) =
sigmod(input_t(i, box_attr_num * an_idx + 5 + c, j, k));
pred_class_t(i, an_idx, j, k, c) =
sigmoid(input_t(i, box_attr_num * an_idx + 5 + c, j, k));
}
}
}
......@@ -148,27 +185,11 @@ static T CalcBoxIoU(std::vector<T> box1, std::vector<T> box2) {
return inter_area / (b1_area + b2_area - inter_area);
}
template <typename T>
static inline int GetPredLabel(const Tensor& pred_classes, int n,
int best_an_index, int gj, int gi) {
auto pred_classes_t = EigenTensor<T, 5>::From(pred_classes);
T score = 0.0;
int label = -1;
for (int i = 0; i < pred_classes.dims()[4]; i++) {
if (pred_classes_t(n, best_an_index, gj, gi, i) > score) {
score = pred_classes_t(n, best_an_index, gj, gi, i);
label = i;
}
}
return label;
}
template <typename T>
static void PrePorcessGTBox(const Tensor& gt_boxes, const float ignore_thresh,
std::vector<int> anchors, const int img_height,
const int grid_size, Tensor* obj_mask,
Tensor* noobj_mask, Tensor* tx, Tensor* ty,
Tensor* tw, Tensor* th, Tensor* tconf,
std::vector<int> anchors, const int grid_size,
Tensor* obj_mask, Tensor* noobj_mask, Tensor* tx,
Tensor* ty, Tensor* tw, Tensor* th, Tensor* tconf,
Tensor* tclass) {
const int n = gt_boxes.dims()[0];
const int b = gt_boxes.dims()[1];
......@@ -240,6 +261,61 @@ static void ExpandObjMaskByClassNum(Tensor* obj_mask_expand,
.broadcast(Array5(1, 1, 1, 1, class_num));
}
template <typename T>
static void AddAllGradToInputGrad(
Tensor* grad, T loss, const Tensor& pred_x, const Tensor& pred_y,
const Tensor& pred_conf, const Tensor& pred_class, const Tensor& grad_x,
const Tensor& grad_y, const Tensor& grad_w, const Tensor& grad_h,
const Tensor& grad_conf_obj, const Tensor& grad_conf_noobj,
const Tensor& grad_class, const int class_num) {
const int n = pred_x.dims()[0];
const int an_num = pred_x.dims()[1];
const int h = pred_x.dims()[2];
const int w = pred_x.dims()[3];
const int attr_num = class_num + 5;
auto grad_t = EigenTensor<T, 4>::From(*grad).setConstant(0.0);
auto pred_x_t = EigenTensor<T, 4>::From(pred_x);
auto pred_y_t = EigenTensor<T, 4>::From(pred_y);
auto pred_conf_t = EigenTensor<T, 4>::From(pred_conf);
auto pred_class_t = EigenTensor<T, 5>::From(pred_class);
auto grad_x_t = EigenTensor<T, 4>::From(grad_x);
auto grad_y_t = EigenTensor<T, 4>::From(grad_y);
auto grad_w_t = EigenTensor<T, 4>::From(grad_w);
auto grad_h_t = EigenTensor<T, 4>::From(grad_h);
auto grad_conf_obj_t = EigenTensor<T, 4>::From(grad_conf_obj);
auto grad_conf_noobj_t = EigenTensor<T, 4>::From(grad_conf_noobj);
auto grad_class_t = EigenTensor<T, 5>::From(grad_class);
for (int i = 0; i < n; i++) {
for (int j = 0; j < an_num; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
grad_t(i, j * attr_num, k, l) = grad_x_t(i, j, k, l) *
pred_x_t(i, j, k, l) *
(1.0 - pred_x_t(i, j, k, l)) * loss;
grad_t(i, j * attr_num + 1, k, l) =
grad_y_t(i, j, k, l) * pred_y_t(i, j, k, l) *
(1.0 - pred_y_t(i, j, k, l)) * loss;
grad_t(i, j * attr_num + 2, k, l) = grad_w_t(i, j, k, l) * loss;
grad_t(i, j * attr_num + 3, k, l) = grad_h_t(i, j, k, l) * loss;
grad_t(i, j * attr_num + 4, k, l) =
grad_conf_obj_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
(1.0 - pred_conf_t(i, j, k, l)) * loss;
grad_t(i, j * attr_num + 4, k, l) +=
grad_conf_noobj_t(i, j, k, l) * pred_conf_t(i, j, k, l) *
(1.0 - pred_conf_t(i, j, k, l)) * loss;
for (int c = 0; c < class_num; c++) {
grad_t(i, j * attr_num + 5 + c, k, l) =
grad_class_t(i, j, k, l, c) * pred_class_t(i, j, k, l, c) *
(1.0 - pred_class_t(i, j, k, l, c)) * loss;
}
}
}
}
}
}
template <typename DeviceContext, typename T>
class Yolov3LossKernel : public framework::OpKernel<T> {
public:
......@@ -247,28 +323,25 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X");
auto* gt_boxes = ctx.Input<Tensor>("GTBox");
auto* loss = ctx.Output<Tensor>("Loss");
int img_height = ctx.Attr<int>("img_height");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int an_num = anchors.size() / 2;
const T stride = static_cast<T>(img_height) / h;
Tensor pred_x, pred_y, pred_w, pred_h;
Tensor pred_confs, pred_classes;
Tensor pred_conf, pred_class;
pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_confs.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_classes.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
CalcPredResult<T>(*input, &pred_confs, &pred_classes, &pred_x, &pred_y,
&pred_w, &pred_h, anchors, class_num, stride);
pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
CalcPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
&pred_w, &pred_h, an_num, class_num);
Tensor obj_mask, noobj_mask;
Tensor tx, ty, tw, th, tconf, tclass;
......@@ -280,9 +353,8 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
PrePorcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, img_height, h,
&obj_mask, &noobj_mask, &tx, &ty, &tw, &th, &tconf,
&tclass);
PrePorcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, h, &obj_mask,
&noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
Tensor obj_mask_expand;
obj_mask_expand.mutable_data<int>({n, an_num, h, w, class_num},
......@@ -293,17 +365,9 @@ class Yolov3LossKernel : public framework::OpKernel<T> {
T loss_y = CalcMSEWithMask<T>(pred_y, ty, obj_mask);
T loss_w = CalcMSEWithMask<T>(pred_w, tw, obj_mask);
T loss_h = CalcMSEWithMask<T>(pred_h, th, obj_mask);
T loss_conf_obj = CalcBCEWithMask<T>(pred_confs, tconf, obj_mask);
T loss_conf_noobj = CalcBCEWithMask<T>(pred_confs, tconf, noobj_mask);
T loss_class = CalcBCEWithMask<T>(pred_classes, tclass, obj_mask_expand);
// LOG(ERROR) << "loss_x: " << loss_x;
// LOG(ERROR) << "loss_y: " << loss_y;
// LOG(ERROR) << "loss_w: " << loss_w;
// LOG(ERROR) << "loss_h: " << loss_h;
// LOG(ERROR) << "loss_conf_obj: " << loss_conf_obj;
// LOG(ERROR) << "loss_conf_noobj: " << loss_conf_noobj;
// LOG(ERROR) << "loss_class: " << loss_class;
T loss_conf_obj = CalcBCEWithMask<T>(pred_conf, tconf, obj_mask);
T loss_conf_noobj = CalcBCEWithMask<T>(pred_conf, tconf, noobj_mask);
T loss_class = CalcBCEWithMask<T>(pred_class, tclass, obj_mask_expand);
auto* loss_data = loss->mutable_data<T>({1}, ctx.GetPlace());
loss_data[0] = loss_x + loss_y + loss_w + loss_h + loss_conf_obj +
......@@ -315,8 +379,76 @@ template <typename DeviceContext, typename T>
class Yolov3LossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* input = ctx.Input<Tensor>("X");
auto* gt_boxes = ctx.Input<Tensor>("GTBox");
auto anchors = ctx.Attr<std::vector<int>>("anchors");
int class_num = ctx.Attr<int>("class_num");
float ignore_thresh = ctx.Attr<float>("ignore_thresh");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
const T loss = output_grad->data<T>()[0];
const int n = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
const int an_num = anchors.size() / 2;
Tensor pred_x, pred_y, pred_w, pred_h;
Tensor pred_conf, pred_class;
pred_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_conf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
pred_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
CalcPredResult<T>(*input, &pred_conf, &pred_class, &pred_x, &pred_y,
&pred_w, &pred_h, an_num, class_num);
Tensor obj_mask, noobj_mask;
Tensor tx, ty, tw, th, tconf, tclass;
obj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
noobj_mask.mutable_data<int>({n, an_num, h, w}, ctx.GetPlace());
tx.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
ty.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tw.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
th.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tconf.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
tclass.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
PrePorcessGTBox<T>(*gt_boxes, ignore_thresh, anchors, h, &obj_mask,
&noobj_mask, &tx, &ty, &tw, &th, &tconf, &tclass);
Tensor obj_mask_expand;
obj_mask_expand.mutable_data<int>({n, an_num, h, w, class_num},
ctx.GetPlace());
ExpandObjMaskByClassNum(&obj_mask_expand, obj_mask);
Tensor grad_x, grad_y, grad_w, grad_h;
Tensor grad_conf_obj, grad_conf_noobj, grad_class;
grad_x.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_y.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_w.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_h.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_conf_obj.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_conf_noobj.mutable_data<T>({n, an_num, h, w}, ctx.GetPlace());
grad_class.mutable_data<T>({n, an_num, h, w, class_num}, ctx.GetPlace());
T obj_mf = CalcMaskPointNum<int>(obj_mask);
T noobj_mf = CalcMaskPointNum<int>(noobj_mask);
T obj_expand_mf = CalcMaskPointNum<int>(obj_mask_expand);
CalcMSEGradWithMask<T>(&grad_x, pred_x, tx, obj_mask, obj_mf);
CalcMSEGradWithMask<T>(&grad_y, pred_y, ty, obj_mask, obj_mf);
CalcMSEGradWithMask<T>(&grad_w, pred_w, tw, obj_mask, obj_mf);
CalcMSEGradWithMask<T>(&grad_h, pred_h, th, obj_mask, obj_mf);
CalcBCEGradWithMask<T>(&grad_conf_obj, pred_conf, tconf, obj_mask, obj_mf);
CalcBCEGradWithMask<T>(&grad_conf_noobj, pred_conf, tconf, noobj_mask,
noobj_mf);
CalcBCEGradWithMask<T>(&grad_class, pred_class, tclass, obj_mask_expand,
obj_expand_mf);
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
AddAllGradToInputGrad<T>(
input_grad, loss, pred_x, pred_y, pred_conf, pred_class, grad_x, grad_y,
grad_w, grad_h, grad_conf_obj, grad_conf_noobj, grad_class, class_num);
}
};
......
......@@ -8244,14 +8244,55 @@ def log_loss(input, label, epsilon=1e-4, name=None):
return loss
def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None):
@templatedoc(op_type="yolov3_loss")
def yolov3_loss(x, gtbox, anchors, class_num, ignore_thresh, name=None):
"""
**YOLOv3 Loss Layer**
${comment}
Args:
x (Variable): ${x_comment}
gtbox (Variable): groud truth boxes, shoulb be in shape of [N, B, 5],
in the third dimenstion, class_id, x, y, w, h should
be stored and x, y, w, h should be relative valud of
input image.
anchors (list|tuple): ${anchors_comment}
class_num (int): ${class_num_comment}
ignore_thresh (float): ${ignore_thresh_comment}
name (string): the name of yolov3 loss
This layer
Returns:
Variable: A 1-D tensor with shape [1], the value of yolov3 loss
Raises:
TypeError: Input x of yolov3_loss must be Variable
TypeError: Input gtbox of yolov3_loss must be Variable"
TypeError: Attr anchors of yolov3_loss must be list or tuple
TypeError: Attr class_num of yolov3_loss must be an integer
TypeError: Attr ignore_thresh of yolov3_loss must be a float number
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[10, 255, 13, 13], dtype='float32')
gtbox = fluid.layers.data(name='gtbox', shape=[10, 6, 5], dtype='float32')
anchors = [10, 13, 16, 30, 33, 23]
loss = fluid.layers.yolov3_loss(x=x, gtbox=gtbox, class_num=80
anchors=anchors, ignore_thresh=0.5)
"""
helper = LayerHelper('yolov3_loss', **locals())
if not isinstance(x, Variable):
raise TypeError("Input x of yolov3_loss must be Variable")
if not isinstance(gtbox, Variable):
raise TypeError("Input gtbox of yolov3_loss must be Variable")
if not isinstance(anchors, list) and not isinstance(anchors, tuple):
raise TypeError("Attr anchors of yolov3_loss must be list or tuple")
if not isinstance(class_num, int):
raise TypeError("Attr class_num of yolov3_loss must be an integer")
if not isinstance(ignore_thresh, float):
raise TypeError(
"Attr ignore_thresh of yolov3_loss must be a float number")
if name is None:
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
......@@ -8264,8 +8305,8 @@ def yolov3_loss(x, gtbox, img_height, anchors, ignore_thresh, name=None):
"GTBox": gtbox},
outputs={'Loss': loss},
attrs={
"img_height": img_height,
"anchors": anchors,
"class_num": class_num,
"ignore_thresh": ignore_thresh,
})
return loss
......
......@@ -911,6 +911,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(data_1)
print(str(program))
def test_yolov3_loss(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[30, 7, 7], dtype='float32')
gtbox = layers.data(name='gtbox', shape=[10, 5], dtype='float32')
loss = layers.yolov3_loss(x, gtbox, [10, 13, 30, 13], 10, 0.5)
self.assertIsNotNone(loss)
def test_bilinear_tensor_product_layer(self):
program = Program()
with program_guard(program):
......
......@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
from paddle.fluid import core
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-1.0 * x))
......@@ -65,10 +69,9 @@ def box_iou(box1, box2):
def build_target(gtboxs, attrs, grid_size):
n, b, _ = gtboxs.shape
ignore_thresh = attrs["ignore_thresh"]
img_height = attrs["img_height"]
anchors = attrs["anchors"]
class_num = attrs["class_num"]
an_num = len(anchors) / 2
an_num = len(anchors) // 2
obj_mask = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
noobj_mask = np.ones((n, an_num, grid_size, grid_size)).astype('float32')
tx = np.zeros((n, an_num, grid_size, grid_size)).astype('float32')
......@@ -120,7 +123,7 @@ def build_target(gtboxs, attrs, grid_size):
def YoloV3Loss(x, gtbox, attrs):
n, c, h, w = x.shape
an_num = len(attrs['anchors']) / 2
an_num = len(attrs['anchors']) // 2
class_num = attrs["class_num"]
x = x.reshape((n, an_num, 5 + class_num, h, w)).transpose((0, 1, 3, 4, 2))
pred_x = sigmoid(x[:, :, :, :, 0])
......@@ -144,13 +147,6 @@ def YoloV3Loss(x, gtbox, attrs):
noobj_mask)
loss_class = bce(pred_cls * obj_mask_expand, tcls * obj_mask_expand,
obj_mask_expand)
# print "loss_x: ", loss_x
# print "loss_y: ", loss_y
# print "loss_w: ", loss_w
# print "loss_h: ", loss_h
# print "loss_conf_obj: ", loss_conf_obj
# print "loss_conf_noobj: ", loss_conf_noobj
# print "loss_class: ", loss_class
return loss_x + loss_y + loss_w + loss_h + loss_conf_obj + loss_conf_noobj + loss_class
......@@ -165,29 +161,35 @@ class TestYolov3LossOp(OpTest):
self.gtbox_shape[:2])
self.attrs = {
"img_height": self.img_height,
"anchors": self.anchors,
"class_num": self.class_num,
"ignore_thresh": self.ignore_thresh,
}
self.inputs = {'X': x, 'GTBox': gtbox}
self.outputs = {'Loss': np.array([YoloV3Loss(x, gtbox, self.attrs)])}
print self.outputs
self.outputs = {
'Loss':
np.array([YoloV3Loss(x, gtbox, self.attrs)]).astype('float32')
}
def test_check_output(self):
self.check_output(atol=1e-3)
place = core.CPUPlace()
self.check_output_with_place(place, atol=1e-3)
# def test_check_grad_normal(self):
# self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61)
def test_check_grad_ignore_gtbox(self):
place = core.CPUPlace()
self.check_grad_with_place(
place, ['X'],
'Loss',
no_grad_set=set("GTBox"),
max_relative_error=0.1)
def initTestCase(self):
self.img_height = 608
self.anchors = [10, 13, 16, 30, 33, 23]
self.anchors = [10, 13, 12, 12]
self.class_num = 10
self.ignore_thresh = 0.5
self.x_shape = (5, len(self.anchors) / 2 * (5 + self.class_num), 7, 7)
self.gtbox_shape = (5, 10, 5)
self.x_shape = (5, len(self.anchors) // 2 * (5 + self.class_num), 7, 7)
self.gtbox_shape = (5, 5, 5)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册