未验证 提交 07dc5a15 编写于 作者: Q qingqing01 提交者: GitHub

Add generate_mask_labels_op to support Mask-RCNN and refine some code. (#15371)

* Add generate_mask_labels_op to support Mask-RCNN.
* Refine sigmoid_cross_entropy to support nomalize mode.
* Fix generator_proposals_label.
* Use DeviceTemporaryAllocator in roi_pool and roi_algin.
* Remove shape check in data_feeder.
上级 9f5108a6
......@@ -197,7 +197,7 @@ paddle.fluid.layers.clip ArgSpec(args=['x', 'min', 'max', 'name'], varargs=None,
paddle.fluid.layers.clip_by_norm ArgSpec(args=['x', 'max_norm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'ignore_index', 'name'], varargs=None, keywords=None, defaults=(-100, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'ignore_index', 'name', 'normalize'], varargs=None, keywords=None, defaults=(-100, None, False))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
......@@ -318,6 +318,7 @@ paddle.fluid.layers.anchor_generator ArgSpec(args=['input', 'anchor_sizes', 'asp
paddle.fluid.layers.roi_perspective_transform ArgSpec(args=['input', 'rois', 'transformed_height', 'transformed_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1.0,))
paddle.fluid.layers.generate_proposal_labels ArgSpec(args=['rpn_rois', 'gt_classes', 'is_crowd', 'gt_boxes', 'im_info', 'batch_size_per_im', 'fg_fraction', 'fg_thresh', 'bg_thresh_hi', 'bg_thresh_lo', 'bbox_reg_weights', 'class_nums', 'use_random'], varargs=None, keywords=None, defaults=(256, 0.25, 0.25, 0.5, 0.0, [0.1, 0.1, 0.2, 0.2], None, True))
paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'im_info', 'anchors', 'variances', 'pre_nms_top_n', 'post_nms_top_n', 'nms_thresh', 'min_size', 'eta', 'name'], varargs=None, keywords=None, defaults=(6000, 1000, 0.5, 0.1, 1.0, None))
paddle.fluid.layers.generate_mask_labels ArgSpec(args=['im_info', 'gt_classes', 'is_crowd', 'gt_segms', 'rois', 'labels_int32', 'num_classes', 'resolution'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
......
......@@ -83,7 +83,7 @@ __global__ void AffineChannelScaleBiasGradientCUDAKernel(
T* dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
typedef cub::BlockReduce<double, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
......@@ -97,13 +97,16 @@ __global__ void AffineChannelScaleBiasGradientCUDAKernel(
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
__syncthreads();
auto ds_out =
BlockReduce(ds_storage).Reduce(static_cast<double>(ds_sum), cub::Sum());
auto db_out =
BlockReduce(db_storage).Reduce(static_cast<double>(db_sum), cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
dscale[i] = ds_sum;
dbias[i] = db_sum;
dscale[i] = ds_out;
dbias[i] = db_out;
}
__syncthreads();
}
}
......
......@@ -45,3 +45,7 @@ detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op
foreach(src ${LOCAL_DETECTION_LIBS})
set(OP_LIBRARY ${src} ${OP_LIBRARY} CACHE INTERNAL "op libs")
endforeach()
cc_library(mask_util SRCS mask_util.cc DEPS memory)
cc_test(mask_util_test SRCS mask_util_test.cc DEPS memory mask_util)
detection_library(generate_mask_labels_op SRCS generate_mask_labels_op.cc DEPS mask_util)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
......@@ -88,7 +92,9 @@ void BboxOverlaps(const framework::Tensor& r_boxes,
inter_w = std::max(x_max - x_min + 1, zero);
inter_h = std::max(y_max - y_min + 1, zero);
inter_area = inter_w * inter_h;
overlaps_et(i, j) = inter_area / (r_box_area + c_box_area - inter_area);
overlaps_et(i, j) =
(inter_area == 0.) ? 0 : inter_area /
(r_box_area + c_box_area - inter_area);
}
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/detection/mask_util.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
const int kBoxDim = 4;
template <typename T>
void AppendMask(LoDTensor* out, int64_t offset, Tensor* to_add) {
auto* out_data = out->data<T>();
auto* to_add_data = to_add->data<T>();
memcpy(out_data + offset, to_add_data, to_add->numel() * sizeof(T));
}
class GenerateMaskLabelsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("GtClasses"),
"Input(GtClasses) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("IsCrowd"),
"Input(IsCrowd) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("GtSegms"),
"Input(GtSegms) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Rois"), "Input(Rois) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("LabelsInt32"),
"Input(LabelsInt32) shouldn't be null.");
PADDLE_ENFORCE(
ctx->HasOutput("MaskRois"),
"Output(MaskRois) of GenerateMaskLabelsOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("RoiHasMaskInt32"),
"Output(RoiHasMaskInt32) of GenerateMaskLabelsOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("MaskInt32"),
"Output(MaskInt32) of GenerateMaskLabelsOp should not be null");
auto im_info_dims = ctx->GetInputDim("ImInfo");
auto gt_segms_dims = ctx->GetInputDim("GtSegms");
PADDLE_ENFORCE_EQ(im_info_dims.size(), 2,
"The rank of Input(ImInfo) must be 2.");
PADDLE_ENFORCE_EQ(gt_segms_dims.size(), 2,
"The rank of Input(GtSegms) must be 2.");
PADDLE_ENFORCE_EQ(gt_segms_dims[1], 2,
"The second dim of Input(GtSegms) must be 2.");
int num_classes = ctx->Attrs().Get<int>("num_classes");
int resolution = ctx->Attrs().Get<int>("resolution");
ctx->SetOutputDim("MaskRois", {-1, 4});
ctx->SetOutputDim("RoiHasMaskInt32", {-1, 1});
ctx->SetOutputDim("MaskInt32", {-1, num_classes * resolution * resolution});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Rois"));
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
/*
* Expand masks from shape (#masks, M ** 2) to (#masks, #classes * M ** 2)
* to encode class specific mask targets.
*/
template <typename T>
static inline void ExpandMaskTarget(const platform::CPUDeviceContext& ctx,
const Tensor& masks,
const Tensor& mask_class_labels,
const int resolution, const int num_classes,
Tensor* mask_targets) {
const uint8_t* masks_data = masks.data<uint8_t>();
int64_t num_mask = masks.dims()[0];
const int* mask_class_labels_data = mask_class_labels.data<int>();
const int M = resolution * resolution;
const int mask_dim = M * num_classes;
int* mask_targets_data =
mask_targets->mutable_data<int>({num_mask, mask_dim}, ctx.GetPlace());
math::set_constant(ctx, mask_targets, -1);
for (int64_t mask_id = 0; mask_id < num_mask; ++mask_id) {
int cls = mask_class_labels_data[mask_id];
int start = M * cls;
if (cls > 0) {
for (int i = 0; i < M; ++i) {
mask_targets_data[mask_id * mask_dim + start + i] =
static_cast<int>(masks_data[mask_id * M + i]);
}
}
}
}
template <typename T>
std::vector<Tensor> SampleMaskForOneImage(
const platform::CPUDeviceContext& ctx, const Tensor& im_info,
const Tensor& gt_classes, const Tensor& is_crowd, const Tensor& gt_segms,
const Tensor& rois, const Tensor& label_int32, const int num_classes,
const int resolution, const framework::LoD& segm_length) {
// Prepare the mask targets by associating one gt mask to each training roi
// that has a fg (non-bg) class label.
const int64_t gt_size = static_cast<int64_t>(gt_classes.dims()[0]);
const int64_t roi_size = static_cast<int64_t>(rois.dims()[0]);
const int* gt_classes_data = gt_classes.data<int>();
const int* is_crowd_data = is_crowd.data<int>();
const int* label_int32_data = label_int32.data<int>();
PADDLE_ENFORCE_EQ(roi_size, label_int32.dims()[0]);
std::vector<int> mask_gt_inds, fg_inds;
std::vector<std::vector<std::vector<T>>> gt_polys;
auto polys_num = segm_length[1];
auto segm_lod_offset = framework::ConvertToOffsetBasedLoD(segm_length);
auto lod1 = segm_lod_offset[1];
auto lod2 = segm_lod_offset[2];
const T* polys_data = gt_segms.data<T>();
for (int64_t i = 0; i < gt_size; ++i) {
if ((gt_classes_data[i] > 0) && (is_crowd_data[i] == 0)) {
mask_gt_inds.emplace_back(i);
// slice fg segmentation polys
int poly_num = polys_num[i];
std::vector<std::vector<T>> polys;
int s_idx = lod1[i];
for (int j = 0; j < poly_num; ++j) {
int s = lod2[s_idx + j];
int e = lod2[s_idx + j + 1];
PADDLE_ENFORCE_NE(s, e);
std::vector<T> plts(polys_data + s * 2, polys_data + e * 2);
polys.push_back(plts);
}
gt_polys.push_back(polys);
}
}
for (int64_t i = 0; i < roi_size; ++i) {
if (label_int32_data[i] > 0) {
fg_inds.emplace_back(i);
}
}
int gt_num = mask_gt_inds.size();
int fg_num = fg_inds.size();
Tensor boxes_from_polys;
boxes_from_polys.mutable_data<T>({gt_num, 4}, platform::CPUPlace());
Poly2Boxes(gt_polys, boxes_from_polys.data<T>());
std::vector<int> roi_has_mask =
std::vector<int>(fg_inds.begin(), fg_inds.end());
Tensor mask_class_labels;
Tensor masks;
Tensor rois_fg;
auto im_scale = im_info.data<T>()[2];
if (fg_num > 0) {
// Class labels for the foreground rois
mask_class_labels.mutable_data<int>({fg_num, 1}, ctx.GetPlace());
Gather<int>(label_int32_data, 1, fg_inds.data(), fg_inds.size(),
mask_class_labels.data<int>());
uint8_t* masks_data = masks.mutable_data<uint8_t>(
{fg_num, resolution * resolution}, ctx.GetPlace());
// Find overlap between all foreground rois and the bounding boxes
// enclosing each segmentation
T* rois_fg_data = rois_fg.mutable_data<T>({fg_num, 4}, ctx.GetPlace());
Gather<T>(rois.data<T>(), 4, fg_inds.data(), fg_inds.size(),
rois_fg.data<T>());
for (int k = 0; k < rois_fg.numel(); ++k) {
rois_fg_data[k] = rois_fg_data[k] / im_scale;
}
Tensor overlaps_bbfg_bbpolys;
overlaps_bbfg_bbpolys.mutable_data<T>({fg_num, gt_num}, ctx.GetPlace());
BboxOverlaps<T>(rois_fg, boxes_from_polys, &overlaps_bbfg_bbpolys);
// Map from each fg rois to the index of the mask with highest overlap
// (measured by bbox overlap)
T* overlaps_bbfg_bbpolys_data = overlaps_bbfg_bbpolys.data<T>();
std::vector<int> fg_masks_inds;
for (int64_t i = 0; i < fg_num; ++i) {
const T* v = overlaps_bbfg_bbpolys_data + i * gt_num;
T max_overlap = std::numeric_limits<T>::min();
int id = 0;
for (int64_t j = 0; j < gt_num; ++j) {
if (v[j] > max_overlap) {
max_overlap = v[j];
id = j;
}
}
fg_masks_inds.push_back(id);
}
// add fg targets
for (int64_t i = 0; i < fg_num; ++i) {
int fg_polys_ind = fg_masks_inds[i];
T* roi_fg = rois_fg_data + i * 4;
uint8_t* mask = masks_data + i * resolution * resolution;
Polys2MaskWrtBox(gt_polys[fg_polys_ind], roi_fg, resolution, mask);
}
} else {
// The network cannot handle empty blobs, so we must provide a mask
// We simply take the first bg roi, given it an all -1's mask (ignore
// label), and label it with class zero (bg).
int bg_num = 1;
T* rois_fg_data = rois_fg.mutable_data<T>({bg_num, 4}, ctx.GetPlace());
const T* rois_data = rois.data<T>();
std::vector<int> bg_inds;
for (int64_t i = 0; i < roi_size; ++i) {
if (label_int32_data[i] == 0) {
bg_inds.emplace_back(i);
rois_fg_data[0] = rois_data[0] / im_scale;
rois_fg_data[1] = rois_data[1] / im_scale;
rois_fg_data[2] = rois_data[2] / im_scale;
rois_fg_data[3] = rois_data[3] / im_scale;
break;
}
}
masks.mutable_data<uint8_t>({bg_num, resolution * resolution},
ctx.GetPlace());
math::set_constant(ctx, &masks, -1);
int* mask_class_labels_data =
mask_class_labels.mutable_data<int>({bg_num, 1}, ctx.GetPlace());
mask_class_labels_data[0] = 0;
roi_has_mask = std::vector<int>(bg_inds.begin(), bg_inds.end());
}
Tensor masks_expand;
ExpandMaskTarget<T>(ctx, masks, mask_class_labels, resolution, num_classes,
&masks_expand);
T* rois_fg_data = rois_fg.data<T>();
for (int k = 0; k < rois_fg.numel(); ++k) {
rois_fg_data[k] = rois_fg_data[k] * im_scale;
}
Tensor roi_has_mask_t;
int roi_has_mask_size = roi_has_mask.size();
int* roi_has_mask_data =
roi_has_mask_t.mutable_data<int>({roi_has_mask_size, 1}, ctx.GetPlace());
std::copy(roi_has_mask.begin(), roi_has_mask.end(), roi_has_mask_data);
std::vector<Tensor> res;
res.emplace_back(rois_fg);
res.emplace_back(roi_has_mask_t);
res.emplace_back(masks_expand);
return res;
}
template <typename T>
class GenerateMaskLabelsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* im_info = ctx.Input<LoDTensor>("ImInfo");
auto* gt_classes = ctx.Input<LoDTensor>("GtClasses");
auto* is_crowd = ctx.Input<LoDTensor>("IsCrowd");
auto* gt_segms = ctx.Input<LoDTensor>("GtSegms");
auto* rois = ctx.Input<LoDTensor>("Rois");
auto* label_int32 = ctx.Input<LoDTensor>("LabelsInt32");
auto* mask_rois = ctx.Output<LoDTensor>("MaskRois");
auto* roi_has_mask_int32 = ctx.Output<LoDTensor>("RoiHasMaskInt32");
auto* mask_int32 = ctx.Output<LoDTensor>("MaskInt32");
int num_classes = ctx.Attr<int>("num_classes");
int resolution = ctx.Attr<int>("resolution");
PADDLE_ENFORCE_EQ(gt_classes->lod().size(), 1UL,
"GenerateMaskLabelsOp gt_classes needs 1 level of LoD");
PADDLE_ENFORCE_EQ(is_crowd->lod().size(), 1UL,
"GenerateMaskLabelsOp is_crowd needs 1 level of LoD");
PADDLE_ENFORCE_EQ(rois->lod().size(), 1UL,
"GenerateMaskLabelsOp rois needs 1 level of LoD");
PADDLE_ENFORCE_EQ(label_int32->lod().size(), 1UL,
"GenerateMaskLabelsOp label_int32 needs 1 level of LoD");
PADDLE_ENFORCE_EQ(gt_segms->lod().size(), 3UL);
int64_t n = static_cast<int64_t>(gt_classes->lod().back().size() - 1);
PADDLE_ENFORCE_EQ(gt_segms->lod()[0].size() - 1, n);
int mask_dim = num_classes * resolution * resolution;
mask_rois->mutable_data<T>({rois->numel(), kBoxDim}, ctx.GetPlace());
roi_has_mask_int32->mutable_data<int>({rois->numel(), 1}, ctx.GetPlace());
mask_int32->mutable_data<int>({rois->numel(), mask_dim}, ctx.GetPlace());
framework::LoD lod;
std::vector<size_t> lod0(1, 0);
int64_t num_mask = 0;
auto& dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
auto gt_classes_lod = gt_classes->lod().back();
auto is_crowd_lod = is_crowd->lod().back();
auto rois_lod = rois->lod().back();
auto label_int32_lod = label_int32->lod().back();
auto gt_segms_lod = gt_segms->lod();
for (int i = 0; i < n; ++i) {
Tensor im_info_slice = im_info->Slice(i, i + 1);
Tensor gt_classes_slice =
gt_classes->Slice(gt_classes_lod[i], gt_classes_lod[i + 1]);
Tensor is_crowd_slice =
is_crowd->Slice(is_crowd_lod[i], is_crowd_lod[i + 1]);
Tensor label_int32_slice =
label_int32->Slice(label_int32_lod[i], label_int32_lod[i + 1]);
Tensor rois_slice = rois->Slice(rois_lod[i], rois_lod[i + 1]);
auto sub_lod_and_offset =
framework::GetSubLoDAndAbsoluteOffset(gt_segms_lod, i, i + 1, 0);
auto lod_length = sub_lod_and_offset.first;
size_t s = sub_lod_and_offset.second.first;
size_t e = sub_lod_and_offset.second.second;
Tensor gt_segms_slice = gt_segms->Slice(s, e);
std::vector<Tensor> tensor_output = SampleMaskForOneImage<T>(
dev_ctx, im_info_slice, gt_classes_slice, is_crowd_slice,
gt_segms_slice, rois_slice, label_int32_slice, num_classes,
resolution, lod_length);
Tensor sampled_mask_rois = tensor_output[0];
Tensor sampled_roi_has_mask_int32 = tensor_output[1];
Tensor sampled_mask_int32 = tensor_output[2];
AppendMask<T>(mask_rois, kBoxDim * num_mask, &sampled_mask_rois);
AppendMask<int>(roi_has_mask_int32, num_mask,
&sampled_roi_has_mask_int32);
AppendMask<int>(mask_int32, mask_dim * num_mask, &sampled_mask_int32);
num_mask += sampled_mask_rois.dims()[0];
lod0.emplace_back(num_mask);
}
lod.emplace_back(lod0);
mask_rois->set_lod(lod);
roi_has_mask_int32->set_lod(lod);
mask_int32->set_lod(lod);
mask_rois->Resize({num_mask, kBoxDim});
roi_has_mask_int32->Resize({num_mask, 1});
mask_int32->Resize({num_mask, mask_dim});
}
};
class GenerateMaskLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("ImInfo",
"(Tensor), This input is a 2D Tensor with shape [B, 3]. "
"B is the number of input images, "
"each element consists of im_height, im_width, im_scale.");
AddInput("GtClasses",
"(LoDTensor), This input is a 2D LoDTensor with shape [M, 1]. "
"M is the number of groundtruth, "
"each element is a class label of groundtruth.");
AddInput(
"IsCrowd",
"(LoDTensor), This input is a 2D LoDTensor with shape [M, 1]. "
"M is the number of groundtruth, "
"each element is a flag indicates whether a groundtruth is crowd.");
AddInput(
"GtSegms",
"(LoDTensor), This input is a 2D LoDTensor with shape [S, 2], it's LoD "
"level is 3. The LoD[0] represents the gt objects number of each "
"instance. LoD[1] represents the segmentation counts of each objects. "
"LoD[2] represents the polygons number of each segmentation. S the "
"total number of polygons coordinate points. Each element is (x, y) "
"coordinate points.");
AddInput(
"Rois",
"(LoDTensor), This input is a 2D LoDTensor with shape [R, 4]. "
"R is the number of rois which is the output of "
"generate_proposal_labels, "
"each element is a bounding box with (xmin, ymin, xmax, ymax) format.");
AddInput("LabelsInt32",
"(LoDTensor), This intput is a 2D LoDTensor with shape [R, 1], "
"each element repersents a class label of a roi");
AddOutput(
"MaskRois",
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4]. "
"P is the number of mask, "
"each element is a bounding box with [xmin, ymin, xmax, ymax] format.");
AddOutput("RoiHasMaskInt32",
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 1], "
"each element repersents the output mask rois index with regard "
"to input rois");
AddOutput("MaskInt32",
"(LoDTensor), This output is a 4D LoDTensor with shape [P, Q], "
"Q equal to num_classes * resolution * resolution");
AddAttr<int>("num_classes", "Class number.");
AddAttr<int>("resolution", "Resolution of mask.");
AddComment(R"DOC(
This operator can be, for given the RoIs and corresponding labels,
to sample foreground RoIs. This mask branch also has
a :math: `K \\times M^{2}` dimensional output targets for each foreground
RoI, which encodes K binary masks of resolution M x M, one for each of the
K classes. This mask targets are used to compute loss of mask branch.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_mask_labels, ops::GenerateMaskLabelsOp,
ops::GenerateMaskLabelsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(generate_mask_labels,
ops::GenerateMaskLabelsKernel<float>);
......@@ -48,20 +48,21 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
"Input(GtBoxes) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"), "Input(ImInfo) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasOutput("Rois"),
"Output(Rois) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("Rois"),
"Output(Rois) of GenerateProposalLabelsOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("LabelsInt32"),
"Output(LabelsInt32) of RpnTargetAssignOp should not be null");
"Output(LabelsInt32) of GenerateProposalLabelsOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("BboxTargets"),
"Output(BboxTargets) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("BboxInsideWeights"),
"Output(BboxInsideWeights) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("BboxOutsideWeights"),
"Output(BboxOutsideWeights) of RpnTargetAssignOp should not be null");
"Output(BboxTargets) of GenerateProposalLabelsOp should not be null");
PADDLE_ENFORCE(ctx->HasOutput("BboxInsideWeights"),
"Output(BboxInsideWeights) of GenerateProposalLabelsOp "
"should not be null");
PADDLE_ENFORCE(ctx->HasOutput("BboxOutsideWeights"),
"Output(BboxOutsideWeights) of GenerateProposalLabelsOp "
"should not be null");
auto rpn_rois_dims = ctx->GetInputDim("RpnRois");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
......@@ -225,30 +226,36 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context,
template <typename T>
std::vector<Tensor> SampleRoisForOneImage(
const platform::CPUDeviceContext& context, Tensor* rpn_rois,
Tensor* gt_classes, Tensor* is_crowd, Tensor* gt_boxes, Tensor* im_info,
const int batch_size_per_im, const float fg_fraction, const float fg_thresh,
const float bg_thresh_hi, const float bg_thresh_lo,
const platform::CPUDeviceContext& context, const Tensor& rpn_rois_in,
const Tensor& gt_classes, const Tensor& is_crowd, const Tensor& gt_boxes,
const Tensor& im_info, const int batch_size_per_im, const float fg_fraction,
const float fg_thresh, const float bg_thresh_hi, const float bg_thresh_lo,
const std::vector<float>& bbox_reg_weights, const int class_nums,
std::minstd_rand engine, bool use_random) {
auto rpn_rois_et = framework::EigenTensor<T, 2>::From(*rpn_rois);
auto im_scale = im_info->data<T>()[2];
rpn_rois_et = rpn_rois_et / im_scale;
auto im_scale = im_info.data<T>()[2];
Tensor rpn_rois;
rpn_rois.mutable_data<T>(rpn_rois_in.dims(), context.GetPlace());
T* rpn_rois_dt = rpn_rois.data<T>();
const T* rpn_rois_in_dt = rpn_rois_in.data<T>();
for (int i = 0; i < rpn_rois.numel(); ++i) {
rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale;
}
Tensor boxes;
int proposals_num = gt_boxes->dims()[0] + rpn_rois->dims()[0];
int proposals_num = gt_boxes.dims()[0] + rpn_rois.dims()[0];
boxes.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
Concat<T>(context, *gt_boxes, *rpn_rois, &boxes);
Concat<T>(context, gt_boxes, rpn_rois, &boxes);
// Overlaps
Tensor proposal_to_gt_overlaps;
proposal_to_gt_overlaps.mutable_data<T>({proposals_num, gt_boxes->dims()[0]},
proposal_to_gt_overlaps.mutable_data<T>({proposals_num, gt_boxes.dims()[0]},
context.GetPlace());
BboxOverlaps<T>(boxes, *gt_boxes, &proposal_to_gt_overlaps);
BboxOverlaps<T>(boxes, gt_boxes, &proposal_to_gt_overlaps);
// Generate proposal index
std::vector<std::vector<int>> fg_bg_gt = SampleFgBgGt<T>(
context, &proposal_to_gt_overlaps, *is_crowd, batch_size_per_im,
context, &proposal_to_gt_overlaps, is_crowd, batch_size_per_im,
fg_fraction, fg_thresh, bg_thresh_hi, bg_thresh_lo, engine, use_random);
std::vector<int> fg_inds = fg_bg_gt[0];
std::vector<int> bg_inds = fg_bg_gt[1];
......@@ -263,7 +270,7 @@ std::vector<Tensor> SampleRoisForOneImage(
sampled_boxes.mutable_data<T>(bbox_dim, context.GetPlace());
sampled_labels.mutable_data<int>({boxes_num}, context.GetPlace());
sampled_gts.mutable_data<T>({fg_num, kBoxDim}, context.GetPlace());
GatherBoxesLabels<T>(context, boxes, *gt_boxes, *gt_classes, fg_inds, bg_inds,
GatherBoxesLabels<T>(context, boxes, gt_boxes, gt_classes, fg_inds, bg_inds,
gt_inds, &sampled_boxes, &sampled_labels, &sampled_gts);
// Compute targets
......@@ -397,8 +404,8 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]);
Tensor im_info_slice = im_info->Slice(i, i + 1);
std::vector<Tensor> tensor_output = SampleRoisForOneImage<T>(
dev_ctx, &rpn_rois_slice, &gt_classes_slice, &is_crowd_slice,
&gt_boxes_slice, &im_info_slice, batch_size_per_im, fg_fraction,
dev_ctx, rpn_rois_slice, gt_classes_slice, is_crowd_slice,
gt_boxes_slice, im_info_slice, batch_size_per_im, fg_fraction,
fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums,
engine, use_random);
Tensor sampled_rois = tensor_output[0];
......@@ -467,7 +474,7 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
"P usuall equal to batch_size_per_im * batch_size, "
"each element is a bounding box with [xmin, ymin, xmax, ymax] format.");
AddOutput("LabelsInt32",
"(LoDTensor), This output is a 2D LoDTensor with shape [P], "
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 1], "
"each element repersents a class label of a roi");
AddOutput("BboxTargets",
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4 * "
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/mask_util.h"
#include <math.h>
#include <stdlib.h>
#include <algorithm>
#include <limits>
#include <utility>
#include "paddle/fluid/memory/memory.h"
namespace paddle {
namespace operators {
uint32_t UMax(uint32_t a, uint32_t b) { return (a > b) ? a : b; }
static inline int Compare(const void* a, const void* b) {
uint32_t c = *(reinterpret_cast<const uint32_t*>(a));
uint32_t d = *(reinterpret_cast<const uint32_t*>(b));
return c > d ? 1 : c < d ? -1 : 0;
}
void Decode(const uint32_t* cnts, int m, uint8_t* mask) {
uint8_t v = 0;
for (int j = 0; j < m; j++) {
for (uint32_t k = 0; k < cnts[j]; k++) {
*(mask++) = v;
}
v = !v;
}
}
typedef uint32_t uint;
void Poly2Mask(const float* xy, int k, int h, int w, uint8_t* mask) {
int j, m = 0;
double scale = 5;
int *x, *y, *u, *v;
uint *a, *b;
platform::CPUPlace cpu;
auto xptr = memory::Alloc(cpu, sizeof(int) * (k + 1) * 2);
x = reinterpret_cast<int*>(xptr->ptr());
y = x + (k + 1);
for (j = 0; j < k; j++) x[j] = static_cast<int>(scale * xy[j * 2 + 0] + .5);
x[k] = x[0];
for (j = 0; j < k; j++) y[j] = static_cast<int>(scale * xy[j * 2 + 1] + .5);
y[k] = y[0];
for (j = 0; j < k; j++) {
m += UMax(abs(x[j] - x[j + 1]), abs(y[j] - y[j + 1])) + 1;
}
auto vptr = memory::Alloc(cpu, sizeof(int) * m * 2);
u = reinterpret_cast<int*>(vptr->ptr());
v = u + m;
m = 0;
for (j = 0; j < k; j++) {
int xs = x[j], xe = x[j + 1], ys = y[j], ye = y[j + 1], dx, dy, t, d;
int flip;
double s;
dx = abs(xe - xs);
dy = abs(ys - ye);
flip = (dx >= dy && xs > xe) || (dx < dy && ys > ye);
if (flip) {
t = xs;
xs = xe;
xe = t;
t = ys;
ys = ye;
ye = t;
}
if (dx >= dy) {
s = dx == 0 ? 0 : static_cast<double>(ye - ys) / dx;
for (d = 0; d <= dx; d++) {
t = flip ? dx - d : d;
u[m] = t + xs;
v[m] = static_cast<int>(ys + s * t + .5);
m++;
}
} else {
s = dy == 0 ? 0 : static_cast<double>(xe - xs) / dy;
for (d = 0; d <= dy; d++) {
t = flip ? dy - d : d;
v[m] = t + ys;
u[m] = static_cast<int>(xs + s * t + .5);
m++;
}
}
}
/* get points along y-boundary and downsample */
k = m;
m = 0;
double xd, yd;
auto xyptr = memory::Alloc(cpu, sizeof(int) * k * 2);
x = reinterpret_cast<int*>(xyptr->ptr());
y = x + k;
for (j = 1; j < k; j++) {
if (u[j] != u[j - 1]) {
xd = static_cast<double>(u[j] < u[j - 1] ? u[j] : u[j] - 1);
xd = (xd + .5) / scale - .5;
if (floor(xd) != xd || xd < 0 || xd > w - 1) continue;
yd = static_cast<double>(v[j] < v[j - 1] ? v[j] : v[j - 1]);
yd = (yd + .5) / scale - .5;
if (yd < 0)
yd = 0;
else if (yd > h)
yd = h;
yd = ceil(yd);
x[m] = static_cast<int>(xd);
y[m] = static_cast<int>(yd);
m++;
}
}
/* compute rle encoding given y-boundary points */
k = m;
auto aptr = memory::Alloc(cpu, sizeof(uint) * (k + 1));
a = reinterpret_cast<uint*>(aptr->ptr());
for (j = 0; j < k; j++) a[j] = static_cast<uint>(x[j] * h + y[j]);
a[k++] = static_cast<uint>(h * w);
qsort(a, k, sizeof(uint), Compare);
uint p = 0;
for (j = 0; j < k; j++) {
uint t = a[j];
a[j] -= p;
p = t;
}
auto bptr = memory::Alloc(cpu, sizeof(uint32_t) * k);
b = reinterpret_cast<uint32_t*>(bptr->ptr());
j = m = 0;
b[m++] = a[j++];
while (j < k) {
if (a[j] > 0) {
b[m++] = a[j++];
} else {
j++;
if (j < k) b[m - 1] += a[j++];
}
}
// convert to mask
auto mskptr = memory::Alloc(cpu, sizeof(uint8_t) * h * w);
uint8_t* msk = reinterpret_cast<uint8_t*>(mskptr->ptr());
Decode(b, m, msk);
for (int ii = 0; ii < h; ++ii) {
for (int jj = 0; jj < w; ++jj) {
mask[ii * w + jj] = msk[jj * h + ii];
}
}
}
void Poly2Boxes(const std::vector<std::vector<std::vector<float>>>& polys,
float* boxes) {
// lists
for (size_t i = 0; i < polys.size(); ++i) {
float x0 = std::numeric_limits<float>::max();
float x1 = std::numeric_limits<float>::min();
float y0 = std::numeric_limits<float>::max();
float y1 = std::numeric_limits<float>::min();
// each list may have more than one polys
for (size_t j = 0; j < polys[i].size(); ++j) {
for (size_t k = 0; k < polys[i][j].size() / 2; ++k) {
x0 = std::min(x0, polys[i][j][2 * k]);
x1 = std::max(x1, polys[i][j][2 * k]);
y0 = std::min(y0, polys[i][j][2 * k + 1]);
y1 = std::max(y1, polys[i][j][2 * k + 1]);
}
}
boxes[i * 4] = x0;
boxes[i * 4 + 1] = y0;
boxes[i * 4 + 2] = x1;
boxes[i * 4 + 3] = y1;
}
}
void Polys2MaskWrtBox(const std::vector<std::vector<float>>& polygons,
const float* box, int M, uint8_t* mask) {
float w = box[2] - box[0];
float h = box[3] - box[1];
w = std::max(w, static_cast<float>(1.));
h = std::max(h, static_cast<float>(1.));
uint8_t* msk = nullptr;
if (polygons.size() == 1UL) {
msk = mask;
} else {
msk = reinterpret_cast<uint8_t*>(
malloc(M * M * polygons.size() * sizeof(uint8_t)));
}
for (size_t i = 0; i < polygons.size(); ++i) {
int k = polygons[i].size() / 2;
std::vector<float> p;
for (int j = 0; j < k; ++j) {
float pw = (polygons[i][2 * j] - box[0]) * M / w;
float ph = (polygons[i][2 * j + 1] - box[1]) * M / h;
p.push_back(pw);
p.push_back(ph);
}
uint8_t* msk_i = msk + i * M * M;
Poly2Mask(p.data(), k, M, M, msk_i);
}
if (polygons.size() > 1UL) {
for (size_t i = 0; i < polygons.size(); ++i) {
uint8_t* msk_i = msk + i * M * M;
for (int j = 0; j < M * M; ++j) {
if (i == 0) {
mask[j] = msk_i[j];
} else {
mask[j] = (mask[j] + msk_i[j]) > 0 ? 1 : 0;
}
}
}
free(msk);
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <vector>
namespace paddle {
namespace operators {
void Poly2Mask(const float* ploy, int k, int h, int w, uint8_t* mask);
void Poly2Boxes(const std::vector<std::vector<std::vector<float>>>& polys,
float* boxes);
void Polys2MaskWrtBox(const std::vector<std::vector<float>>& polygons,
const float* box, int M, uint8_t* mask);
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/mask_util.h"
#include <gtest/gtest.h>
#include "paddle/fluid/memory/memory.h"
namespace paddle {
namespace operators {
template <typename T>
void Compare(const T* a, const T* b, const int n) {
for (int i = 0; i < n; i++) {
EXPECT_EQ(a[i], b[i]);
}
}
TEST(MaskUtil, Poly2MaskTest) {
float polys[] = {1.97f, 1.88f, 5.81f, 1.88f, 1.69f,
6.53f, 5.94f, 6.38f, 1.97f, 1.88f};
int h = 8, w = 8;
int k = 5; // length(polys) / 2
// clang-format off
uint8_t expect_mask[] = {
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0
};
// clang-format on
// the groud-truth mask is computed by coco API:
//
// import pycocotools.mask as mask_util
// import numpy as np
// segm = [1.97, 1.88, 5.81, 1.88, 1.69, 6.53, 5.94, 6.38, 1.97, 1.88]
// rles = mask_util.frPyObjects([segm], im_h, im_w)
// mask = mask_util.decode(rles)
// print mask
platform::CPUPlace cpu;
auto allocation = memory::Alloc(cpu, sizeof(expect_mask));
uint8_t* mask = reinterpret_cast<uint8_t*>(allocation->ptr());
Poly2Mask(polys, k, h, w, mask);
Compare<uint8_t>(expect_mask, mask, h * w);
}
TEST(MaskUtil, Poly2BoxesTest) {
// clang-format off
std::vector<std::vector<std::vector<float>>> polys = {
{{1.97f, 1.88f, 5.81f, 1.88f, 1.69f, 6.53f, 5.94f, 6.38f, 1.97f, 1.88f}},
{{2.97f, 1.88f, 3.81f, 1.68f, 1.69f, 6.63f, 6.94f, 6.58f, 2.97f, 0.88f}}
};
float expect_boxes[] = {
1.69f, 1.88f, 5.94f, 6.53f,
1.69f, 0.88f, 6.94f, 6.63f
};
// clang-format on
platform::CPUPlace cpu;
auto allocation = memory::Alloc(cpu, sizeof(expect_boxes));
float* boxes = reinterpret_cast<float*>(allocation->ptr());
Poly2Boxes(polys, boxes);
Compare<float>(expect_boxes, boxes, 8);
}
TEST(MaskUtil, Polys2MaskWrtBoxTest) {
// clang-format off
std::vector<std::vector<std::vector<float>>> polys = {{
{1.97f, 1.88f, 5.81f, 1.88f, 1.69f, 6.53f, 5.94f, 6.38f, 1.97f, 1.88f},
{2.97f, 1.88f, 3.81f, 1.68f, 1.69f, 6.63f, 6.94f, 6.58f, 2.97f, 0.88f}}};
float expect_boxes[] = {
1.69f, 0.88f, 6.94f, 6.63f
};
uint8_t expect_mask[] = {
0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0,
0, 1, 1, 1, 1, 1, 0, 0,
0, 1, 1, 1, 1, 1, 1, 0,
1, 1, 1, 1, 1, 1, 1, 1
};
// clang-format on
platform::CPUPlace cpu;
auto allocation = memory::Alloc(cpu, sizeof(expect_boxes));
float* boxes = reinterpret_cast<float*>(allocation->ptr());
Poly2Boxes(polys, boxes);
Compare<float>(expect_boxes, boxes, 4);
auto allocat_mask = memory::Alloc(cpu, sizeof(expect_mask));
uint8_t* mask = reinterpret_cast<uint8_t*>(allocat_mask->ptr());
int M = 8;
Polys2MaskWrtBox(polys[0], expect_boxes, M, mask);
Compare<uint8_t>(expect_mask, mask, M * M);
}
} // namespace operators
} // namespace paddle
......@@ -103,8 +103,10 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
ops::GatherOpKernel<int64_t>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>);
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/roi_align_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -255,8 +256,8 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
......@@ -270,14 +271,18 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
roi_batch_id_data[i] = n;
}
}
Tensor roi_batch_id_list_gpu;
framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(),
&roi_batch_id_list_gpu);
GPUROIAlignForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream());
GPUROIAlignForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
height, width, pooled_height, pooled_width, sampling_ratio,
roi_batch_id_list_gpu.data<int>(),
height, width, pooled_height, pooled_width, sampling_ratio, roi_id_data,
out->mutable_data<T>(ctx.GetPlace()));
}
};
......@@ -307,8 +312,8 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
}
Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
......@@ -316,24 +321,28 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
roi_batch_id_data[i] = n;
}
}
Tensor roi_batch_id_list_gpu;
framework::TensorCopySync(roi_batch_id_list, ctx.GetPlace(),
&roi_batch_id_list_gpu);
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto roi_ptr = allocator.Allocate(roi_batch_id_list.numel() * sizeof(int));
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
int bytes = roi_batch_id_list.numel() * sizeof(int);
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream());
in_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), in_grad, static_cast<T>(0));
set_zero(dev_ctx, in_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIAlignBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
GPUROIAlignBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_grad_size, rois->data<T>(), out_grad->data<T>(), rois_num,
spatial_scale, channels, height, width, pooled_height, pooled_width,
sampling_ratio, roi_batch_id_list_gpu.data<int>(),
sampling_ratio, roi_id_data,
in_grad->mutable_data<T>(ctx.GetPlace()));
}
}
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/roi_pool_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -152,8 +153,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
......@@ -168,15 +169,20 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
}
}
framework::Tensor roi_batch_id_list_gpu;
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &roi_batch_id_list_gpu);
GPUROIPoolForward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream());
GPUROIPoolForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_size, in->data<T>(), rois->data<T>(), spatial_scale, channels,
height, width, pooled_height, pooled_width,
roi_batch_id_list_gpu.data<int>(), out->mutable_data<T>(ctx.GetPlace()),
height, width, pooled_height, pooled_width, roi_id_data,
out->mutable_data<T>(ctx.GetPlace()),
argmax->mutable_data<int64_t>(ctx.GetPlace()));
}
};
......@@ -204,8 +210,8 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
if (x_grad) {
framework::Tensor roi_batch_id_list;
roi_batch_id_list.Resize({rois_num});
int* roi_batch_id_data =
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
auto cplace = platform::CPUPlace();
int* roi_batch_id_data = roi_batch_id_list.mutable_data<int>(cplace);
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
for (int n = 0; n < rois_batch_size; ++n) {
......@@ -213,25 +219,30 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
roi_batch_id_data[i] = n;
}
}
framework::Tensor roi_batch_id_list_gpu;
framework::TensorCopy(roi_batch_id_list, ctx.GetPlace(),
ctx.device_context(), &roi_batch_id_list_gpu);
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
dev_ctx.stream());
x_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(ctx.cuda_device_context(), x_grad, static_cast<T>(0));
set_zero(dev_ctx, x_grad, static_cast<T>(0));
int output_grad_size = out_grad->numel();
int blocks = NumBlocks(output_grad_size);
int threads = kNumCUDAThreads;
if (output_grad_size > 0) {
GPUROIPoolBackward<
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
GPUROIPoolBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
output_grad_size, rois->data<T>(), out_grad->data<T>(),
argmax->data<int64_t>(), rois_num, spatial_scale, channels, height,
width, pooled_height, pooled_width,
roi_batch_id_list_gpu.data<int>(),
width, pooled_height, pooled_width, roi_id_data,
x_grad->mutable_data<T>(ctx.GetPlace()));
}
}
......
......@@ -101,6 +101,10 @@ class SigmoidCrossEntropyWithLogitsOpMaker
AddOutput("Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D "
" of elementwise logistic losses.");
AddAttr<bool>("normalize",
"if true, divide the loss by the number of "
"targets != ignore_index.")
.SetDefault(false);
AddAttr<int>("ignore_index",
"(int, default kIgnoreIndex), Specifies a target value that "
"is ignored and"
......@@ -145,9 +149,14 @@ REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradOp);
REGISTER_OP_CPU_KERNEL(sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsKernel<
paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext,
float>,
ops::SigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_CPU_KERNEL(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CPUDeviceContext, float>);
paddle::platform::CPUDeviceContext, float>,
ops::SigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CPUDeviceContext, double>);
......@@ -11,12 +11,184 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
static HOSTDEVICE float real_exp(float x) { return expf(x); }
static HOSTDEVICE float real_exp(double x) { return exp(x); }
static HOSTDEVICE float real_log(float x) { return logf(x); }
static HOSTDEVICE float real_log(double x) { return log(x); }
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GPUSigmoidForward(const T *x_data, const T *label_data,
const int ignore_index, const int limit,
T *out_data, T *counts) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
T label = label_data[i];
T eps = static_cast<T>(1e-5);
T diff = label - static_cast<T>(ignore_index);
if ((diff > -eps) && (diff < eps)) {
out_data[i] = static_cast<T>(0.);
counts[i] = 0;
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = real_log(static_cast<T>(1) + real_exp(static_cast<T>(-abs(x))));
out_data[i] = term1 - term2 + term3;
counts[i] = 1;
}
}
}
template <typename T, int BlockDim>
__global__ void Sum(const T *counts, int num, const T eps, T *sum) {
typedef cub::BlockReduce<double, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
T in = 0;
for (int i = threadIdx.x; i < num; i += BlockDim) {
in += counts[i];
}
__syncthreads();
auto out =
BlockReduce(temp_storage).Reduce(static_cast<double>(in), cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
T a = out > eps ? out : eps;
sum[0] = a;
}
}
template <typename T>
__global__ void Div(T *loss, const int num, const T *norm) {
CUDA_1D_KERNEL_LOOP(i, num) { loss[i] /= norm[0]; }
}
template <typename T>
__global__ void GPUSigmoidBackward(const T *x_data, const T *label_data,
const int ignore_index, const T *dout_data,
const int limit, T *dx_data, T *counts) {
CUDA_1D_KERNEL_LOOP(i, limit) {
T x = x_data[i];
T label = label_data[i];
T dout = dout_data[i];
T eps = static_cast<T>(1e-5);
T diff = label - static_cast<T>(ignore_index);
if ((diff > -eps) && (diff < eps)) {
dx_data[i] = static_cast<T>(0.);
counts[i] = 0;
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + real_exp(-x));
T diff = simoid_x - label;
dx_data[i] = dout * diff;
counts[i] = 1;
}
}
}
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename DeviceContext, typename T>
class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
Tensor *Out = context.Output<Tensor>("Out");
int ignore_index = context.Attr<int>("ignore_index");
auto out_data = Out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.cuda_device_context();
bool normalize = context.Attr<bool>("normalize");
// Temporary memory
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cnt_ptr = allocator.Allocate(Labels->numel() * sizeof(T));
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
int limit = Out->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<T>(), ignore_index, limit, out_data, counts);
if (normalize) {
auto norm_ptr = allocator.Allocate(sizeof(T));
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
counts, limit, static_cast<T>(1e-5), norm);
Div<T><<<blocks, threads, 0, dev_ctx.stream()>>>(out_data, limit, norm);
}
}
};
// dX = sigmoid(X) - labels
template <typename DeviceContext, typename T>
class GPUSigmoidCrossEntropyWithLogitsGradKernel
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dX->mutable_data<T>(context.GetPlace());
int ignore_index = context.Attr<int>("ignore_index");
auto &dev_ctx = context.cuda_device_context();
// Temporary memory
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cnt_ptr = allocator.Allocate(X->numel() * sizeof(T));
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
int limit = dX->numel();
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
GPUSigmoidBackward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<T>(), ignore_index, dOut->data<T>(), limit,
dx_data, counts);
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
auto norm_ptr = allocator.Allocate(sizeof(T));
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
counts, limit, static_cast<T>(1e-5), norm);
Div<T><<<blocks, threads, 0, dev_ctx.stream()>>>(dx_data, limit, norm);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsKernel<
paddle::platform::CUDADeviceContext, float>);
ops::GPUSigmoidCrossEntropyWithLogitsKernel<
paddle::platform::CUDADeviceContext, float>,
ops::GPUSigmoidCrossEntropyWithLogitsKernel<
paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CUDADeviceContext, float>);
ops::GPUSigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CUDADeviceContext, float>,
ops::GPUSigmoidCrossEntropyWithLogitsGradKernel<
paddle::platform::CUDADeviceContext, double>);
......@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include <algorithm>
#include <limits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct SigmoidCrossEntropyWithLogitsForward {
HOSTDEVICE SigmoidCrossEntropyWithLogitsForward(const int &ignore_index)
: ignore_index(ignore_index) {}
HOSTDEVICE T operator()(const T &x, const T &label) const {
if (static_cast<int>(label) == ignore_index) {
return static_cast<T>(0.);
}
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-(std::abs(x))));
return term1 - term2 + term3;
}
int ignore_index;
};
template <typename T>
struct SigmoidCrossEntropyWithLogitsBackward {
HOSTDEVICE SigmoidCrossEntropyWithLogitsBackward(const int &ignore_index)
: ignore_index(ignore_index) {}
HOSTDEVICE T operator()(const T &x, const T &label) const {
if (static_cast<int>(label) == ignore_index) {
return static_cast<T>(0.);
}
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
return simoid_x - label;
}
int ignore_index;
};
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename DeviceContext, typename T>
......@@ -70,16 +30,37 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
Tensor *Out = context.Output<Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
int ignore_index = context.Attr<int>("ignore_index");
auto x = EigenVector<T>::Flatten(*X);
auto labels = EigenVector<T>::Flatten(*Labels);
auto out = EigenVector<T>::Flatten(*Out);
auto &place = *context.device_context<DeviceContext>().eigen_device();
out.device(place) = x.binaryExpr(
labels, SigmoidCrossEntropyWithLogitsForward<T>(ignore_index));
auto out_data = Out->mutable_data<T>(context.GetPlace());
int limit = Out->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
if (static_cast<int>(label) == ignore_index) {
out_data[idx] = static_cast<T>(0.);
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
out_data[idx] = term1 - term2 + term3;
}
}
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
int norm = 0;
T eps = static_cast<T>(1e-6);
for (int idx = 0; idx < limit; ++idx) {
T diff = label_data[idx] - static_cast<T>(ignore_index);
if ((diff < -eps) || (diff > eps)) {
norm += 1;
}
}
eps = static_cast<T>(1e-5);
norm = norm > eps ? norm : eps;
std::for_each(out_data, out_data + limit, [norm](T &v) { v = v / norm; });
}
}
};
......@@ -92,19 +73,39 @@ class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto ignore_index = context.Attr<int>("ignore_index");
auto x = EigenVector<T>::Flatten(*X);
auto labels = EigenVector<T>::Flatten(*Labels);
auto dout = EigenVector<T>::Flatten(*dOut);
auto dx = EigenVector<T>::Flatten(*dX);
auto &place =
*context.template device_context<DeviceContext>().eigen_device();
auto dx_data = dX->mutable_data<T>(context.GetPlace());
auto diff = x.binaryExpr(labels, SigmoidCrossEntropyWithLogitsBackward<T>(
static_cast<int>(ignore_index)));
dx.device(place) = dout * diff;
int ignore_index = context.Attr<int>("ignore_index");
int limit = dX->numel();
auto x_data = X->data<T>();
auto label_data = Labels->data<T>();
auto dout_data = dOut->data<T>();
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
T dout = dout_data[idx];
if (static_cast<int>(label) == ignore_index) {
dx_data[idx] = static_cast<T>(0.);
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
T diff = simoid_x - label;
dx_data[idx] = dout * diff;
}
}
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
int norm = 0;
T eps = static_cast<T>(1e-6);
for (int idx = 0; idx < limit; ++idx) {
T diff = label_data[idx] - static_cast<T>(ignore_index);
if ((diff < -eps) || (diff > eps)) {
norm += 1;
}
}
eps = static_cast<T>(1e-5);
norm = norm > eps ? norm : eps;
std::for_each(dx_data, dx_data + limit, [norm](T &v) { v = v / norm; });
}
}
};
......
......@@ -88,8 +88,8 @@ class DataToLoDTensorConverter(object):
raise ValueError(
"Reshape error. What is defined in data layer is {}, but receive {}"
.format(self.shape, arr.shape))
else:
self._check_shape(arr.shape)
#else:
# self._check_shape(arr.shape)
t = core.LoDTensor()
t.set(arr, self.place)
if self.lod_level > 0:
......
......@@ -44,6 +44,7 @@ __all__ = [
'roi_perspective_transform',
'generate_proposal_labels',
'generate_proposals',
'generate_mask_labels',
'iou_similarity',
'box_coder',
'polygon_box_transform',
......@@ -1659,7 +1660,7 @@ def generate_proposal_labels(rpn_rois,
class_nums=None,
use_random=True):
"""
** Generate proposal labels Faster-RCNN **
** Generate Proposal Labels of Faster-RCNN **
This operator can be, for given the GenerateProposalOp output bounding boxes and groundtruth,
to sample foreground boxes and background boxes, and compute loss target.
......@@ -1740,6 +1741,140 @@ def generate_proposal_labels(rpn_rois,
return rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights
def generate_mask_labels(im_info, gt_classes, is_crowd, gt_segms, rois,
labels_int32, num_classes, resolution):
"""
** Generate Mask Labels for Mask-RCNN **
This operator can be, for given the RoIs and corresponding labels,
to sample foreground RoIs. This mask branch also has
a :math: `K \\times M^{2}` dimensional output targets for each foreground
RoI, which encodes K binary masks of resolution M x M, one for each of the
K classes. This mask targets are used to compute loss of mask branch.
Please note, the data format of groud-truth segmentation, assumed the
segmentations are as follows. The first instance has two gt objects.
The second instance has one gt object, this object has two gt segmentations.
.. code-block:: python
#[
# [[[229.14, 370.9, 229.14, 370.9, ...]],
# [[343.7, 139.85, 349.01, 138.46, ...]]], # 0-th instance
# [[[500.0, 390.62, ...],[115.48, 187.86, ...]]] # 1-th instance
#]
batch_masks = []
for semgs in batch_semgs:
gt_masks = []
for semg in semgs:
gt_segm = []
for polys in semg:
gt_segm.append(np.array(polys).reshape(-1, 2))
gt_masks.append(gt_segm)
batch_masks.append(gt_masks)
place = fluid.CPUPlace()
feeder = fluid.DataFeeder(place=place, feed_list=feeds)
feeder.feed(batch_masks)
Args:
im_info(Variable): A 2-D Tensor with shape [N, 3]. N is the batch size,
each element is [height, width, scale] of image. Image scale is
target_size) / original_size.
gt_classes(Variable): A 2-D LoDTensor with shape [M, 1]. M is the total
number of ground-truth, each element is a class label.
is_crowd(Variable): A 2-D LoDTensor with shape as gt_classes,
each element is a flag indicating whether a groundtruth is crowd.
gt_segms(Variable): This input is a 2D LoDTensor with shape [S, 2],
it's LoD level is 3. Usually users do not needs to understand LoD,
The users should return correct data format in reader.
The LoD[0] represents the gt objects number of
each instance. LoD[1] represents the segmentation counts of each
objects. LoD[2] represents the polygons number of each segmentation.
S the total number of polygons coordinate points. Each element is
(x, y) coordinate points.
rois(Variable): A 2-D LoDTensor with shape [R, 4]. R is the total
number of RoIs, each element is a bounding box with
(xmin, ymin, xmax, ymax) format in the range of original image.
labels_int32(Variable): A 2-D LoDTensor in shape of [R, 1] with type
of int32. R is the same as it in `rois`. Each element repersents
a class label of a RoI.
num_classes(int): Class number.
resolution(int): Resolution of mask predictions.
Returns:
mask_rois (Variable): A 2D LoDTensor with shape [P, 4]. P is the total
number of sampled RoIs. Each element is a bounding box with
[xmin, ymin, xmax, ymax] format in range of orignal image size.
mask_rois_has_mask_int32 (Variable): A 2D LoDTensor with shape [P, 1],
each element repersents the output mask RoI index with regard to
to input RoIs.
mask_int32 (Variable): A 2D LoDTensor with shape [P, K * M * M],
K is the classes number and M is the resolution of mask predictions.
Each element repersents the binary mask targets.
Examples:
.. code-block:: python
im_info = fluid.layers.data(name="im_info", shape=[3],
dtype="float32")
gt_classes = fluid.layers.data(name="gt_classes", shape=[1],
dtype="float32", lod_level=1)
is_crowd = fluid.layers.data(name="is_crowd", shape=[1],
dtype="float32", lod_level=1)
gt_masks = fluid.layers.data(name="gt_masks", shape=[2],
dtype="float32", lod_level=3)
# rois, labels_int32 can be the output of
# fluid.layers.generate_proposal_labels.
mask_rois, mask_index, mask_int32 = fluid.layers.generate_mask_labels(
im_info=im_info,
gt_classes=gt_classes,
is_crowd=is_crowd,
gt_segms=gt_masks,
rois=rois,
labels_int32=labels_int32,
num_classes=81,
resolution=14)
"""
helper = LayerHelper('generate_mask_labels', **locals())
mask_rois = helper.create_variable_for_type_inference(dtype=rois.dtype)
roi_has_mask_int32 = helper.create_variable_for_type_inference(
dtype=gt_classes.dtype)
mask_int32 = helper.create_variable_for_type_inference(
dtype=gt_classes.dtype)
helper.append_op(
type="generate_mask_labels",
inputs={
'ImInfo': im_info,
'GtClasses': gt_classes,
'IsCrowd': is_crowd,
'GtSegms': gt_segms,
'Rois': rois,
'LabelsInt32': labels_int32
},
outputs={
'MaskRois': mask_rois,
'RoiHasMaskInt32': roi_has_mask_int32,
'MaskInt32': mask_int32
},
attrs={'num_classes': num_classes,
'resolution': resolution})
mask_rois.stop_gradient = True
roi_has_mask_int32.stop_gradient = True
mask_int32.stop_gradient = True
return mask_rois, roi_has_mask_int32, mask_int32
def generate_proposals(scores,
bbox_deltas,
im_info,
......@@ -1754,33 +1889,48 @@ def generate_proposals(scores,
"""
**Generate proposal Faster-RCNN**
This operation proposes RoIs according to each box with their probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores to be an object are the output of RPN. Final proposals
This operation proposes RoIs according to each box with their
probability to be a foreground object and
the box can be calculated by anchors. Bbox_deltais and scores
to be an object are the output of RPN. Final proposals
could be used to train detection net.
For generating proposals, this operation performs following steps:
1. Transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4)
1. Transposes and resizes scores and bbox_deltas in size of
(H*W*A, 1) and (H*W*A, 4)
2. Calculate box locations as proposals candidates.
3. Clip boxes to image
4. Remove predicted boxes with small area.
5. Apply NMS to get final proposals as output.
Args:
scores(Variable): A 4-D Tensor with shape [N, A, H, W] represents the probability for each box to be an object.
N is batch size, A is number of anchors, H and W are height and width of the feature map.
bbox_deltas(Variable): A 4-D Tensor with shape [N, 4*A, H, W] represents the differece between predicted box locatoin and anchor location.
im_info(Variable): A 2-D Tensor with shape [N, 3] represents origin image information for N batch. Info contains height, width and scale
scores(Variable): A 4-D Tensor with shape [N, A, H, W] represents
the probability for each box to be an object.
N is batch size, A is number of anchors, H and W are height and
width of the feature map.
bbox_deltas(Variable): A 4-D Tensor with shape [N, 4*A, H, W]
represents the differece between predicted box locatoin and
anchor location.
im_info(Variable): A 2-D Tensor with shape [N, 3] represents origin
image information for N batch. Info contains height, width and scale
between origin image size and the size of feature map.
anchors(Variable): A 4-D Tensor represents the anchors with a layout of [H, W, A, 4]. H and W are height and width of the feature map,
num_anchors is the box count of each position. Each anchor is in (xmin, ymin, xmax, ymax) format an unnormalized.
variances(Variable): The expanded variances of anchors with a layout of [H, W, num_priors, 4]. Each variance is in (xcenter, ycenter, w, h) format.
pre_nms_top_n(float): Number of total bboxes to be kept per image before NMS. 6000 by default.
post_nms_top_n(float): Number of total bboxes to be kept per image after NMS. 1000 by default.
anchors(Variable): A 4-D Tensor represents the anchors with a layout
of [H, W, A, 4]. H and W are height and width of the feature map,
num_anchors is the box count of each position. Each anchor is
in (xmin, ymin, xmax, ymax) format an unnormalized.
variances(Variable): The expanded variances of anchors with a layout of
[H, W, num_priors, 4]. Each variance is in
(xcenter, ycenter, w, h) format.
pre_nms_top_n(float): Number of total bboxes to be kept per
image before NMS. 6000 by default.
post_nms_top_n(float): Number of total bboxes to be kept per
image after NMS. 1000 by default.
nms_thresh(float): Threshold in NMS, 0.5 by default.
min_size(float): Remove predicted boxes with either height or width < min_size. 0.1 by default.
eta(float): Apply in adaptive NMS, if adaptive threshold > 0.5, adaptive_threshold = adaptive_threshold * eta in each iteration.
min_size(float): Remove predicted boxes with either height or
width < min_size. 0.1 by default.
eta(float): Apply in adaptive NMS, if adaptive threshold > 0.5,
adaptive_threshold = adaptive_threshold * eta in each iteration.
"""
helper = LayerHelper('generate_proposals', **locals())
......
......@@ -8927,7 +8927,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
def sigmoid_cross_entropy_with_logits(x,
label,
ignore_index=kIgnoreIndex,
name=None):
name=None,
normalize=False):
"""
${comment}
......@@ -8936,9 +8937,25 @@ def sigmoid_cross_entropy_with_logits(x,
label(${label_type}): ${label_comment}
ignore_index(&{ignore_index}): ${ignore_index_comment}
name(basestring|None): Name of the output.
normalize(bool): If true, divide the output by the number of
targets != ignore_index.
Returns:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
input = fluid.layers.data(
name='data', shape=[10], dtype='float32')
label = fluid.layers.data(
name='data', shape=[10], dtype='float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=input,
label=label,
ignore_index=-1,
normalize=True) # or False
# loss = fluid.layers.reduce_sum(loss) # summation of loss
"""
helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals())
......@@ -8953,7 +8970,8 @@ def sigmoid_cross_entropy_with_logits(x,
type="sigmoid_cross_entropy_with_logits",
inputs={"X": x,
"Label": label},
attrs={"ignore_index": ignore_index},
attrs={"ignore_index": ignore_index,
'normalize': normalize},
outputs={"Out": out})
return out
......
......@@ -203,7 +203,7 @@ class TestGenerateProposalLabels(unittest.TestCase):
lod_level=1,
append_batch_size=False)
class_nums = 5
rois, labels_int32, bbox_targets, bbox_inside_weights, bbox_outside_weights = fluid.layers.generate_proposal_labels(
outs = fluid.layers.generate_proposal_labels(
rpn_rois=rpn_rois,
gt_classes=gt_classes,
is_crowd=is_crowd,
......@@ -216,6 +216,11 @@ class TestGenerateProposalLabels(unittest.TestCase):
bg_thresh_lo=0.0,
bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
class_nums=class_nums)
rois = outs[0]
labels_int32 = outs[1]
bbox_targets = outs[2]
bbox_inside_weights = outs[3]
bbox_outside_weights = outs[4]
assert rois.shape[1] == 4
assert rois.shape[0] == labels_int32.shape[0]
assert rois.shape[0] == bbox_targets.shape[0]
......@@ -226,6 +231,62 @@ class TestGenerateProposalLabels(unittest.TestCase):
assert bbox_outside_weights.shape[1] == 4 * class_nums
class TestGenerateMaskLabels(unittest.TestCase):
def test_generate_mask_labels(self):
program = Program()
with program_guard(program):
im_info = layers.data(
name='im_info',
shape=[1, 3],
dtype='float32',
lod_level=1,
append_batch_size=False)
gt_classes = layers.data(
name='gt_classes',
shape=[2, 1],
dtype='int32',
lod_level=1,
append_batch_size=False)
is_crowd = layers.data(
name='is_crowd',
shape=[2, 1],
dtype='int32',
lod_level=1,
append_batch_size=False)
gt_segms = layers.data(
name='gt_segms',
shape=[20, 2],
dtype='float32',
lod_level=3,
append_batch_size=False)
rois = layers.data(
name='rois',
shape=[4, 4],
dtype='float32',
lod_level=1,
append_batch_size=False)
labels_int32 = layers.data(
name='labels_int32',
shape=[4, 1],
dtype='int32',
lod_level=1,
append_batch_size=False)
num_classes = 5
resolution = 14
outs = fluid.layers.generate_mask_labels(
im_info=im_info,
gt_classes=gt_classes,
is_crowd=is_crowd,
gt_segms=gt_segms,
rois=rois,
labels_int32=labels_int32,
num_classes=num_classes,
resolution=resolution)
mask_rois, roi_has_mask_int32, mask_int32 = outs
assert mask_rois.shape[1] == 4
assert mask_int32.shape[1] == num_classes * resolution * resolution
class TestMultiBoxHead(unittest.TestCase):
def test_multi_box_head(self):
data_shape = [3, 224, 224]
......@@ -313,7 +374,7 @@ class TestRpnTargetAssign(unittest.TestCase):
name='gt_boxes', shape=[4], lod_level=1, dtype='float32')
is_crowd = layers.data(
name='is_crowd',
shape=[10],
shape=[1, 10],
dtype='int32',
lod_level=1,
append_batch_size=False)
......@@ -323,7 +384,7 @@ class TestRpnTargetAssign(unittest.TestCase):
dtype='float32',
lod_level=1,
append_batch_size=False)
pred_scores, pred_loc, tgt_lbl, tgt_bbox, bbox_inside_weight = layers.rpn_target_assign(
outs = layers.rpn_target_assign(
bbox_pred=bbox_pred,
cls_logits=cls_logits,
anchor_box=anchor_box,
......@@ -337,6 +398,11 @@ class TestRpnTargetAssign(unittest.TestCase):
rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3,
use_random=False)
pred_scores = outs[0]
pred_loc = outs[1]
tgt_lbl = outs[2]
tgt_bbox = outs[3]
bbox_inside_weight = outs[4]
self.assertIsNotNone(pred_scores)
self.assertIsNotNone(pred_loc)
......@@ -351,41 +417,43 @@ class TestRpnTargetAssign(unittest.TestCase):
class TestGenerateProposals(unittest.TestCase):
def test_generate_proposals(self):
data_shape = [20, 64, 64]
images = fluid.layers.data(
name='images', shape=data_shape, dtype='float32')
im_info = fluid.layers.data(
name='im_info', shape=[1, 3], dtype='float32')
anchors, variances = fluid.layers.anchor_generator(
name='anchor_generator',
input=images,
anchor_sizes=[32, 64],
aspect_ratios=[1.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
num_anchors = anchors.shape[2]
scores = fluid.layers.data(
name='scores', shape=[1, num_anchors, 8, 8], dtype='float32')
bbox_deltas = fluid.layers.data(
name='bbox_deltas',
shape=[1, num_anchors * 4, 8, 8],
dtype='float32')
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
name='generate_proposals',
scores=scores,
bbox_deltas=bbox_deltas,
im_info=im_info,
anchors=anchors,
variances=variances,
pre_nms_top_n=6000,
post_nms_top_n=1000,
nms_thresh=0.5,
min_size=0.1,
eta=1.0)
self.assertIsNotNone(rpn_rois)
self.assertIsNotNone(rpn_roi_probs)
print(rpn_rois.shape)
program = Program()
with program_guard(program):
data_shape = [20, 64, 64]
images = fluid.layers.data(
name='images', shape=data_shape, dtype='float32')
im_info = fluid.layers.data(
name='im_info', shape=[3], dtype='float32')
anchors, variances = fluid.layers.anchor_generator(
name='anchor_generator',
input=images,
anchor_sizes=[32, 64],
aspect_ratios=[1.0],
variance=[0.1, 0.1, 0.2, 0.2],
stride=[16.0, 16.0],
offset=0.5)
num_anchors = anchors.shape[2]
scores = fluid.layers.data(
name='scores', shape=[num_anchors, 8, 8], dtype='float32')
bbox_deltas = fluid.layers.data(
name='bbox_deltas',
shape=[num_anchors * 4, 8, 8],
dtype='float32')
rpn_rois, rpn_roi_probs = fluid.layers.generate_proposals(
name='generate_proposals',
scores=scores,
bbox_deltas=bbox_deltas,
im_info=im_info,
anchors=anchors,
variances=variances,
pre_nms_top_n=6000,
post_nms_top_n=1000,
nms_thresh=0.5,
min_size=0.1,
eta=1.0)
self.assertIsNotNone(rpn_rois)
self.assertIsNotNone(rpn_roi_probs)
print(rpn_rois.shape)
class TestYoloDetection(unittest.TestCase):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
import math
import six
import paddle.fluid as fluid
from op_test import OpTest
'''
# Equivalent code
rles = mask_util.frPyObjects([segm], im_h, im_w)
mask = mask_util.decode(rles)
'''
def decode(cnts, m):
v = 0
mask = []
for j in range(m):
for k in range(cnts[j]):
mask.append(v)
v = 1 - v
return mask
def poly2mask(xy, k, h, w):
scale = 5.
x = [int(scale * p + 0.5) for p in xy[::2]]
x = x + [x[0]]
y = [int(scale * p + 0.5) for p in xy[1::2]]
y = y + [y[0]]
m = sum([
int(max(abs(x[j] - x[j + 1]), abs(y[j] - y[j + 1]))) + int(1)
for j in range(k)
])
u, v = [], []
for j in range(k):
xs = x[j]
xe = x[j + 1]
ys = y[j]
ye = y[j + 1]
dx = abs(xe - xs)
dy = abs(ys - ye)
flip = (dx >= dy and xs > xe) or (dx < dy and ys > ye)
if flip:
xs, xe = xe, xs
ys, ye = ye, ys
if dx >= dy:
if (dx == 0): assert ye - ys == 0
s = 0 if dx == 0 else float(ye - ys) / dx
else:
if (dy == 0): assert xe - xs == 0
s = 0 if dy == 0 else float(xe - xs) / dy
if dx >= dy:
ts = [dx - d if flip else d for d in range(dx + 1)]
u.extend([xs + t for t in ts])
v.extend([int(ys + s * t + .5) for t in ts])
else:
ts = [dy - d if flip else d for d in range(dy + 1)]
v.extend([t + ys for t in ts])
u.extend([int(xs + s * t + .5) for t in ts])
k = len(u)
x = np.zeros((k), np.int)
y = np.zeros((k), np.int)
m = 0
for j in six.moves.xrange(1, k):
if u[j] != u[j - 1]:
xd = float(u[j] if (u[j] < u[j - 1]) else (u[j] - 1))
xd = (xd + .5) / scale - .5
if (math.floor(xd) != xd or xd < 0 or xd > (w - 1)):
continue
yd = float(v[j] if v[j] < v[j - 1] else v[j - 1])
yd = (yd + .5) / scale - .5
yd = math.ceil(0 if yd < 0 else (h if yd > h else yd))
x[m] = int(xd)
y[m] = int(yd)
m += 1
k = m
a = [int(x[i] * h + y[i]) for i in range(k)]
a.append(h * w)
a.sort()
b = [0] + a[:len(a) - 1]
a = [c - d for (c, d) in zip(a, b)]
k += 1
b = [0 for i in range(k)]
b[0] = a[0]
m, j = 1, 1
while (j < k):
if a[j] > 0:
b[m] = a[j]
m += 1
j += 1
else:
j += 1
if (j < k):
b[m - 1] += a[j]
j += 1
mask = decode(b, m)
mask = np.array(mask, dtype=np.int).reshape((w, h))
mask = mask.transpose((1, 0))
return mask
def polys_to_boxes(polys):
"""Convert a list of polygons into an array of tight bounding boxes."""
boxes_from_polys = np.zeros((len(polys), 4), dtype=np.float32)
for i in range(len(polys)):
poly = polys[i]
x0 = min(min(p[::2]) for p in poly)
x1 = max(max(p[::2]) for p in poly)
y0 = min(min(p[1::2]) for p in poly)
y1 = max(max(p[1::2]) for p in poly)
boxes_from_polys[i, :] = [x0, y0, x1, y1]
return boxes_from_polys
def bbox_overlaps(boxes, query_boxes):
N = boxes.shape[0]
K = query_boxes.shape[0]
overlaps = np.zeros((N, K), dtype=boxes.dtype)
for k in range(K):
box_area = (query_boxes[k, 2] - query_boxes[k, 0] + 1) *\
(query_boxes[k, 3] - query_boxes[k, 1] + 1)
for n in range(N):
iw = min(boxes[n, 2], query_boxes[k, 2]) -\
max(boxes[n, 0], query_boxes[k, 0]) + 1
if iw > 0:
ih = min(boxes[n, 3], query_boxes[k, 3]) -\
max(boxes[n, 1], query_boxes[k, 1]) + 1
if ih > 0:
ua = float(
(boxes[n, 2] - boxes[n, 0] + 1) *\
(boxes[n, 3] - boxes[n, 1] + 1) +\
box_area - iw * ih)
overlaps[n, k] = iw * ih / ua
return overlaps
def polys_to_mask_wrt_box(polygons, box, M):
"""Convert from the COCO polygon segmentation format to a binary mask
encoded as a 2D array of data type numpy.float32. The polygon segmentation
is understood to be enclosed in the given box and rasterized to an M x M
mask. The resulting mask is therefore of shape (M, M).
"""
w = box[2] - box[0]
h = box[3] - box[1]
w = np.maximum(w, 1)
h = np.maximum(h, 1)
polygons_norm = []
for poly in polygons:
p = np.array(poly, dtype=np.float32)
p[0::2] = (p[0::2] - box[0]) * M / w
p[1::2] = (p[1::2] - box[1]) * M / h
polygons_norm.append(p)
mask = []
for polygons in polygons_norm:
assert polygons.shape[0] % 2 == 0
k = polygons.shape[0] // 2
mask.append(poly2mask(polygons, k, M, M))
mask = np.array(mask)
# Flatten in case polygons was a list
mask = np.sum(mask, axis=0)
mask = np.array(mask > 0, dtype=np.float32)
return mask
def expand_mask_targets(masks, mask_class_labels, resolution, num_classes):
"""Expand masks from shape (#masks, resolution ** 2)
to (#masks, #classes * resolution ** 2) to encode class
specific mask targets.
"""
assert masks.shape[0] == mask_class_labels.shape[0]
# Target values of -1 are "don't care" / ignore labels
mask_targets = -np.ones(
(masks.shape[0], num_classes * resolution**2), dtype=np.int32)
for i in range(masks.shape[0]):
cls = int(mask_class_labels[i])
start = resolution**2 * cls
end = start + resolution**2
# Ignore background instance
# (only happens when there is no fg samples in an image)
if cls > 0:
mask_targets[i, start:end] = masks[i, :]
return mask_targets
def generate_mask_labels(num_classes, im_info, gt_classes, is_crowd,
label_int32, gt_polys, resolution, rois, roi_lod,
gt_lod):
mask_rois = []
roi_has_mask_int32 = []
mask_int32 = []
new_lod = []
for i in range(len(im_info)):
roi_s = roi_lod[i]
roi_e = roi_lod[i + 1]
gt_s = gt_lod[i]
gt_e = gt_lod[i + 1]
mask_blob = _sample_mask(num_classes, im_info[i], gt_classes[gt_s:gt_e],
is_crowd[gt_s:gt_e], label_int32[roi_s:roi_e],
gt_polys[i], resolution, rois[roi_s:roi_e])
new_lod.append(mask_blob['mask_rois'].shape[0])
mask_rois.append(mask_blob['mask_rois'])
roi_has_mask_int32.append(mask_blob['roi_has_mask_int32'])
mask_int32.append(mask_blob['mask_int32'])
return mask_rois, roi_has_mask_int32, mask_int32, new_lod
def _sample_mask(
num_classes,
im_info,
gt_classes,
is_crowd,
label_int32,
gt_polys, # [[[], []], []]
resolution,
rois):
mask_blob = {}
im_scale = im_info[2]
sample_boxes = rois
polys_gt_inds = np.where((gt_classes > 0) & (is_crowd == 0))[0]
polys_gt = [gt_polys[i] for i in polys_gt_inds]
boxes_from_polys = polys_to_boxes(polys_gt)
fg_inds = np.where(label_int32 > 0)[0]
roi_has_mask = fg_inds.copy()
if fg_inds.shape[0] > 0:
mask_class_labels = label_int32[fg_inds]
masks = np.zeros((fg_inds.shape[0], resolution**2), dtype=np.int32)
rois_fg = sample_boxes[fg_inds]
overlaps_bbfg_bbpolys = bbox_overlaps(
rois_fg.astype(np.float32), boxes_from_polys.astype(np.float32))
fg_polys_inds = np.argmax(overlaps_bbfg_bbpolys, axis=1)
for i in range(rois_fg.shape[0]):
fg_polys_ind = fg_polys_inds[i]
poly_gt = polys_gt[fg_polys_ind]
roi_fg = rois_fg[i]
mask = polys_to_mask_wrt_box(poly_gt, roi_fg, resolution)
mask = np.array(mask > 0, dtype=np.int32)
masks[i, :] = np.reshape(mask, resolution**2)
else:
bg_inds = np.where(label_int32 == 0)[0]
rois_fg = sample_boxes[bg_inds[0]].reshape((1, -1))
masks = -np.ones((1, resolution**2), dtype=np.int32)
mask_class_labels = np.zeros((1, ))
roi_has_mask = np.append(roi_has_mask, 0)
masks = expand_mask_targets(masks, mask_class_labels, resolution,
num_classes)
rois_fg *= im_scale
mask_blob['mask_rois'] = rois_fg
mask_blob['roi_has_mask_int32'] = roi_has_mask
mask_blob['mask_int32'] = masks
return mask_blob
def trans_lod(lod):
new_lod = [0]
for i in range(len(lod)):
new_lod.append(lod[i] + new_lod[i])
return new_lod
class TestGenerateMaskLabels(OpTest):
def set_data(self):
self.init_test_case()
self.make_generate_proposal_labels_out()
self.generate_gt_polys()
self.generate_groundtruth()
self.init_test_output()
self.inputs = {
'ImInfo': self.im_info,
'GtClasses': (self.gt_classes.astype(np.int32), self.gt_lod),
'IsCrowd': (self.is_crowd.astype(np.int32), self.gt_lod),
'LabelsInt32': (self.label_int32.astype(np.int32), self.rois_lod),
'GtSegms': (self.gt_polys.astype(np.float32), self.masks_lod),
'Rois': (self.rois.astype(np.float32), self.rois_lod)
}
self.attrs = {
'num_classes': self.num_classes,
'resolution': self.resolution
}
self.outputs = {
'MaskRois': (self.mask_rois, [self.new_lod]),
'RoiHasMaskInt32': (self.roi_has_mask_int32, [self.new_lod]),
'MaskInt32': (self.mask_int32, [self.new_lod])
}
def init_test_case(self):
self.num_classes = 81
self.resolution = 14
self.batch_size = 2
self.batch_size_per_im = 64
self.images_shape = [100, 200]
np.random.seed(0)
def make_generate_proposal_labels_out(self):
rois = []
self.rois_lod = [[]]
self.label_int32 = []
for bno in range(self.batch_size):
self.rois_lod[0].append(self.batch_size_per_im)
for i in range(self.batch_size_per_im):
xywh = np.random.rand(4)
xy1 = xywh[0:2] * 2
wh = xywh[2:4] * (self.images_shape[0] - xy1)
xy2 = xy1 + wh
roi = [xy1[0], xy1[1], xy2[0], xy2[1]]
rois.append(roi)
self.rois = np.array(rois).astype("float32")
for idx, roi_num in enumerate(self.rois_lod[0]):
for roi_id in range(roi_num):
class_id = np.random.random_integers(self.num_classes - 1)
if idx == 0:
# set an image with no foreground, to test the empty case
self.label_int32.append(0)
else:
self.label_int32.append(class_id)
label_np = np.array(self.label_int32)
self.label_int32 = label_np[:, np.newaxis]
def generate_gt_polys(self):
h, w = self.images_shape[0:2]
self.gt_polys = []
self.gt_polys_list = []
max_gt = 4
max_poly_num = 5
min_poly_size = 4
max_poly_size = 16
lod0 = []
lod1 = []
lod2 = []
for i in range(self.batch_size):
gt_num = np.random.randint(1, high=max_gt, size=1)[0]
lod0.append(gt_num)
ptss = []
for i in range(gt_num):
poly_num = np.random.randint(1, max_poly_num, size=1)[0]
lod1.append(poly_num)
pts = []
for j in range(poly_num):
poly_size = np.random.randint(
min_poly_size, max_poly_size, size=1)[0]
x = np.random.rand(poly_size, 1) * w
y = np.random.rand(poly_size, 1) * h
xy = np.concatenate((x, y), axis=1)
pts.append(xy.flatten().tolist())
self.gt_polys.extend(xy.flatten().tolist())
lod2.append(poly_size)
ptss.append(pts)
self.gt_polys_list.append(ptss)
self.masks_lod = [lod0, lod1, lod2]
self.gt_lod = [lod0]
self.gt_polys = np.array(self.gt_polys).astype('float32').reshape(-1, 2)
def generate_groundtruth(self):
self.im_info = []
self.gt_classes = []
self.is_crowd = []
for roi_num in self.gt_lod[0]:
self.im_info.append(self.images_shape + [1.0])
for roi_id in range(roi_num):
class_id = np.random.random_integers(self.num_classes - 1)
self.gt_classes.append(class_id)
self.is_crowd.append(0)
self.im_info = np.array(self.im_info).astype(np.float32)
gt_classes_np = np.array(self.gt_classes)
self.gt_classes = gt_classes_np[:, np.newaxis]
is_crowd_np = np.array(self.is_crowd)
self.is_crowd = is_crowd_np[:, np.newaxis]
def init_test_output(self):
roi_lod = trans_lod(self.rois_lod[0])
gt_lod = trans_lod(self.gt_lod[0])
outs = generate_mask_labels(self.num_classes, self.im_info,
self.gt_classes, self.is_crowd,
self.label_int32, self.gt_polys_list,
self.resolution, self.rois, roi_lod, gt_lod)
self.mask_rois = outs[0]
self.roi_has_mask_int32 = outs[1]
self.mask_int32 = outs[2]
self.new_lod = outs[3]
self.mask_rois = np.vstack(self.mask_rois)
self.roi_has_mask_int32 = np.hstack(self.roi_has_mask_int32)[:,
np.newaxis]
self.mask_int32 = np.vstack(self.mask_int32)
def setUp(self):
self.op_type = "generate_mask_labels"
self.set_data()
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
......@@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://w_idxw.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
......
......@@ -4,7 +4,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://w_idxw.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
......
......@@ -18,6 +18,7 @@ import numpy as np
from op_test import OpTest
from scipy.special import logit
from scipy.special import expit
import paddle.fluid.core as core
import unittest
......@@ -117,5 +118,36 @@ class TestSigmoidCrossEntropyWithLogitsOp3(OpTest):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithNorm(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = 64
num_classes = 20
ignore_index = -1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, (batch_size, num_classes))
.astype("float32")),
'Label': np.random.randint(-1, 2, (batch_size, num_classes))
.astype("float32")
}
self.attrs = {'ignore_index': ignore_index, 'normalize': True}
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
out = -term1 - term2
out[np.where(self.inputs['Label'] == ignore_index)] = 0
if self.attrs['normalize']:
out = out / float(
np.where(self.inputs['Label'] != ignore_index)[0].size)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册