提交 ee7d8421 编写于 作者: D dangqingqing

Update doc and follow comments.

上级 09b78c72
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -61,10 +61,12 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -61,10 +61,12 @@ class TargetAssignOp : public framework::OperatorWithKernel {
"The rank of Input(NegIndices) must be 2."); "The rank of Input(NegIndices) must be 2.");
PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0], PADDLE_ENFORCE_EQ(blabel_dims[0], slabel_dims[0],
"The 1st dimension of Input(EncodedGTBBox) and " "The 1st dimension (means the total number of "
"Input(GTScoreLabel) must be the same."); "ground-truth bounding boxes) of Input(EncodedGTBBox) "
"and Input(GTScoreLabel) must be the same.");
PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1], PADDLE_ENFORCE_EQ(blabel_dims[1], mi_dims[1],
"The 2nd dimension of Input(EncodedGTBBox) and " "The 2nd dimension (means the number of priod boxes) "
"of Input(EncodedGTBBox) and "
"Input(MatchIndices) must be the same."); "Input(MatchIndices) must be the same.");
PADDLE_ENFORCE_EQ(blabel_dims[2], 4, PADDLE_ENFORCE_EQ(blabel_dims[2], 4,
"The 3rd dimension of Input(EncodedGTBBox) must be 4."); "The 3rd dimension of Input(EncodedGTBBox) must be 4.");
...@@ -101,31 +103,31 @@ class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,31 +103,31 @@ class TargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"labels with shape [Ng, 1], where the Ng is the same as it in " "labels with shape [Ng, 1], where the Ng is the same as it in "
"the input of EncodedGTBBox."); "the input of EncodedGTBBox.");
AddInput("MatchIndices", AddInput("MatchIndices",
"(Tensor, default LoDTensor<int>), The input matched indices " "(Tensor, default Tensor<int>), The input matched indices "
"with shape [N, Np], where N is the batch size, Np is the same " "with shape [N, Np], where N is the batch size, Np is the same "
"as it in the input of EncodedGTBBox. If MatchIndices[i][j] " "as it in the input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the j-th prior box is not matched to any ground-truh " "is -1, the j-th prior box is not matched to any ground-truh "
"box in i-th instance."); "box in i-th instance.");
AddInput("NegIndices", AddInput("NegIndices",
"(LoDTensor, default LoDTensor<int>), The input negative example " "(LoDTensor, default LoDTensor<int>), The input negative example "
"indics with shape [Neg, 1], where is the total number of " "indices with shape [Neg, 1], where is the total number of "
"negative example indices."); "negative example indices.");
AddAttr<int>("background_label", AddAttr<int>("background_label",
"(int, default 0), Label id for background class.") "(int, default 0), Label index of background class.")
.SetDefault(0); .SetDefault(0);
AddOutput("PredBBoxLabel", AddOutput("PredBBoxLabel",
"(Tensor), The output encoded ground-truth labels " "(Tensor), The output encoded ground-truth labels "
"with shape [N, Np, 4], N is the batch size and Np, 4 is the " "with shape [N, Np, 4], N is the batch size and Np, 4 is the "
"same as they in input of EncodedGTBBox. If MatchIndices[i][j] " "same as they in input of EncodedGTBBox. If MatchIndices[i][j] "
"is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth " "is -1, the PredBBoxLabel[i][j][:] is the encoded ground-truth "
"box for background_label_id in i-th instance."); "box for background_label in i-th instance.");
AddOutput("PredBBoxWeight", AddOutput("PredBBoxWeight",
"(Tensor), The weight for PredBBoxLabel with the shape " "(Tensor), The weight for PredBBoxLabel with the shape "
"of [N, Np, 1]"); "of [N, Np, 1]");
AddOutput("PredScoreLabel", AddOutput("PredScoreLabel",
"(Tensor, default Tensor<int>), The output score labels for " "(Tensor, default Tensor<int>), The output score labels for "
"each predictions with shape [N, Np, 1]. If MatchIndices[i][j] " "each predictions with shape [N, Np, 1]. If MatchIndices[i][j] "
"is -1, PredScoreLabel[i][j] = background_label_id."); "is -1, PredScoreLabel[i][j] = background_label.");
AddOutput("PredScoreWeight", AddOutput("PredScoreWeight",
"(Tensor), The weight for PredScoreLabel with the shape " "(Tensor), The weight for PredScoreLabel with the shape "
"of [N, Np, 1]"); "of [N, Np, 1]");
...@@ -136,19 +138,47 @@ and regression targets to each prior box as well as weights to each ...@@ -136,19 +138,47 @@ and regression targets to each prior box as well as weights to each
prior box. The weights is used to specify which prior box would not contribute prior box. The weights is used to specify which prior box would not contribute
to training loss. to training loss.
TODO(dang qingqing) add an example. For each instance, the output `PredBBoxLabel`, `PredBBoxWeight`,
`PredScoreLabel` and `PredScoreWeight` are assigned based on `MatchIndices`.
Assumed that the row offset for each instance in `EncodedGTBBox` is called lod,
this operato assigns classification/regression targets by performing the
following steps:
1. Assigning all outpts based on `MatchIndices`:
If id = MatchIndices[i][j] > 0,
PredBBoxLabel[i][j] = EncodedGTBBox[lod[i] + id][j]
PredBBoxWeight[i][j] = 1.
PredScoreLabel[i][j] = GTScoreLabel[lod[i] + id]
PredScoreWeight[i][j] = 1.
Otherwise,
PredBBoxLabel[j][j] = [0., 0., 0., 0.]
PredBBoxWeight[i][j] = 0.
PredScoreLabel[i][j] = background_label
PredScoreWeight[i][j] = 0.
2. Assigning PredScoreWeight based on `NegIndices`:
Assumed that the row offset for each instance in `NegIndices` is caleed neg_lod,
for i-th instance and all ids of NegIndices in this instance:
PredScoreLabel[i][id] = background_label
PredScoreWeight[i][id] = 1.0
)DOC"); )DOC");
} }
}; };
template <typename T> template <typename T>
struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> { struct NegTargetAssignFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const int* neg_indices, void operator()(const platform::CPUDeviceContext& ctx, const int* neg_indices,
const size_t* lod, const int num, const int num_prior_box, const size_t* lod, const int num, const int num_prior_box,
const int background_label, int* out_label, T* out_label_wt) { const int background_label, int* out_label, T* out_label_wt) {
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
for (int j = lod[i]; j < lod[i + 1]; ++j) { for (size_t j = lod[i]; j < lod[i + 1]; ++j) {
int id = neg_indices[j]; int id = neg_indices[j];
out_label[i * num_prior_box + id] = background_label; out_label[i * num_prior_box + id] = background_label;
out_label_wt[i * num_prior_box + id] = static_cast<T>(1.0); out_label_wt[i * num_prior_box + id] = static_cast<T>(1.0);
...@@ -157,8 +187,8 @@ struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> { ...@@ -157,8 +187,8 @@ struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, float>; template struct NegTargetAssignFunctor<platform::CPUDeviceContext, float>;
template struct UpdateTargetLabelFunctor<platform::CPUDeviceContext, double>; template struct NegTargetAssignFunctor<platform::CPUDeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -18,38 +18,38 @@ namespace paddle { ...@@ -18,38 +18,38 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
__global__ void UpdateTargetLabelKernel(const int* neg_indices, __global__ void NegTargetAssignKernel(const int* neg_indices, const size_t* lod,
const size_t* lod, const int num, const int num, const int num_prior_box,
const int num_prior_box, const int background_label,
const int background_label, int* out_label, T* out_label_wt) {
int* out_label, T* out_label_wt) {
int bidx = blockIdx.x; int bidx = blockIdx.x;
int st = lod[bidx]; int st = lod[bidx];
int ed = lod[bidx + 1]; int ed = lod[bidx + 1];
int row_start = bidx * num_prior_box;
for (int i = st + threadIdx.x; i < ed; i += blockDim.x) { for (int i = st + threadIdx.x; i < ed; i += blockDim.x) {
int id = neg_indices[i]; int id = row_start + neg_indices[i];
out_label[bidx * num_prior_box + id] = background_label; out_label[id] = background_label;
out_label_wt[bidx * num_prior_box + id] = 1.; out_label_wt[id] = 1.;
} }
} }
template <typename T> template <typename T>
struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, T> { struct NegTargetAssignFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext& ctx,
const int* neg_indices, const size_t* lod, const int num, const int* neg_indices, const size_t* lod, const int num,
const int num_prior_box, const int background_label, const int num_prior_box, const int background_label,
int* out_label, T* out_label_wt) { int* out_label, T* out_label_wt) {
const int block_size = 256; const int block_size = 256;
const int grid_size = num; const int grid_size = num;
UpdateTargetLabelKernel<T><<<grid_size, block_size, 0, ctx.stream()>>>( NegTargetAssignKernel<T><<<grid_size, block_size, 0, ctx.stream()>>>(
neg_indices, lod, num, num_prior_box, background_label, out_label, neg_indices, lod, num, num_prior_box, background_label, out_label,
out_label_wt); out_label_wt);
} }
}; };
template struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, float>; template struct NegTargetAssignFunctor<platform::CUDADeviceContext, float>;
template struct UpdateTargetLabelFunctor<platform::CUDADeviceContext, double>; template struct NegTargetAssignFunctor<platform::CUDADeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -56,40 +56,41 @@ struct TargetAssignFunctor { ...@@ -56,40 +56,41 @@ struct TargetAssignFunctor {
int row = i / num_prior_box_; int row = i / num_prior_box_;
int col = i - row * num_prior_box_; int col = i - row * num_prior_box_;
size_t off = lod_[row]; size_t row_off = lod_[row];
int offset = row * num_prior_box_ + col;
int id = match_indices_[row * num_prior_box_ + col]; int id = match_indices_[offset];
T* obox = out_box_ + (row * num_prior_box_ + col) * 4; T* obox = out_box_ + offset * 4;
int* olabel = out_label_ + row * num_prior_box_ + col; int* olabel = out_label_ + offset;
T* obox_wt = out_box_wt_ + row * num_prior_box_ + col; T* obox_wt = out_box_wt_ + offset;
T* olabel_wt = out_label_wt_ + row * num_prior_box_ + col; T* olabel_wt = out_label_wt_ + offset;
if (id > -1) { if (id > -1) {
const T* gtbox = gt_box_ + ((off + id) * num_prior_box_ + col) * 4; const T* gtbox = gt_box_ + ((row_off + id) * num_prior_box_ + col) * 4;
obox[0] = gtbox[0]; obox[0] = gtbox[0];
obox[1] = gtbox[1]; obox[1] = gtbox[1];
obox[2] = gtbox[2]; obox[2] = gtbox[2];
obox[3] = gtbox[3]; obox[3] = gtbox[3];
olabel[0] = gt_label_[off + id]; olabel[0] = gt_label_[row_off + id];
obox_wt[0] = 1.; obox_wt[0] = static_cast<T>(1.);
olabel_wt[0] = 1.; olabel_wt[0] = static_cast<T>(1.);
} else { } else {
obox[0] = 0.; obox[0] = static_cast<T>(0.);
obox[1] = 0.; obox[1] = static_cast<T>(0.);
obox[2] = 0.; obox[2] = static_cast<T>(0.);
obox[3] = 0.; obox[3] = static_cast<T>(0.);
olabel[0] = background_label_; olabel[0] = background_label_;
obox_wt[0] = 0.; obox_wt[0] = static_cast<T>(0.);
olabel_wt[0] = 0.; olabel_wt[0] = static_cast<T>(0.);
} }
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct UpdateTargetLabelFunctor { struct NegTargetAssignFunctor {
void operator()(const platform::DeviceContext& ctx, const int* neg_indices, void operator()(const platform::DeviceContext& ctx, const int* neg_indices,
const size_t* lod, const int num, const int num_prior_box, const size_t* lod, const int num, const int num_prior_box,
const int background_label, int* out_label, const int background_label, int* out_label,
...@@ -130,7 +131,11 @@ class TargetAssignKernel : public framework::OpKernel<T> { ...@@ -130,7 +131,11 @@ class TargetAssignKernel : public framework::OpKernel<T> {
int64_t num_prior_box = match_indices->dims()[1]; int64_t num_prior_box = match_indices->dims()[1];
auto gt_lod = enc_gt_box->lod().back(); auto gt_lod = enc_gt_box->lod().back();
auto gt_label_lod = gt_label->lod().back();
auto neg_lod = neg_indices->lod().back(); auto neg_lod = neg_indices->lod().back();
for (size_t i = 0; i < gt_lod.size(); ++i) {
PADDLE_ENFORCE_EQ(gt_lod.data()[i], gt_label_lod.data()[i]);
}
size_t* gt_lod_data = gt_lod.data(ctx.GetPlace()); size_t* gt_lod_data = gt_lod.data(ctx.GetPlace());
size_t* neg_lod_data = neg_lod.data(ctx.GetPlace()); size_t* neg_lod_data = neg_lod.data(ctx.GetPlace());
...@@ -145,9 +150,9 @@ class TargetAssignKernel : public framework::OpKernel<T> { ...@@ -145,9 +150,9 @@ class TargetAssignKernel : public framework::OpKernel<T> {
num * num_prior_box); num * num_prior_box);
for_range(functor); for_range(functor);
UpdateTargetLabelFunctor<DeviceContext, T> update_functor; NegTargetAssignFunctor<DeviceContext, T> neg_trg_functor;
update_functor(device_ctx, neg_idx_data, neg_lod_data, num, num_prior_box, neg_trg_functor(device_ctx, neg_idx_data, neg_lod_data, num, num_prior_box,
background_label, olabel_data, olabel_wt_data); background_label, olabel_data, olabel_wt_data);
} }
}; };
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
import unittest import unittest
import numpy as np import numpy as np
import math
import sys
import random import random
from op_test import OpTest from op_test import OpTest
...@@ -89,8 +87,6 @@ class TestTargetAssginOp(OpTest): ...@@ -89,8 +87,6 @@ class TestTargetAssginOp(OpTest):
num_class = 21 num_class = 21
gt_lod = [0, 5, 11, 23] gt_lod = [0, 5, 11, 23]
neg_lod = [0, 4, 7, 13] neg_lod = [0, 4, 7, 13]
#gt_lod = [0, 2, 5]
#neg_lod = [0, 2, 4]
batch_size = len(gt_lod) - 1 batch_size = len(gt_lod) - 1
num_gt = gt_lod[-1] num_gt = gt_lod[-1]
background_label = 0 background_label = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册