未验证 提交 a28ae86e 编写于 作者: W wangguanzhong 提交者: GitHub

Enhance ops to support LoD as input for dygraph detection models. (#25316)

* enhance collect_op for dygraph, test=develop

* enhance detection ops with lod, test=develop

* support none bbox left in generate_proposals, test=develop

* unfiy MultiLevelRoisNum, test=develop

* update core.ops, test=develop

* add op register for new input & output, test=develop
上级 0dab0fc2
......@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License.*/
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -54,11 +55,14 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
score_dim[1]));
}
context->SetOutputDim("FpnRois", {post_nms_topN, 4});
if (context->HasOutput("RoisNum")) {
context->SetOutputDim("RoisNum", {-1});
}
if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
// in Kernel.
context->ShareLoD("MultiLevelRois", "FpnRois");
}
if (context->IsRuntime()) {
if (context->IsRuntime() && !context->HasInputs("MultiLevelRoIsNum")) {
std::vector<framework::InferShapeVarPtr> roi_inputs =
context->GetInputVarPtrs("MultiLevelRois");
std::vector<framework::InferShapeVarPtr> score_inputs =
......@@ -99,7 +103,16 @@ class CollectFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) Multiple score LoDTensors from each level in shape"
" (N, 1), N is the number of RoIs.")
.AsDuplicable();
AddInput(
"MultiLevelRoIsNum",
"(List of Tensor) The RoIs' number of each image on multiple levels."
"The number on each level has the shape of (N), N is the number of "
"images.")
.AsDuplicable()
.AsDispensable();
AddOutput("FpnRois", "(LoDTensor) All selected RoIs with highest scores");
AddOutput("RoisNum", "(Tensor), Number of RoIs in each images.")
.AsDispensable();
AddAttr<int>("post_nms_topN",
"Select post_nms_topN RoIs from"
" all images and all fpn layers");
......@@ -123,3 +136,14 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(collect_fpn_proposals,
ops::CollectFpnProposalsOpKernel<float>,
ops::CollectFpnProposalsOpKernel<double>);
REGISTER_OP_VERSION(collect_fpn_proposals)
.AddCheckpoint(
R"ROC(
Upgrade collect_fpn_proposals add a new input
[MultiLevelRoIsNum] and add a new output [RoisNum].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("MultiLevelRoIsNum",
"The RoIs' number of each image on multiple levels."
"The number on each level has the shape of (N), "
"N is the number of images.")
.NewOutput("RoisNum", "The number of RoIs in each image."));
......@@ -80,14 +80,27 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
int lod_size;
auto place = BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
auto multi_rois_num = ctx.MultiInput<Tensor>("MultiLevelRoIsNum");
for (size_t i = 0; i < roi_ins.size(); ++i) {
auto roi_in = roi_ins[i];
auto score_in = score_ins[i];
auto roi_lod = roi_in->lod().back();
lod_size = roi_lod.size() - 1;
for (size_t n = 0; n < lod_size; ++n) {
for (size_t j = roi_lod[n]; j < roi_lod[n + 1]; ++j) {
roi_batch_id_data[index++] = n;
if (multi_rois_num.size() > 0) {
framework::Tensor temp;
TensorCopySync(*multi_rois_num[i], platform::CPUPlace(), &temp);
const int* length_in = temp.data<int>();
lod_size = multi_rois_num[i]->numel();
for (size_t n = 0; n < lod_size; ++n) {
for (size_t j = 0; j < length_in[n]; ++j) {
roi_batch_id_data[index++] = n;
}
}
} else {
auto length_in = roi_in->lod().back();
lod_size = length_in.size() - 1;
for (size_t n = 0; n < lod_size; ++n) {
for (size_t j = length_in[n]; j < length_in[n + 1]; ++j) {
roi_batch_id_data[index++] = n;
}
}
}
......@@ -190,6 +203,13 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
offset.emplace_back(offset.back() + length_lod_cpu[i]);
}
if (ctx.HasOutput("RoisNum")) {
auto* rois_num = ctx.Output<Tensor>("RoisNum");
int* rois_num_data = rois_num->mutable_data<int>({lod_size}, place);
memory::Copy(place, rois_num_data, place, length_lod_data,
lod_size * sizeof(int), dev_ctx.stream());
}
framework::LoD lod;
lod.emplace_back(offset);
fpn_rois->set_lod(lod);
......
......@@ -17,6 +17,7 @@ limitations under the License.*/
#include <algorithm>
#include <cmath>
#include <cstring>
#include <numeric>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -65,6 +66,8 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
auto multi_layer_scores =
context.MultiInput<paddle::framework::LoDTensor>("MultiLevelScores");
auto multi_rois_num = context.MultiInput<Tensor>("MultiLevelRoIsNum");
int num_size = multi_rois_num.size();
auto* fpn_rois = context.Output<paddle::framework::LoDTensor>("FpnRois");
......@@ -88,11 +91,21 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
const int num_fpn_level = multi_layer_rois.size();
std::vector<int> integral_of_all_rois(num_fpn_level + 1, 0);
for (int i = 0; i < num_fpn_level; ++i) {
auto cur_rois_lod = multi_layer_rois[i]->lod().back();
integral_of_all_rois[i + 1] =
integral_of_all_rois[i] + cur_rois_lod[cur_rois_lod.size() - 1];
int all_rois = 0;
if (num_size == 0) {
auto cur_rois_lod = multi_layer_rois[i]->lod().back();
all_rois = cur_rois_lod[cur_rois_lod.size() - 1];
} else {
const int* cur_rois_num = multi_rois_num[i]->data<int>();
all_rois = std::accumulate(
cur_rois_num, cur_rois_num + multi_rois_num[i]->numel(), 0);
}
integral_of_all_rois[i + 1] = integral_of_all_rois[i] + all_rois;
}
const int batch_size = (num_size == 0)
? multi_layer_rois[0]->lod().back().size() - 1
: multi_rois_num[0]->numel();
// concatenate all fpn rois scores into a list
// create a vector to store all scores
std::vector<ScoreWithID<T>> scores_of_all_rois(
......@@ -100,11 +113,20 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
for (int i = 0; i < num_fpn_level; ++i) {
const T* cur_level_scores = multi_layer_scores[i]->data<T>();
int cur_level_num = integral_of_all_rois[i + 1] - integral_of_all_rois[i];
auto cur_scores_lod = multi_layer_scores[i]->lod().back();
int cur_batch_id = 0;
int pre_num = 0;
for (int j = 0; j < cur_level_num; ++j) {
if (static_cast<size_t>(j) >= cur_scores_lod[cur_batch_id + 1]) {
cur_batch_id++;
if (num_size == 0) {
auto cur_scores_lod = multi_layer_scores[i]->lod().back();
if (static_cast<size_t>(j) >= cur_scores_lod[cur_batch_id + 1]) {
cur_batch_id++;
}
} else {
const int* rois_num_data = multi_rois_num[i]->data<int>();
if (j >= pre_num + rois_num_data[cur_batch_id]) {
pre_num += rois_num_data[cur_batch_id];
cur_batch_id++;
}
}
int cur_index = j + integral_of_all_rois[i];
scores_of_all_rois[cur_index].score = cur_level_scores[j];
......@@ -134,6 +156,9 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
T* fpn_rois_data = fpn_rois->data<T>();
std::vector<size_t> lod0(1, 0);
int cur_batch_id = 0;
std::vector<int64_t> num_per_batch;
int pre_idx = 0;
int cur_num = 0;
for (int i = 0; i < post_nms_topN; ++i) {
int cur_fpn_level = scores_of_all_rois[i].level;
int cur_level_index = scores_of_all_rois[i].index;
......@@ -144,6 +169,18 @@ class CollectFpnProposalsOpKernel : public framework::OpKernel<T> {
if (scores_of_all_rois[i].batch_id != cur_batch_id) {
cur_batch_id = scores_of_all_rois[i].batch_id;
lod0.emplace_back(i);
cur_num = i - pre_idx;
pre_idx = i;
num_per_batch.emplace_back(cur_num);
}
}
num_per_batch.emplace_back(post_nms_topN - pre_idx);
if (context.HasOutput("RoisNum")) {
auto* rois_num = context.Output<Tensor>("RoisNum");
int* rois_num_data =
rois_num->mutable_data<int>({batch_size}, context.GetPlace());
for (int i = 0; i < batch_size; i++) {
rois_num_data[i] = num_per_batch[i];
}
}
lod0.emplace_back(post_nms_topN);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -48,6 +49,14 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
}
ctx->SetOutputsDim("MultiFpnRois", outs_dims);
ctx->SetOutputDim("RestoreIndex", {-1, 1});
if (ctx->HasOutputs("MultiLevelRoIsNum")) {
std::vector<framework::DDim> outs_num_dims;
for (size_t i = 0; i < num_out_rois; ++i) {
outs_num_dims.push_back({-1});
}
ctx->SetOutputsDim("MultiLevelRoIsNum", outs_num_dims);
}
if (!ctx->IsRuntime()) {
for (size_t i = 0; i < num_out_rois; ++i) {
ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i);
......@@ -66,12 +75,22 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
class DistributeFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("FpnRois", "(LoDTensor) The rois at all levels in shape (-1, 4)");
AddInput("FpnRois", "(LoDTensor) The RoIs at all levels in shape (-1, 4)");
AddInput("RoisNum",
"(Tensor) The number of RoIs in shape (B),"
"B is the number of images")
.AsDispensable();
AddOutput("MultiFpnRois", "(LoDTensor) Output with distribute operator")
.AsDuplicable();
AddOutput("RestoreIndex",
"(Tensor) An array of positive number which is "
"used to restore the order of FpnRois");
AddOutput("MultiLevelRoIsNum",
"(List of Tensor) The RoIs' number of each image on multiple "
"levels. The number on each level has the shape of (B),"
"B is the number of images.")
.AsDuplicable()
.AsDispensable();
AddAttr<int>("min_level",
"The lowest level of FPN layer where the"
" proposals come from");
......@@ -105,3 +124,14 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(distribute_fpn_proposals,
ops::DistributeFpnProposalsOpKernel<float>,
ops::DistributeFpnProposalsOpKernel<double>);
REGISTER_OP_VERSION(distribute_fpn_proposals)
.AddCheckpoint(
R"ROC(
Upgrade distribute_fpn_proposals add a new input
[RoisNum] and add a new output [MultiLevelRoIsNum].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("RoIsNum", "The number of RoIs in each image.")
.NewOutput("MultiLevelRoisNum",
"The RoIs' number of each image on multiple "
"levels. The number on each level has the shape of (B),"
"B is the number of images."));
......@@ -76,12 +76,20 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ(
fpn_rois->lod().size(), 1UL,
platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD"
"with one level"));
if (!ctx.HasInput("RoisNum")) {
PADDLE_ENFORCE_EQ(
fpn_rois->lod().size(), 1UL,
platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD"
"with one level"));
}
auto fpn_rois_lod = fpn_rois->lod().back();
std::vector<size_t> fpn_rois_lod;
if (ctx.HasInput("RoisNum")) {
auto* rois_num = ctx.Input<Tensor>("RoisNum");
fpn_rois_lod = GetLodFromRoisNum(rois_num);
} else {
fpn_rois_lod = fpn_rois->lod().back();
}
int lod_size = fpn_rois_lod.size() - 1;
int roi_num = fpn_rois_lod[lod_size];
......@@ -154,6 +162,8 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
restore_idx_data, roi_num);
int start = 0;
auto multi_rois_num = ctx.MultiOutput<Tensor>("MultiLevelRoIsNum");
for (int i = 0; i < num_level; ++i) {
Tensor sub_lod = sub_lod_list.Slice(i, i + 1);
int* sub_lod_data = sub_lod.data<int>();
......@@ -180,6 +190,11 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
multi_fpn_rois[i]->mutable_data<T>({sub_rois_num, kBoxDim},
dev_ctx.GetPlace());
}
if (multi_rois_num.size() > 0) {
Tensor* rois_num_t = multi_rois_num[i];
TensorCopySync(sub_lod, dev_ctx.GetPlace(), rois_num_t);
rois_num_t->Resize({lod_size});
}
framework::LoD lod;
lod.emplace_back(offset);
multi_fpn_rois[i]->set_lod(lod);
......
......@@ -28,6 +28,21 @@ namespace operators {
const int kBoxDim = 4;
inline std::vector<size_t> GetLodFromRoisNum(const Tensor* rois_num) {
std::vector<size_t> rois_lod;
auto* rois_num_data = rois_num->data<int>();
Tensor cpu_tensor;
if (platform::is_gpu_place(rois_num->place())) {
TensorCopySync(*rois_num, platform::CPUPlace(), &cpu_tensor);
rois_num_data = cpu_tensor.data<int>();
}
rois_lod.push_back(static_cast<size_t>(0));
for (int i = 0; i < rois_num->numel(); ++i) {
rois_lod.push_back(rois_lod.back() + static_cast<size_t>(rois_num_data[i]));
}
return rois_lod;
}
template <typename T>
static inline T BBoxArea(const T* box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
......@@ -65,13 +80,22 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
const int num_level = max_level - min_level + 1;
// check that the fpn_rois is not empty
PADDLE_ENFORCE_EQ(
fpn_rois->lod().size(), 1UL,
platform::errors::InvalidArgument("DistributeFpnProposalsOp needs LoD "
"with one level."));
if (!context.HasInput("RoisNum")) {
PADDLE_ENFORCE_EQ(fpn_rois->lod().size(), 1UL,
platform::errors::InvalidArgument(
"DistributeFpnProposalsOp needs LoD "
"with one level."));
}
auto fpn_rois_lod = fpn_rois->lod().back();
int fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<size_t> fpn_rois_lod;
int fpn_rois_num;
if (context.HasInput("RoisNum")) {
auto* rois_num = context.Input<Tensor>("RoisNum");
fpn_rois_lod = GetLodFromRoisNum(rois_num);
} else {
fpn_rois_lod = fpn_rois->lod().back();
}
fpn_rois_num = fpn_rois_lod[fpn_rois_lod.size() - 1];
std::vector<int> target_level;
// std::vector<int> target_level(fpn_rois_num, -1);
// record the number of rois in each level
......@@ -136,6 +160,18 @@ class DistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
for (int i = 0; i < fpn_rois_num; ++i) {
restore_index_data[restore_index_inter[i]] = i;
}
auto multi_rois_num = context.MultiOutput<Tensor>("MultiLevelRoIsNum");
if (multi_rois_num.size() > 0) {
int batch_size = fpn_rois_lod.size() - 1;
for (int i = 0; i < num_level; ++i) {
int* rois_num_data = multi_rois_num[i]->mutable_data<int>(
{batch_size}, context.GetPlace());
for (int j = 0; j < batch_size; ++j) {
rois_num_data[j] = static_cast<int>(multi_fpn_rois_lod0[i][j + 1] -
multi_fpn_rois_lod0[i][j]);
}
}
}
// merge lod information into LoDTensor
for (int i = 0; i < num_level; ++i) {
framework::LoD lod;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -61,6 +62,10 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
if (!ctx->IsRuntime()) {
ctx->SetLoDLevel("RpnRois", std::max(ctx->GetLoDLevel("Scores"), 1));
ctx->SetLoDLevel("RpnRoiProbs", std::max(ctx->GetLoDLevel("Scores"), 1));
}
}
protected:
......@@ -347,7 +352,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
lod0.push_back(0);
anchors.Resize({anchors.numel() / 4, 4});
variances.Resize({variances.numel() / 4, 4});
std::vector<int64_t> tmp_lod;
std::vector<int> tmp_num;
int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) {
......@@ -369,16 +374,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0];
lod0.push_back(num_proposals);
tmp_lod.push_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisLod")) {
auto *rpn_rois_lod = context.Output<Tensor>("RpnRoisLod");
rpn_rois_lod->mutable_data<int64_t>({num}, context.GetPlace());
int64_t *lod_data = rpn_rois_lod->data<int64_t>();
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
for (int i = 0; i < num; i++) {
lod_data[i] = tmp_lod[i];
num_data[i] = tmp_num[i];
}
rpn_rois_lod->Resize({num});
rpn_rois_num->Resize({num});
}
rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod);
......@@ -433,6 +438,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
Tensor keep;
FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep);
// Handle the case when there is no keep index left
if (keep.numel() == 0) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
bbox_sel.mutable_data<T>({1, 4}, ctx.GetPlace());
set_zero(ctx, &bbox_sel, static_cast<T>(0));
Tensor scores_filter;
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(bbox_sel, scores_filter);
}
Tensor scores_filter;
bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
......@@ -481,7 +496,8 @@ class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor), Output proposals with shape (rois_num, 4).");
AddOutput("RpnRoiProbs",
"(LoDTensor) Scores of proposals with shape (rois_num, 1).");
AddOutput("RpnRoisLod", "(Tensor), rpn rois's lod info").AsDispensable();
AddOutput("RpnRoisNum", "(Tensor), The number of Rpn RoIs in each image")
.AsDispensable();
AddAttr<int>("pre_nms_topN",
"Number of top scoring RPN proposals to keep before "
"applying NMS.");
......@@ -515,3 +531,11 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
ops::GenerateProposalsKernel<double>);
REGISTER_OP_VERSION(generate_proposals)
.AddCheckpoint(
R"ROC(
Upgrade generate_proposals add a new output [RpnRoisNum])ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput(
"RpnRoisNum",
"The number of Rpn RoIs in each image. RpnRoisNum is "
"dispensable."));
......@@ -330,6 +330,15 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
keep_index.Resize({keep_num});
Tensor scores_filter, proposals_filter;
// Handle the case when there is no keep index left
if (keep_num == 0) {
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
proposals_filter.mutable_data<T>({1, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
set_zero(ctx, &proposals_filter, static_cast<T>(0));
set_zero(ctx, &scores_filter, static_cast<T>(0));
return std::make_pair(proposals_filter, scores_filter);
}
proposals_filter.mutable_data<T>({keep_num, 4}, ctx.GetPlace());
scores_filter.mutable_data<T>({keep_num, 1}, ctx.GetPlace());
GPUGather<T>(ctx, proposals, keep_index, &proposals_filter);
......@@ -421,7 +430,7 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
int64_t num_proposals = 0;
std::vector<size_t> offset(1, 0);
std::vector<int64_t> tmp_lod;
std::vector<int> tmp_num;
for (int64_t i = 0; i < num; ++i) {
Tensor im_info_slice = im_info->Slice(i, i + 1);
......@@ -448,15 +457,15 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
dev_ctx.Wait();
num_proposals += proposals.dims()[0];
offset.emplace_back(num_proposals);
tmp_lod.push_back(num_proposals);
tmp_num.push_back(proposals.dims()[0]);
}
if (context.HasOutput("RpnRoisLod")) {
auto *rpn_rois_lod = context.Output<Tensor>("RpnRoisLod");
rpn_rois_lod->mutable_data<int64_t>({num}, context.GetPlace());
int64_t *lod_data = rpn_rois_lod->data<int64_t>();
memory::Copy(place, lod_data, cpu_place, &tmp_lod[0],
sizeof(int64_t) * num, dev_ctx.stream());
rpn_rois_lod->Resize({num});
if (context.HasOutput("RpnRoisNum")) {
auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
int *num_data = rpn_rois_num->data<int>();
memory::Copy(place, num_data, cpu_place, &tmp_num[0], sizeof(int) * num,
dev_ctx.stream());
rpn_rois_num->Resize({num});
}
framework::LoD lod;
lod.emplace_back(offset);
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/roi_align_op.h"
#include <memory>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -35,13 +36,13 @@ class ROIAlignOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
if (ctx->HasInput("RoisLod")) {
auto rois_lod_dims = ctx->GetInputDim("RoisLod");
if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(
rois_lod_dims.size(), 1,
platform::errors::InvalidArgument("The RoisLod dimension should be 1"
", but got dimension = %d",
rois_lod_dims.size()));
rois_num_dims.size(), 1,
platform::errors::InvalidArgument("The size of RoisNum should be 1"
", but received size = %d",
rois_num_dims.size()));
}
PADDLE_ENFORCE_EQ(
input_dims.size(), 4,
......@@ -145,9 +146,9 @@ class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker {
"given as [[x1, y1, x2, y2], ...]. "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
AddInput("RoisLod",
AddInput("RoisNum",
"(Tensor), "
"The lod info of rois.")
"The number of RoIs in each image.")
.AsDispensable();
AddOutput("Out",
"(Tensor), "
......@@ -203,7 +204,7 @@ class ROIAlignGradMaker : public framework::SingleGradOpMaker<T> {
op->SetType("roi_align_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("RoisLod", this->Input("RoisLod"));
op->SetInput("RoisNum", this->Input("RoisNum"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
......@@ -231,3 +232,10 @@ REGISTER_OP_CPU_KERNEL(
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_VERSION(roi_align)
.AddCheckpoint(
R"ROC(
Upgrade roi_align add a new input [RoisNum])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"RoisNum",
"The number of RoIs in each image. RoisNum is dispensable."));
......@@ -257,24 +257,26 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod");
int rois_batch_size = rois_lod->numel();
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size,
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and imgs "
"batch_size must be the same. But received rois_batch_size = %d, "
"batch_size = %d",
rois_batch_size, batch_size));
std::vector<int64_t> rois_lod_(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace, rois_lod->data<int64_t>(),
sizeof(int64_t) * rois_batch_size, 0);
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) {
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_num_list.data(), gplace,
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto lod = rois->lod();
......@@ -348,16 +350,18 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod");
int rois_batch_size = rois_lod->numel();
std::vector<int64_t> rois_lod_(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace, rois_lod->data<int64_t>(),
sizeof(int64_t) * rois_batch_size, 0);
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) {
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_num_t->numel();
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_num_list.data(), gplace,
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (size_t i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
......
......@@ -165,21 +165,23 @@ class CPUROIAlignOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisLod")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod");
rois_batch_size = rois_lod_t->numel();
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size,
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The batch size of rois and the batch size of images "
" must be the same. But received the batch size of rois is %d, "
"and the batch size of images is %d",
rois_batch_size, batch_size));
auto* rois_lod = rois_lod_t->data<int64_t>();
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto lod = rois->lod();
......@@ -303,14 +305,16 @@ class CPUROIAlignGradOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisLod")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod");
rois_batch_size = rois_lod_t->numel();
auto* rois_lod = rois_lod_t->data<int64_t>();
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/roi_pool_op.h"
#include <memory>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -34,12 +35,13 @@ class ROIPoolOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("X");
auto rois_dims = ctx->GetInputDim("ROIs");
if (ctx->HasInput("RoisLod")) {
auto rois_lod_dims = ctx->GetInputDim("RoisLod");
PADDLE_ENFORCE_EQ(rois_lod_dims.size(), 1,
if (ctx->HasInput("RoisNum")) {
auto rois_num_dims = ctx->GetInputDim("RoisNum");
PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
platform::errors::InvalidArgument(
"The lod information tensor of ROIs should "
"be one-dimensional"));
"The second dimension of RoisNum should "
"be 1, but received dimension is %d",
rois_num_dims.size()));
}
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
platform::errors::InvalidArgument(
......@@ -140,7 +142,8 @@ class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"Where batch_id is the id of the data, "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
AddInput("RoisLod", "(Tensor), The lod info of rois.").AsDispensable();
AddInput("RoisNum", "(Tensor), The number of RoIs in each image.")
.AsDispensable();
AddOutput("Out",
"(Tensor), "
"The output of ROIPoolOp is a 4-D tensor with shape "
......@@ -197,7 +200,7 @@ class ROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
op->SetType("roi_pool_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("ROIs", this->Input("ROIs"));
op->SetInput("RoisLod", this->Input("RoisLod"));
op->SetInput("RoisNum", this->Input("RoisNum"));
op->SetInput("Argmax", this->Output("Argmax"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
......@@ -223,3 +226,10 @@ REGISTER_OP_CPU_KERNEL(
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_VERSION(roi_pool)
.AddCheckpoint(
R"ROC(
Upgrade roi_pool add a new input [RoisNum])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"RoisNum",
"The number of RoIs in each image. RoisNum is dispensable."));
......@@ -157,19 +157,21 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod");
int rois_batch_size = rois_lod->numel();
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size,
rois_batch_size, batch_size,
"The rois_batch_size and imgs batch_size must be the same.");
std::vector<int64_t> rois_lod_(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace, rois_lod->data<int64_t>(),
sizeof(int64_t) * rois_batch_size, 0);
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) {
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_num_list.data(), gplace,
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
......@@ -206,7 +208,7 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X");
auto* rois = ctx.Input<LoDTensor>("ROIs");
auto* rois_lod = ctx.Input<Tensor>("RoisLod");
auto* rois_lod = ctx.Input<Tensor>("RoisNum");
auto* argmax = ctx.Input<Tensor>("Argmax");
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -229,17 +231,18 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.cuda_device_context();
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
if (ctx.HasInput("RoisLod")) {
auto* rois_lod = ctx.Input<Tensor>("RoisLod");
int rois_batch_size = rois_lod->numel();
std::vector<int64_t> rois_lod_(rois_batch_size);
memory::Copy(cplace, rois_lod_.data(), gplace,
rois_lod->data<int64_t>(),
sizeof(int64_t) * rois_batch_size, 0);
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (size_t i = rois_lod_[n]; i < rois_lod_[n + 1]; ++i) {
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<Tensor>("RoisNum");
int rois_batch_size = rois_num_t->numel();
std::vector<int> rois_num_list(rois_batch_size);
memory::Copy(cplace, rois_num_list.data(), gplace,
rois_num_t->data<int>(), sizeof(int) * rois_batch_size, 0);
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_list[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_list[n];
}
} else {
auto rois_lod = rois->lod().back();
......
......@@ -58,18 +58,20 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisLod")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod");
rois_batch_size = rois_lod_t->numel();
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
PADDLE_ENFORCE_EQ(
rois_batch_size - 1, batch_size,
rois_batch_size, batch_size,
platform::errors::InvalidArgument("The rois_batch_size and imgs "
"batch_size must be the same."));
auto* rois_lod = rois_lod_t->data<int64_t>();
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
......@@ -185,14 +187,16 @@ class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
roi_batch_id_list.mutable_data<int>(ctx.GetPlace());
int rois_batch_size;
if (ctx.HasInput("RoisLod")) {
auto* rois_lod_t = ctx.Input<framework::Tensor>("RoisLod");
rois_batch_size = rois_lod_t->numel();
auto* rois_lod = rois_lod_t->data<int64_t>();
for (int n = 0; n < rois_batch_size - 1; ++n) {
for (int i = rois_lod[n]; i < rois_lod[n + 1]; ++i) {
if (ctx.HasInput("RoisNum")) {
auto* rois_num_t = ctx.Input<framework::Tensor>("RoisNum");
rois_batch_size = rois_num_t->numel();
auto* rois_num_data = rois_num_t->data<int>();
int start = 0;
for (int n = 0; n < rois_batch_size; ++n) {
for (int i = start; i < start + rois_num_data[n]; ++i) {
roi_batch_id_data[i] = n;
}
start += rois_num_data[n];
}
} else {
auto rois_lod = rois->lod().back();
......
......@@ -43,6 +43,11 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"nll_loss", {"X", "Label", "Weight"}},
{"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}},
{"gather", {"X", "Index", "Axis"}},
{"roi_pool", {"X", "ROIs", "RoisNum"}},
{"roi_align", {"X", "ROIs", "RoisNum"}},
{"collect_fpn_proposals",
{"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}},
{"distribute_fpn_proposals", {"FpnRois", "RoisNum"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......@@ -63,6 +68,10 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
{"unique", {"Out", "Index", "Indices", "Counts"}},
{"generate_proposals", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"collect_fpn_proposals", {"FpnRois", "RoisNum"}},
{"distribute_fpn_proposals",
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......
......@@ -20,7 +20,8 @@ from __future__ import print_function
from .layer_function_generator import generate_layer_fn
from .layer_function_generator import autodoc, templatedoc
from ..layer_helper import LayerHelper
from ..framework import Variable
from ..framework import Variable, in_dygraph_mode
from .. import core
from .loss import softmax_with_cross_entropy
from . import tensor
from . import nn
......@@ -2893,8 +2894,8 @@ def generate_proposals(scores,
nms_thresh=0.5,
min_size=0.1,
eta=1.0,
name=None,
return_rois_num=False):
return_rois_num=False,
name=None):
"""
:alias_main: paddle.nn.functional.generate_proposals
:alias: paddle.nn.functional.generate_proposals,paddle.nn.functional.vision.generate_proposals
......@@ -2949,6 +2950,10 @@ def generate_proposals(scores,
num of each image in one batch. The N is the image's num. For example, the tensor has values [4,5] that represents
the first image has 4 Rois, the second image has 5 Rois. It only used in rcnn model.
'False' by default.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
tuple:
A tuple with format ``(rpn_rois, rpn_roi_probs)``.
......@@ -2969,6 +2974,14 @@ def generate_proposals(scores,
im_info, anchors, variances)
"""
if in_dygraph_mode():
assert return_rois_num, "return_rois_num should be True in dygraph mode."
attrs = ('pre_nms_topN', pre_nms_top_n, 'post_nms_topN', post_nms_top_n,
'nms_thresh', nms_thresh, 'min_size', min_size, 'eta', eta)
rpn_rois, rpn_roi_probs, rpn_rois_num = core.ops.generate_proposals(
scores, bbox_deltas, im_info, anchors, variances, *attrs)
return rpn_rois, rpn_roi_probs, rpn_rois_num
helper = LayerHelper('generate_proposals', **locals())
check_variable_and_dtype(scores, 'scores', ['float32'],
......@@ -2986,7 +2999,14 @@ def generate_proposals(scores,
dtype=bbox_deltas.dtype)
rpn_roi_probs = helper.create_variable_for_type_inference(
dtype=scores.dtype)
rpn_rois_lod = helper.create_variable_for_type_inference(dtype='int32')
outputs = {
'RpnRois': rpn_rois,
'RpnRoiProbs': rpn_roi_probs,
}
if return_rois_num:
rpn_rois_num = helper.create_variable_for_type_inference(dtype='int32')
rpn_rois_num.stop_gradient = True
outputs['RpnRoisNum'] = rpn_rois_num
helper.append_op(
type="generate_proposals",
......@@ -3004,17 +3024,12 @@ def generate_proposals(scores,
'min_size': min_size,
'eta': eta
},
outputs={
'RpnRois': rpn_rois,
'RpnRoiProbs': rpn_roi_probs,
'RpnRoisLod': rpn_rois_lod
})
outputs=outputs)
rpn_rois.stop_gradient = True
rpn_roi_probs.stop_gradient = True
rpn_rois_lod.stop_gradient = True
if return_rois_num:
return rpn_rois, rpn_roi_probs, rpn_rois_lod
return rpn_rois, rpn_roi_probs, rpn_rois_num
else:
return rpn_rois, rpn_roi_probs
......@@ -3656,6 +3671,7 @@ def distribute_fpn_proposals(fpn_rois,
max_level,
refer_level,
refer_scale,
rois_num=None,
name=None):
"""
:alias_main: paddle.nn.functional.distribute_fpn_proposals
......@@ -3687,6 +3703,11 @@ def distribute_fpn_proposals(fpn_rois,
come from.
refer_level(int32): The referring level of FPN layer with specified scale.
refer_scale(int32): The referring scale of FPN layer with specified level.
rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image.
The shape is [B] and data type is int32. B is the number of images.
If it is not None then return a list of 1-D Tensor. Each element
is the output RoIs' number of each image on the corresponding level
and the shape is [B]. None by default.
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
......@@ -3702,6 +3723,10 @@ def distribute_fpn_proposals(fpn_rois,
the number of total rois. The data type is int32. It is
used to restore the order of fpn_rois.
rois_num_per_level(List): A list of 1-D Tensor and each Tensor is
the RoIs' number in each image on the corresponding level. The shape
is [B] and data type of int32. B is the number of images
Examples:
.. code-block:: python
......@@ -3716,26 +3741,52 @@ def distribute_fpn_proposals(fpn_rois,
refer_level=4,
refer_scale=224)
"""
num_lvl = max_level - min_level + 1
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
attrs = ('min_level', min_level, 'max_level', max_level, 'refer_level',
refer_level, 'refer_scale', refer_scale)
multi_rois, restore_ind, rois_num_per_level = core.ops.distribute_fpn_proposals(
fpn_rois, rois_num, num_lvl, num_lvl, *attrs)
return multi_rois, restore_ind, rois_num_per_level
check_variable_and_dtype(fpn_rois, 'fpn_rois', ['float32', 'float64'],
'distribute_fpn_proposals')
helper = LayerHelper('distribute_fpn_proposals', **locals())
dtype = helper.input_dtype('fpn_rois')
num_lvl = max_level - min_level + 1
multi_rois = [
helper.create_variable_for_type_inference(dtype) for i in range(num_lvl)
]
restore_ind = helper.create_variable_for_type_inference(dtype='int32')
inputs = {'FpnRois': fpn_rois}
outputs = {
'MultiFpnRois': multi_rois,
'RestoreIndex': restore_ind,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
rois_num_per_level = [
helper.create_variable_for_type_inference(dtype='int32')
for i in range(num_lvl)
]
outputs['MultiLevelRoIsNum'] = rois_num_per_level
helper.append_op(
type='distribute_fpn_proposals',
inputs={'FpnRois': fpn_rois},
outputs={'MultiFpnRois': multi_rois,
'RestoreIndex': restore_ind},
inputs=inputs,
outputs=outputs,
attrs={
'min_level': min_level,
'max_level': max_level,
'refer_level': refer_level,
'refer_scale': refer_scale
})
if rois_num is not None:
return multi_rois, restore_ind, rois_num_per_level
return multi_rois, restore_ind
......@@ -3820,6 +3871,7 @@ def collect_fpn_proposals(multi_rois,
min_level,
max_level,
post_nms_top_n,
rois_num_per_level=None,
name=None):
"""
:alias_main: paddle.nn.functional.collect_fpn_proposals
......@@ -3846,6 +3898,12 @@ def collect_fpn_proposals(multi_rois,
min_level(int): The lowest level of FPN layer to collect
max_level(int): The highest level of FPN layer to collect
post_nms_top_n(int): The number of selected RoIs
rois_num_per_level(list, optional): The List of RoIs' numbers.
Each element is 1-D Tensor which contains the RoIs' number of each
image on each level and the shape is [B] and data type is
int32, B is the number of images. If it is not None then return
a 1-D Tensor contains the output RoIs' number of each image and
the shape is [B]. Default: None
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
......@@ -3856,6 +3914,9 @@ def collect_fpn_proposals(multi_rois,
fpn_rois(Variable): 2-D LoDTensor with shape [N, 4] and data type is
float32 or float64. Selected RoIs.
rois_num(Tensor): 1-D Tensor contains the RoIs's number of each
image. The shape is [B] and data type is int32. B is the number of
images.
Examples:
.. code-block:: python
......@@ -3879,21 +3940,38 @@ def collect_fpn_proposals(multi_rois,
"""
check_type(multi_rois, 'multi_rois', list, 'collect_fpn_proposals')
check_type(multi_scores, 'multi_scores', list, 'collect_fpn_proposals')
num_lvl = max_level - min_level + 1
input_rois = multi_rois[:num_lvl]
input_scores = multi_scores[:num_lvl]
if in_dygraph_mode():
assert rois_num_per_level is not None, "rois_num_per_level should not be None in dygraph mode."
attrs = ('post_nms_topN', post_nms_top_n)
output_rois, rois_num = core.ops.collect_fpn_proposals(
input_rois, input_scores, rois_num_per_level, *attrs)
helper = LayerHelper('collect_fpn_proposals', **locals())
dtype = helper.input_dtype('multi_rois')
check_dtype(dtype, 'multi_rois', ['float32', 'float64'],
'collect_fpn_proposals')
num_lvl = max_level - min_level + 1
input_rois = multi_rois[:num_lvl]
input_scores = multi_scores[:num_lvl]
output_rois = helper.create_variable_for_type_inference(dtype)
output_rois.stop_gradient = True
inputs = {
'MultiLevelRois': input_rois,
'MultiLevelScores': input_scores,
}
outputs = {'FpnRois': output_rois}
if rois_num_per_level is not None:
inputs['MultiLevelRoIsNum'] = rois_num_per_level
rois_num = helper.create_variable_for_type_inference(dtype='int32')
rois_num.stop_gradient = True
outputs['RoisNum'] = rois_num
helper.append_op(
type='collect_fpn_proposals',
inputs={
'MultiLevelRois': input_rois,
'MultiLevelScores': input_scores
},
outputs={'FpnRois': output_rois},
inputs=inputs,
outputs=outputs,
attrs={'post_nms_topN': post_nms_top_n})
if rois_num_per_level is not None:
return output_rois, rois_num
return output_rois
......@@ -6862,7 +6862,8 @@ def roi_pool(input,
pooled_height=1,
pooled_width=1,
spatial_scale=1.0,
rois_lod=None):
rois_num=None,
name=None):
"""
:alias_main: paddle.nn.functional.roi_pool
:alias: paddle.nn.functional.roi_pool,paddle.nn.functional.vision.roi_pool
......@@ -6882,10 +6883,14 @@ def roi_pool(input,
Args:
input (Variable): Input feature, 4D-Tensor with the shape of [N,C,H,W], where N is the batch size, C is the input channel, H is Height, W is weight. The data type is float32 or float64.
rois (Variable): ROIs (Regions of Interest) to pool over. 2D-LoDTensor with the shape of [num_rois,4], the lod level is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is the top left coordinates, and (x2, y2) is the bottom right coordinates.
rois_lod (Variable): The lod info of rois. Default: None
pooled_height (int, optional): The pooled output height, data type is int32. Default: 1
pooled_width (int, optional): The pooled output height, data type is int32. Default: 1
spatial_scale (float, optional): Multiplicative spatial scale factor to translate ROI coords from their input scale to the scale used when pooling. Default: 1.0
rois_num (Tensor): The number of RoIs in each image. Default: None
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Variable: The pooled feature, 4D-Tensor with the shape of [num_rois, C, pooled_height, pooled_width].
......@@ -6905,11 +6910,11 @@ def roi_pool(input,
input_data = np.array([i for i in range(1,17)]).reshape(1,1,4,4).astype(DATATYPE)
roi_data =fluid.create_lod_tensor(np.array([[1., 1., 2., 2.], [1.5, 1.5, 3., 3.]]).astype(DATATYPE),[[2]], place)
rois_lod_data = np.array([0, 2])
rois_num_data = np.array([2]).astype('int32')
x = fluid.data(name='input', shape=[None,1,4,4], dtype=DATATYPE)
rois = fluid.data(name='roi', shape=[None,4], dtype=DATATYPE)
rois_lod = fluid.data(name='rois_lod', shape=[None], dtype='int64')
rois_num = fluid.data(name='rois_num', shape=[None], dtype='int32')
pool_out = fluid.layers.roi_pool(
input=x,
......@@ -6917,24 +6922,36 @@ def roi_pool(input,
pooled_height=1,
pooled_width=1,
spatial_scale=1.0,
rois_lod=rois_lod)
rois_num=rois_num)
exe = fluid.Executor(place)
out, = exe.run(feed={'input':input_data ,'roi':roi_data, 'rois_lod': rois_lod_data}, fetch_list=[pool_out.name])
out, = exe.run(feed={'input':input_data ,'roi':roi_data, 'rois_num': rois_num_data}, fetch_list=[pool_out.name])
print(out) #array([[[[11.]]], [[[16.]]]], dtype=float32)
print(np.array(out).shape) # (2, 1, 1, 1)
"""
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
pool_out, argmaxes = core.ops.roi_pool(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale)
return pool_out, argmaxes
check_variable_and_dtype(input, 'input', ['float32'], 'roi_pool')
check_variable_and_dtype(rois, 'rois', ['float32'], 'roi_pool')
helper = LayerHelper('roi_pool', **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
argmaxes = helper.create_variable_for_type_inference(dtype='int32')
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op(
type="roi_pool",
inputs={"X": input,
"ROIs": rois,
"RoisLod": rois_lod},
inputs=inputs,
outputs={"Out": pool_out,
"Argmax": argmaxes},
attrs={
......@@ -6952,8 +6969,8 @@ def roi_align(input,
pooled_width=1,
spatial_scale=1.0,
sampling_ratio=-1,
name=None,
rois_lod=None):
rois_num=None,
name=None):
"""
:alias_main: paddle.nn.functional.roi_align
:alias: paddle.nn.functional.roi_align,paddle.nn.functional.vision.roi_align
......@@ -6968,11 +6985,11 @@ def roi_align(input,
data type is float32 or float64. Given as [[x1, y1, x2, y2], ...],
(x1, y1) is the top left coordinates, and (x2, y2) is the bottom
right coordinates.
rois_lod (Variable): The lod info of rois. Default: None
pooled_height (int32, optional): ${pooled_height_comment} Default: 1
pooled_width (int32, optional): ${pooled_width_comment} Default: 1
spatial_scale (float32, optional): ${spatial_scale_comment} Default: 1.0
sampling_ratio(int32, optional): ${sampling_ratio_comment} Default: -1
rois_num (Tensor): The number of RoIs in each image. Default: None
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
......@@ -6991,26 +7008,38 @@ def roi_align(input,
name='data', shape=[None, 256, 32, 32], dtype='float32')
rois = fluid.data(
name='rois', shape=[None, 4], dtype='float32')
rois_lod = fluid.data(name='rois_lod', shape=[None], dtype='int64')
rois_num = fluid.data(name='rois_num', shape=[None], dtype='int32')
align_out = fluid.layers.roi_align(input=x,
rois=rois,
pooled_height=7,
pooled_width=7,
spatial_scale=0.5,
sampling_ratio=-1,
rois_lod=rois_lod)
rois_num=rois_num)
"""
if in_dygraph_mode():
assert rois_num is not None, "rois_num should not be None in dygraph mode."
align_out = core.ops.roi_align(
input, rois, rois_num, "pooled_height", pooled_height,
"pooled_width", pooled_width, "spatial_scale", spatial_scale,
"sampling_ratio", sampling_ratio)
return align_out
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'roi_align')
check_variable_and_dtype(rois, 'rois', ['float32', 'float64'], 'roi_align')
helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype()
align_out = helper.create_variable_for_type_inference(dtype)
inputs = {
"X": input,
"ROIs": rois,
}
if rois_num is not None:
inputs['RoisNum'] = rois_num
helper.append_op(
type="roi_align",
inputs={"X": input,
"ROIs": rois,
"RoisLod": rois_lod},
inputs=inputs,
outputs={"Out": align_out},
attrs={
"pooled_height": pooled_height,
......
......@@ -19,6 +19,57 @@ import paddle.fluid.layers as layers
from paddle.fluid.layers import detection
from paddle.fluid.framework import Program, program_guard
import unittest
import contextlib
import numpy as np
from unittests.test_imperative_base import new_program_scope
from paddle.fluid.dygraph import base
from paddle.fluid import core
class LayerTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.seed = 111
@classmethod
def tearDownClass(cls):
pass
def _get_place(self, force_to_use_cpu=False):
# this option for ops that only have cpu kernel
if force_to_use_cpu:
return core.CPUPlace()
else:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
return core.CPUPlace()
@contextlib.contextmanager
def static_graph(self):
with new_program_scope():
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
yield
def get_static_graph_result(self,
feed,
fetch_list,
with_lod=False,
force_to_use_cpu=False):
exe = fluid.Executor(self._get_place(force_to_use_cpu))
exe.run(fluid.default_startup_program())
return exe.run(fluid.default_main_program(),
feed=feed,
fetch_list=fetch_list,
return_numpy=(not with_lod))
@contextlib.contextmanager
def dynamic_graph(self, force_to_use_cpu=False):
with fluid.dygraph.guard(
self._get_place(force_to_use_cpu=force_to_use_cpu)):
fluid.default_startup_program().random_seed = self.seed
fluid.default_main_program().random_seed = self.seed
yield
class TestDetection(unittest.TestCase):
......@@ -481,45 +532,67 @@ class TestRpnTargetAssign(unittest.TestCase):
print(str(program))
class TestGenerateProposals(unittest.TestCase):
class TestGenerateProposals(LayerTest):
def test_generate_proposals(self):
program = Program()
with program_guard(program):
data_shape = [20, 64, 64]
images = fluid.layers.data(
name='images', shape=data_shape, dtype='float32')
im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32')
anchors, variances = fluid.layers.anchor_generator(
name='anchor_generator',
input=images,
anchor_sizes=[32, 64],
aspect_ratios=[1.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
num_anchors = anchors.shape[2]
scores = fluid.layers.data(
name='scores', shape=[num_anchors, 8, 8], dtype='float32')
bbox_deltas = fluid.layers.data(
name='bbox_deltas',
shape=[num_anchors * 4, 8, 8],
dtype='float32')
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
name='generate_proposals',
scores=scores,
bbox_deltas=bbox_deltas,
im_info=im_info,
anchors=anchors,
variances=variances,
pre_nms_top_n=6000,
post_nms_top_n=1000,
nms_thresh=0.5,
min_size=0.1,
eta=1.0)
self.assertIsNotNone(rpn_rois)
self.assertIsNotNone(rpn_roi_probs)
print(rpn_rois.shape)
scores_np = np.random.rand(2, 3, 4, 4).astype('float32')
bbox_deltas_np = np.random.rand(2, 12, 4, 4).astype('float32')
im_info_np = np.array([[8, 8, 0.5], [6, 6, 0.5]]).astype('float32')
anchors_np = np.reshape(np.arange(4 * 4 * 3 * 4),
[4, 4, 3, 4]).astype('float32')
variances_np = np.ones((4, 4, 3, 4)).astype('float32')
with self.static_graph():
scores = fluid.data(
name='scores', shape=[2, 3, 4, 4], dtype='float32')
bbox_deltas = fluid.data(
name='bbox_deltas', shape=[2, 12, 4, 4], dtype='float32')
im_info = fluid.data(name='im_info', shape=[2, 3], dtype='float32')
anchors = fluid.data(
name='anchors', shape=[4, 4, 3, 4], dtype='float32')
variances = fluid.data(
name='var', shape=[4, 4, 3, 4], dtype='float32')
rois, roi_probs, rois_num = fluid.layers.generate_proposals(
scores,
bbox_deltas,
im_info,
anchors,
variances,
pre_nms_top_n=10,
post_nms_top_n=5,
return_rois_num=True)
rois_stat, roi_probs_stat, rois_num_stat = self.get_static_graph_result(
feed={
'scores': scores_np,
'bbox_deltas': bbox_deltas_np,
'im_info': im_info_np,
'anchors': anchors_np,
'var': variances_np
},
fetch_list=[rois, roi_probs, rois_num],
with_lod=True)
with self.dynamic_graph():
scores_dy = base.to_variable(scores_np)
bbox_deltas_dy = base.to_variable(bbox_deltas_np)
im_info_dy = base.to_variable(im_info_np)
anchors_dy = base.to_variable(anchors_np)
variances_dy = base.to_variable(variances_np)
rois, roi_probs, rois_num = fluid.layers.generate_proposals(
scores_dy,
bbox_deltas_dy,
im_info_dy,
anchors_dy,
variances_dy,
pre_nms_top_n=10,
post_nms_top_n=5,
return_rois_num=True)
rois_dy = rois.numpy()
roi_probs_dy = roi_probs.numpy()
rois_num_dy = rois_num.numpy()
self.assertTrue(np.array_equal(np.array(rois_stat), rois_dy))
self.assertTrue(np.array_equal(np.array(roi_probs_stat), roi_probs_dy))
self.assertTrue(np.array_equal(np.array(rois_num_stat), rois_num_dy))
class TestYoloDetection(unittest.TestCase):
......@@ -648,30 +721,81 @@ class TestMulticlassNMS2(unittest.TestCase):
self.assertIsNotNone(index)
class TestCollectFpnPropsals(unittest.TestCase):
class TestCollectFpnPropsals(LayerTest):
def test_collect_fpn_proposals(self):
program = Program()
with program_guard(program):
multi_bboxes_np = []
multi_scores_np = []
rois_num_per_level_np = []
for i in range(4):
bboxes_np = np.random.rand(5, 4).astype('float32')
scores_np = np.random.rand(5, 1).astype('float32')
rois_num = np.array([2, 3]).astype('int32')
multi_bboxes_np.append(bboxes_np)
multi_scores_np.append(scores_np)
rois_num_per_level_np.append(rois_num)
with self.static_graph():
multi_bboxes = []
multi_scores = []
rois_num_per_level = []
for i in range(4):
bboxes = layers.data(
bboxes = fluid.data(
name='rois' + str(i),
shape=[10, 4],
shape=[5, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
scores = layers.data(
lod_level=1)
scores = fluid.data(
name='scores' + str(i),
shape=[10, 1],
shape=[5, 1],
dtype='float32',
lod_level=1,
append_batch_size=False)
lod_level=1)
rois_num = fluid.data(
name='rois_num' + str(i), shape=[None], dtype='int32')
multi_bboxes.append(bboxes)
multi_scores.append(scores)
fpn_rois = layers.collect_fpn_proposals(multi_bboxes, multi_scores,
2, 5, 10)
self.assertIsNotNone(fpn_rois)
rois_num_per_level.append(rois_num)
fpn_rois, rois_num = layers.collect_fpn_proposals(
multi_bboxes,
multi_scores,
2,
5,
10,
rois_num_per_level=rois_num_per_level)
feed = {}
for i in range(4):
feed['rois' + str(i)] = multi_bboxes_np[i]
feed['scores' + str(i)] = multi_scores_np[i]
feed['rois_num' + str(i)] = rois_num_per_level_np[i]
fpn_rois_stat, rois_num_stat = self.get_static_graph_result(
feed=feed, fetch_list=[fpn_rois, rois_num], with_lod=True)
fpn_rois_stat = np.array(fpn_rois_stat)
rois_num_stat = np.array(rois_num_stat)
with self.dynamic_graph():
multi_bboxes_dy = []
multi_scores_dy = []
rois_num_per_level_dy = []
for i in range(4):
bboxes_dy = base.to_variable(multi_bboxes_np[i])
scores_dy = base.to_variable(multi_scores_np[i])
rois_num_dy = base.to_variable(rois_num_per_level_np[i])
multi_bboxes_dy.append(bboxes_dy)
multi_scores_dy.append(scores_dy)
rois_num_per_level_dy.append(rois_num_dy)
fpn_rois_dy, rois_num_dy = fluid.layers.collect_fpn_proposals(
multi_bboxes_dy,
multi_scores_dy,
2,
5,
10,
rois_num_per_level=rois_num_per_level_dy)
fpn_rois_dy = fpn_rois_dy.numpy()
rois_num_dy = rois_num_dy.numpy()
self.assertTrue(np.array_equal(fpn_rois_stat, fpn_rois_dy))
self.assertTrue(np.array_equal(rois_num_stat, rois_num_dy))
def test_collect_fpn_proposals_error(self):
def generate_input(bbox_type, score_type, name):
......@@ -717,20 +841,51 @@ class TestCollectFpnPropsals(unittest.TestCase):
post_nms_top_n=2000)
class TestDistributeFpnProposals(unittest.TestCase):
class TestDistributeFpnProposals(LayerTest):
def test_distribute_fpn_proposals(self):
program = Program()
with program_guard(program):
fpn_rois = fluid.layers.data(
name='data', shape=[4], dtype='float32', lod_level=1)
multi_rois, restore_ind = layers.distribute_fpn_proposals(
fpn_rois=fpn_rois,
rois_np = np.random.rand(10, 4).astype('float32')
rois_num_np = np.array([4, 6]).astype('int32')
with self.static_graph():
rois = fluid.data(name='rois', shape=[10, 4], dtype='float32')
rois_num = fluid.data(name='rois_num', shape=[None], dtype='int32')
multi_rois, restore_ind, rois_num_per_level = layers.distribute_fpn_proposals(
fpn_rois=rois,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224)
self.assertIsNotNone(multi_rois)
self.assertIsNotNone(restore_ind)
refer_scale=224,
rois_num=rois_num)
fetch_list = multi_rois + [restore_ind] + rois_num_per_level
output_stat = self.get_static_graph_result(
feed={'rois': rois_np,
'rois_num': rois_num_np},
fetch_list=fetch_list,
with_lod=True)
output_stat_np = []
for output in output_stat:
output_np = np.array(output)
if len(output_np) > 0:
output_stat_np.append(output_np)
with self.dynamic_graph():
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
multi_rois_dy, restore_ind_dy, rois_num_per_level_dy = layers.distribute_fpn_proposals(
fpn_rois=rois_dy,
min_level=2,
max_level=5,
refer_level=4,
refer_scale=224,
rois_num=rois_num_dy)
output_dy = multi_rois_dy + [restore_ind_dy] + rois_num_per_level_dy
output_dy_np = []
for output in output_dy:
output_np = output.numpy()
if len(output_np) > 0:
output_dy_np.append(output_np)
for res_stat, res_dy in zip(output_stat_np, output_dy_np):
self.assertTrue(np.array_equal(res_stat, res_dy))
def test_distribute_fpn_proposals_error(self):
program = Program()
......
......@@ -33,10 +33,14 @@ class TestCollectFPNProposalstOp(OpTest):
for i in range(self.num_level)]
self.inputs = {
'MultiLevelRois': inputs_x,
"MultiLevelScores": self.scores_input
"MultiLevelScores": self.scores_input,
'MultiLevelRoIsNum': []
}
self.attrs = {'post_nms_topN': self.post_nms_top_n, }
self.outputs = {'FpnRois': (self.rois, [self.lod])}
self.outputs = {
'FpnRois': (self.rois, [self.lod]),
'RoisNum': np.array(self.lod).astype('int32')
}
def init_test_case(self):
self.post_nms_top_n = 20
......@@ -96,5 +100,32 @@ class TestCollectFPNProposalstOp(OpTest):
self.check_output(check_dygraph=False)
class TestCollectFPNProposalstOpWithRoisNum(TestCollectFPNProposalstOp):
def set_data(self):
self.init_test_case()
self.make_rois()
self.scores_input = [('y%d' % i,
(self.scores[i].reshape(-1, 1), self.rois_lod[i]))
for i in range(self.num_level)]
self.rois, self.lod = self.calc_rois_collect()
inputs_x = [('x%d' % i, (self.roi_inputs[i][:, 1:], self.rois_lod[i]))
for i in range(self.num_level)]
rois_num_per_level = [
('rois%d' % i, np.array(self.rois_lod[i][0]).astype('int32'))
for i in range(self.num_level)
]
self.inputs = {
'MultiLevelRois': inputs_x,
"MultiLevelScores": self.scores_input,
'MultiLevelRoIsNum': rois_num_per_level
}
self.attrs = {'post_nms_topN': self.post_nms_top_n, }
self.outputs = {
'FpnRois': (self.rois, [self.lod]),
'RoisNum': np.array(self.lod).astype('int32')
}
if __name__ == '__main__':
unittest.main()
......@@ -35,9 +35,10 @@ class TestDistributeFPNProposalsOp(OpTest):
}
output = [('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1)
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
}
def init_test_case(self):
......@@ -117,5 +118,34 @@ class TestDistributeFPNProposalsOp(OpTest):
self.check_output()
class TestDistributeFPNProposalsOpWithRoisNum(TestDistributeFPNProposalsOp):
def set_data(self):
self.init_test_case()
self.make_rois()
self.rois_fpn, self.rois_idx_restore = self.calc_rois_distribute()
self.inputs = {
'FpnRois': (self.rois[:, 1:5], self.rois_lod),
'RoisNum': np.array(self.rois_lod[0]).astype('int32')
}
self.attrs = {
'max_level': self.roi_max_level,
'min_level': self.roi_min_level,
'refer_scale': self.canonical_scale,
'refer_level': self.canonical_level
}
output = [('out%d' % i, self.rois_fpn[i])
for i in range(len(self.rois_fpn))]
rois_num_per_level = [
('rois_num%d' % i, np.array(self.rois_fpn[i][1][0]).astype('int32'))
for i in range(len(self.rois_fpn))
]
self.outputs = {
'MultiFpnRois': output,
'RestoreIndex': self.rois_idx_restore.reshape(-1, 1),
'MultiLevelRoIsNum': rois_num_per_level
}
if __name__ == '__main__':
unittest.main()
......@@ -34,18 +34,18 @@ def generate_proposals_in_python(scores, bbox_deltas, im_info, anchors,
rpn_rois = []
rpn_roi_probs = []
lod = []
rois_num = []
num_images = scores.shape[0]
for img_idx in range(num_images):
img_i_boxes, img_i_probs = proposal_for_one_image(
im_info[img_idx, :], all_anchors, variances,
bbox_deltas[img_idx, :, :, :], scores[img_idx, :, :, :],
pre_nms_topN, post_nms_topN, nms_thresh, min_size, eta)
lod.append(img_i_probs.shape[0])
rois_num.append(img_i_probs.shape[0])
rpn_rois.append(img_i_boxes)
rpn_roi_probs.append(img_i_probs)
return rpn_rois, rpn_roi_probs, lod
return rpn_rois, rpn_roi_probs, rois_num
def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores,
......@@ -87,6 +87,10 @@ def proposal_for_one_image(im_info, all_anchors, variances, bbox_deltas, scores,
proposals = clip_tiled_boxes(proposals, im_info[:2])
# remove predicted boxes with height or width < min_size
keep = filter_boxes(proposals, min_size, im_info)
if len(keep) == 0:
proposals = np.zeros((1, 4)).astype('float32')
scores = np.zeros((1, 1)).astype('float32')
return proposals, scores
proposals = proposals[keep, :]
scores = scores[keep, :]
......@@ -280,8 +284,8 @@ class TestGenerateProposalsOp(OpTest):
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.lod]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]),
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
}
def test_check_output(self):
......@@ -320,7 +324,7 @@ class TestGenerateProposalsOp(OpTest):
(batch_size, num_anchors * 4, layer_h, layer_w)).astype('float32')
def init_test_output(self):
self.rpn_rois, self.rpn_roi_probs, self.lod = generate_proposals_in_python(
self.rpn_rois, self.rpn_roi_probs, self.rois_num = generate_proposals_in_python(
self.scores, self.bbox_deltas, self.im_info, self.anchors,
self.variances, self.pre_nms_topN, self.post_nms_topN,
self.nms_thresh, self.min_size, self.eta)
......@@ -349,12 +353,21 @@ class TestGenerateProposalsOutLodOp(TestGenerateProposalsOp):
}
self.outputs = {
'RpnRois': (self.rpn_rois[0], [self.lod]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.lod]),
'RpnRoisLod': (np.asarray(
self.lod, dtype=np.int32))
'RpnRois': (self.rpn_rois[0], [self.rois_num]),
'RpnRoiProbs': (self.rpn_roi_probs[0], [self.rois_num]),
'RpnRoisNum': (np.asarray(
self.rois_num, dtype=np.int32))
}
class TestGenerateProposalsOpNoBoxLeft(TestGenerateProposalsOp):
def init_test_params(self):
self.pre_nms_topN = 12000 # train 12000, test 2000
self.post_nms_topN = 5000 # train 6000, test 1000
self.nms_thresh = 0.7
self.min_size = 1000.0
self.eta = 1.
if __name__ == '__main__':
unittest.main()
......@@ -3318,15 +3318,29 @@ class TestBook(LayerTest):
return (out)
def test_roi_pool(self):
# TODO(minqiyang): dygraph do not support lod now
x_np = np.random.rand(2, 3, 8, 8).astype('float32')
rois_np = np.random.rand(3, 4).astype('float32')
rois_num_np = np.array([1, 2]).astype('int32')
with self.static_graph():
x = layers.data(name="x", shape=[256, 30, 30], dtype="float32")
rois = layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1)
rois_lod = layers.data(
name="rois_lod", shape=[None, ], dtype="int", lod_level=1)
output = layers.roi_pool(x, rois, 7, 7, 0.6, rois_lod)
return (output)
x = layers.data(name="x", shape=[3, 8, 8], dtype="float32")
rois = layers.data(name="rois", shape=[4], dtype="float32")
rois_num = fluid.data(name="rois_num", shape=[None], dtype="int32")
output = layers.roi_pool(x, rois, 4, 4, 0.5, rois_num=rois_num)
static_res = self.get_static_graph_result(
feed={'x': x_np,
'rois': rois_np,
'rois_num': rois_num_np},
fetch_list=[output])[0]
with self.dynamic_graph():
x_dy = base.to_variable(x_np)
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
dy_res = layers.roi_pool(
x_dy, rois_dy, 4, 4, 0.5, rois_num=rois_num_dy)
dy_res_value = dy_res[0].numpy()
self.assertTrue(np.array_equal(static_res, dy_res_value))
def test_sequence_enumerate(self):
# TODO(minqiyang): dygraph do not support lod now
......@@ -3335,16 +3349,29 @@ class TestBook(LayerTest):
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
def test_roi_align(self):
# TODO(minqiyang): dygraph do not support lod now
x_np = np.random.rand(2, 3, 8, 8).astype('float32')
rois_np = np.random.rand(3, 4).astype('float32')
rois_num_np = np.array([1, 2]).astype('int32')
with self.static_graph():
x = layers.data(name="x", shape=[256, 30, 30], dtype="float32")
rois = layers.data(
name="rois", shape=[4], dtype="float32", lod_level=1)
rois_lod = layers.data(
name="rois_lod", shape=[None, ], dtype="int", lod_level=1)
output = layers.roi_align(x, rois, 14, 14, 0.5, 2, 'roi_align',
rois_lod)
return (output)
x = layers.data(name="x", shape=[3, 8, 8], dtype="float32")
rois = layers.data(name="rois", shape=[4], dtype="float32")
rois_num = fluid.data(name="rois_num", shape=[None], dtype="int32")
output = layers.roi_align(x, rois, 4, 4, 0.5, 2, rois_num=rois_num)
static_res = self.get_static_graph_result(
feed={'x': x_np,
'rois': rois_np,
'rois_num': rois_num_np},
fetch_list=[output])[0]
with self.dynamic_graph():
x_dy = base.to_variable(x_np)
rois_dy = base.to_variable(rois_np)
rois_num_dy = base.to_variable(rois_num_np)
dy_res = layers.roi_align(
x_dy, rois_dy, 4, 4, 0.5, 2, rois_num=rois_num_dy)
dy_res_value = dy_res.numpy()
self.assertTrue(np.array_equal(static_res, dy_res_value))
def test_roi_perspective_transform(self):
# TODO(minqiyang): dygraph do not support lod now
......
......@@ -181,16 +181,11 @@ class TestROIAlignInLodOp(TestROIAlignOp):
self.calc_roi_align()
seq_len = self.rois_lod[0]
cur_len = 0
lod = [cur_len]
for l in seq_len:
cur_len += l
lod.append(cur_len)
self.inputs = {
'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisLod': np.asarray(lod).astype('int64')
'RoisNum': np.asarray(seq_len).astype('int32')
}
self.attrs = {
......
......@@ -174,16 +174,11 @@ class TestROIPoolInLodOp(TestROIPoolOp):
self.calc_roi_pool()
seq_len = self.rois_lod[0]
cur_len = 0
lod = [cur_len]
for l in seq_len:
cur_len += l
lod.append(cur_len)
self.inputs = {
'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod),
'RoisLod': np.asarray(lod).astype('int64')
'RoisNum': np.asarray(seq_len).astype('int32')
}
self.attrs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册