diff --git a/mace/core/operator.cc b/mace/core/operator.cc index ad3c8e5820d469802cbe0c9cea3da2c12661c227..4c501759db2c9ad160cf6f0b8f111c087afbabc4 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -89,6 +89,8 @@ extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry); +extern void Register_Proposal(OperatorRegistry *op_registry); +extern void Register_PSROIAlign(OperatorRegistry *op_registry); } // namespace ops @@ -118,6 +120,8 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Eltwise(this); ops::Register_FullyConnected(this); ops::Register_Slice(this); + ops::Register_Proposal(this); + ops::Register_PSROIAlign(this); } } // namespace mace diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 1d7d5debf9dfefaeab59205d6de67d29867d2c35..53ac3c2e34e65f72d75a5aa518a48f0eeab3ed28 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -178,6 +178,18 @@ class Tensor { } } + inline void ResizeWithBuffer(const std::vector &shape, + BufferBase *buffer) { + MACE_CHECK(!has_opencl_image(), "Cannot resize image, use ResizeImage."); + shape_ = shape; + image_shape_.clear(); + if (buffer_ != nullptr && is_buffer_owner_) { + delete buffer_; + } + buffer_ = buffer; + is_buffer_owner_ = false; + } + inline void ResizeImage(const std::vector &shape, const std::vector &image_shape) { shape_ = shape; diff --git a/mace/kernels/proposal.h b/mace/kernels/proposal.h new file mode 100644 index 0000000000000000000000000000000000000000..d3afe966e2da1892f7fc3862fb1e49cbe14ea082 --- /dev/null +++ b/mace/kernels/proposal.h @@ -0,0 +1,284 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_PROPOSAL_H_ +#define MACE_KERNELS_PROPOSAL_H_ + +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +static std::vector WHCenters(const std::vector &anchor) { + // width, height, width_center, height_center + std::vector window(4); + window[0] = anchor[2] - anchor[0] + 1; + window[1] = anchor[3] - anchor[1] + 1; + window[2] = anchor[0] + (window[0] - 1) / 2; + window[3] = anchor[1] + (window[1] - 1) / 2; + return window; +} + +std::vector> GenerateAnchors( + const std::vector &scales, + const std::vector &ratios, + const int base_size) { + const std::vector base_anchor = + {0, 0, + static_cast(base_size-1), + static_cast(base_size-1)}; + + const size_t scales_size = scales.size(); + const size_t ratios_size = ratios.size(); + // get height, width, centers + std::vector base_window = WHCenters(base_anchor); + const float size = base_window[0] * base_window[1]; + std::vector> anchors(scales_size * ratios_size, + std::vector(4)); + +#pragma omp parallel for + for (size_t ratio_idx = 0; ratio_idx < ratios_size; ++ratio_idx) { + float ws = ::roundf(::sqrtf(size / ratios[ratio_idx])); + float hs = ::roundf(ws * ratios[ratio_idx]); + std::vector tmp_anchor(4); + tmp_anchor[0] = base_window[2] - (ws - 1) / 2; + tmp_anchor[1] = base_window[3] - (hs - 1) / 2; + tmp_anchor[2] = base_window[2] + (ws - 1) / 2; + tmp_anchor[3] = base_window[3] + (hs - 1) / 2; + auto window = WHCenters(tmp_anchor); + for (size_t scale_idx = 0; scale_idx < scales_size; ++scale_idx) { + const size_t idx = ratio_idx * scales_size + scale_idx; + ws = window[0] * scales[scale_idx]; + hs = window[1] * scales[scale_idx]; + anchors[idx][0] = window[2] - (ws - 1) / 2; + anchors[idx][1] = window[3] - (hs - 1) / 2; + anchors[idx][2] = window[2] + (ws - 1) / 2; + anchors[idx][3] = window[3] + (hs - 1) / 2; + } + } + return anchors; +} + +std::vector nms(const float *bboxes_ptr, + const index_t num_bboxes, + const float thresh, + const int post_nms_top_n) { + std::vector keep; + std::vector suppressed(num_bboxes, 0); + + std::vector areas(num_bboxes, 0); + for (index_t i = 0; i < num_bboxes; ++i) { + const index_t idx = (i << 2); + areas[i] = (bboxes_ptr[idx + 2] - bboxes_ptr[idx] + 1) * + (bboxes_ptr[idx + 3] - bboxes_ptr[idx + 1] + 1); + } + + for (int i = 0; i < num_bboxes; ++i) { + if (suppressed[i] == 1) continue; + keep.push_back(i); + if (keep.size() >= post_nms_top_n) break; + int coord_idx = i << 2; + const float x1 = bboxes_ptr[coord_idx]; + const float y1 = bboxes_ptr[coord_idx + 1]; + const float x2 = bboxes_ptr[coord_idx + 2]; + const float y2 = bboxes_ptr[coord_idx + 3]; + const float area1 = areas[i]; + for (int j = i + 1; j < num_bboxes; ++j) { + if (suppressed[j] == 1) continue; + + coord_idx = j << 2; + const float iou = + std::max(0.0, + std::min(x2, bboxes_ptr[coord_idx + 2]) - + std::max(x1, bboxes_ptr[coord_idx]) + 1) + * std::max(0.0, + std::min(y2, bboxes_ptr[coord_idx + 3]) - + std::max(y1, bboxes_ptr[coord_idx + 1]) + 1); + if ((iou / (area1 + areas[j] - iou)) >= thresh) { + suppressed[j] = 1; + } + } + } + return keep; +} + + +template +struct ProposalFunctor { + ProposalFunctor(const int min_size, + const float nms_thresh, + const int pre_nms_top_n, + const int post_nms_top_n, + const int feat_stride, + const int base_size, + const std::vector &scales, + const std::vector &ratios) : + min_size_(min_size), + thresh_(nms_thresh), + pre_nms_top_n_(pre_nms_top_n), + post_nms_top_n_(post_nms_top_n), + feat_stride_(feat_stride), + anchors_(GenerateAnchors(scales, ratios, base_size)) {} + + void operator()(const Tensor *rpn_cls_prob, + const Tensor *rpn_bbox_pred, + const Tensor *img_info_tensor, + Tensor *output, + StatsFuture *future) { + MACE_CHECK(rpn_cls_prob->dim(1) == rpn_bbox_pred->dim(1) && + rpn_cls_prob->dim(2) == rpn_bbox_pred->dim(2)); + MACE_CHECK((rpn_cls_prob->dim(3) / 2 == rpn_bbox_pred->dim(3) / 4) && + (rpn_cls_prob->dim(3) / 2 == anchors_.size())); + const float *img_info = img_info_tensor->data(); + const int im_height = static_cast(img_info[0] - 1); + const int im_width = static_cast(img_info[1] - 1); + const index_t feat_height = rpn_cls_prob->dim(1); + const index_t feat_width = rpn_cls_prob->dim(2); + const int anchors_size = anchors_.size(); + + // shift anchors to original input + std::vector> proposals( + anchors_size * feat_height * feat_width, + std::vector(4)); + +#pragma omp parallel for collapse(3) + for (int h_idx = 0; h_idx < feat_height; ++h_idx) { + for (int w_idx = 0; w_idx < feat_width; ++w_idx) { + for (int a_idx = 0; a_idx < anchors_size; ++a_idx) { + const int shift_h = h_idx * feat_stride_; + const int shift_w = w_idx * feat_stride_; + const index_t sanc_idx = (h_idx * feat_width + w_idx) * anchors_size + + a_idx; + proposals[sanc_idx][0] = anchors_[a_idx][0] + shift_w; + proposals[sanc_idx][1] = anchors_[a_idx][1] + shift_h; + proposals[sanc_idx][2] = anchors_[a_idx][2] + shift_w; + proposals[sanc_idx][3] = anchors_[a_idx][3] + shift_h; + } + } + } + // Convert anchors into proposals via bbox transformations + // 2. clip predicted boxes to image + const float *bbox_deltas = rpn_bbox_pred->data(); +#pragma omp parallel for collapse(3) + for (int h_idx = 0; h_idx < feat_height; ++h_idx) { + for (int w_idx = 0; w_idx < feat_width; ++w_idx) { + for (int a_idx = 0; a_idx < anchors_size; ++a_idx) { + const int sanc_idx = (h_idx * feat_width + w_idx) * anchors_size + + a_idx; + const float width = proposals[sanc_idx][2] - + proposals[sanc_idx][0] + 1; + const float height = proposals[sanc_idx][3] - + proposals[sanc_idx][1] + 1; + int delta_offset = sanc_idx * 4; + float pred_ctr_x = bbox_deltas[delta_offset + 0] * width + + (proposals[sanc_idx][0] + width / 2); + float pred_ctr_y = bbox_deltas[delta_offset + 1] * height + + (proposals[sanc_idx][1] + height / 2); + float pred_w = std::exp(bbox_deltas[delta_offset + 2]) * width; + float pred_h = std::exp(bbox_deltas[delta_offset + 3]) * height; + + proposals[sanc_idx][0] = std::max( + std::min(pred_ctr_x - pred_w / 2, im_width), + 0); + proposals[sanc_idx][1] = std::max( + std::min(pred_ctr_y - pred_h / 2, im_height), + 0); + proposals[sanc_idx][2] = std::max( + std::min(pred_ctr_x + pred_w / 2, im_width), + 0); + proposals[sanc_idx][3] = std::max( + std::min(pred_ctr_y + pred_h / 2, im_height), + 0); + } + } + } + // 3. remove predicted boxes with either height or width < threshold + // (NOTE: convert min_size to input image scale stored in im_info[2]) + std::vector keep; + const float min_size = min_size_ * img_info[2]; + for (int h_idx = 0; h_idx < feat_height; ++h_idx) { + for (int w_idx = 0; w_idx < feat_width; ++w_idx) { + for (int a_idx = 0; a_idx < anchors_size; ++a_idx) { + const int sanc_idx = (h_idx * feat_width + w_idx) * anchors_size + + a_idx; + const float width = proposals[sanc_idx][2] + - proposals[sanc_idx][0] + 1; + const float height = proposals[sanc_idx][3] + - proposals[sanc_idx][1] + 1; + if (width >= min_size && height >= min_size) { + keep.push_back(sanc_idx); + } + } + } + } + + // 4. sort all (proposal, score) pairs by score from highest to lowest + // 5. take top pre_nms_topN (e.g. 6000) + auto scores = rpn_cls_prob->data(); + const int scores_chan = static_cast(rpn_cls_prob->dim(3)); + + auto score_idx_func = [&](int idx) -> int { + return (idx / anchors_size) * scores_chan + + (idx % anchors_size) + anchors_size; + }; + std::sort(keep.begin(), keep.end(), [&](int left, int right) -> bool{ + return scores[score_idx_func(left)] > + scores[score_idx_func(right)]; + }); + + int size = std::min(pre_nms_top_n_, keep.size()); + std::vector nms_scores(size, 0); + std::vector nms_proposals((size << 2), 0); +#pragma omp parallel for + for (int i = 0; i < size; ++i) { + nms_scores[i] = scores[score_idx_func(keep[i])]; + nms_proposals[i << 2] = proposals[keep[i]][0]; + nms_proposals[(i << 2) + 1] = proposals[keep[i]][1]; + nms_proposals[(i << 2) + 2] = proposals[keep[i]][2]; + nms_proposals[(i << 2) + 3] = proposals[keep[i]][3]; + } + + /* 6. apply nms (e.g. threshold = 0.7) + 7. take after_nms_topN (e.g. 300) + 8. return the top proposals (-> RoIs top) */ + auto nms_result = nms(nms_proposals.data(), + nms_scores.size(), + thresh_, + post_nms_top_n_); + + // Output rois blob + // Our RPN implementation only supports a single input image, so all + // batch inds are 0 + size = static_cast(nms_result.size()); + output->Resize({size, 1, 1, 5}); + auto output_ptr = output->mutable_data(); +#pragma omp parallel for + for (int i = 0; i < size; ++i) { + const int out_idx = i * 5; + const int nms_idx = nms_result[i] * 4; + output_ptr[out_idx] = 0; + output_ptr[out_idx + 1] = nms_proposals[nms_idx]; + output_ptr[out_idx + 2] = nms_proposals[nms_idx + 1]; + output_ptr[out_idx + 3] = nms_proposals[nms_idx + 2]; + output_ptr[out_idx + 4] = nms_proposals[nms_idx + 3]; + } + } + + const int min_size_; + const float thresh_; + const int pre_nms_top_n_; + const int post_nms_top_n_; + const int feat_stride_; + std::vector> anchors_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_PROPOSAL_H_ diff --git a/mace/kernels/psroi_align.h b/mace/kernels/psroi_align.h new file mode 100644 index 0000000000000000000000000000000000000000..57e19c3c0a8fd297a46f9b47d2a2ff83e7dd0afe --- /dev/null +++ b/mace/kernels/psroi_align.h @@ -0,0 +1,178 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_PSROI_ALIGN_H_ +#define MACE_KERNELS_PSROI_ALIGN_H_ + +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct PSROIAlignFunctor { + PSROIAlignFunctor(const T spatial_scale, + const int output_dim, + const int group_size) : + spatial_scale_(spatial_scale), + output_dim_(output_dim), + group_size_(group_size) {} + + void operator()(const Tensor *input, + const Tensor *rois, + Tensor *output, + StatsFuture *future) { + const int height = static_cast(input->dim(1)); + const int width = static_cast(input->dim(2)); + const int channels = static_cast(input->dim(3)); + const int pooled_height = group_size_; + const int pooled_width = group_size_; + const T *input_ptr = input->data(); + const T *rois_ptr = rois->data(); + // Number of ROIs + const int num_rois = rois->dim(0); + const int batch_size = input->dim(0); + + output->Resize({num_rois, pooled_height, pooled_width, output_dim_}); + T *output_ptr = output->mutable_data(); + + for (int n = 0; n < num_rois; ++n) { + int roi_batch_ind = rois_ptr[0]; + T roi_start_w = + static_cast(rois_ptr[1]) * spatial_scale_; + T roi_start_h = + static_cast(rois_ptr[2]) * spatial_scale_; + T roi_end_w = + static_cast(rois_ptr[3] + 1.) * spatial_scale_; + T roi_end_h = + static_cast(rois_ptr[4] + 1.) * spatial_scale_; + MACE_CHECK(roi_batch_ind >= 0); + MACE_CHECK(roi_batch_ind < batch_size); + + // Force too small ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, static_cast(0.1)); + T roi_height = std::max(roi_end_h - roi_start_h, static_cast(0.1)); + + // Compute w and h at bottom + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + const T *batch_data = input_ptr + + roi_batch_ind * height * width * channels; + + std::vector vhstart, vwstart, vhend, vwend; + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + T hstart = static_cast(ph) * bin_size_h + + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + + roi_start_w; + T hend = static_cast(ph + 1) * bin_size_h + + roi_start_h; + T wend = static_cast(pw + 1) * bin_size_w + + roi_start_w; + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart, static_cast(0.)), + static_cast(height)); + hend = std::min(std::max(hend, static_cast(0.)), + static_cast(height)); + wstart = std::min(std::max(wstart, static_cast(0.)), + static_cast(width)); + wend = std::min(std::max(wend, static_cast(0.)), + static_cast(width)); + + vhstart.push_back(hstart); + vwstart.push_back(wstart); + vhend.push_back(hend); + vwend.push_back(wend); + } + } + +#pragma omp parallel for collapse(3) + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + for (int c = 0; c < output_dim_; ++c) { + const int pool_index = ph * pooled_width + pw; + const int out_idx = pool_index * output_dim_ + c; + const int in_chan_idx = (c * pooled_height + ph) + * pooled_width + pw; + T hstart = vhstart[pool_index]; + T hend = vhend[pool_index]; + T wstart = vwstart[pool_index]; + T wend = vwend[pool_index]; + bool is_empty = (hend <= hstart) || (wend <= wstart); + + T out_sum = 0; + for (T h = hstart; h < hend; h += 1.) { + for (T w = wstart; w < wend; w += 1.) { + // Selecting four regular locations for bilinear interpolation + int x_left = std::floor(w); + int x_right = std::ceil(w); + int y_bottom = std::floor(h); + int y_top = std::ceil(h); + + int top_left_index = (y_top * width + x_left) + * channels + in_chan_idx; + int top_right_index = (y_top * width + x_right) + * channels + in_chan_idx; + int bottom_left_index = (y_bottom * width + x_left) + * channels + in_chan_idx; + int bottom_right_index = (y_bottom * width + x_right) + * channels + in_chan_idx; + + bool is_top_left_in = x_left >= 0 && x_left <= width - 1 + && y_top >= 0 && y_top <= height - 1; + bool is_top_right_in = x_right >= 0 && x_right <= width - 1 + && y_top >= 0 && y_top <= height - 1; + bool is_bottom_left_in = x_left >= 0 && x_left <= width - 1 + && y_bottom >= 0 && y_bottom <= height - 1; + bool is_bottom_right_in = x_right >= 0 && x_right <= width - 1 + && y_bottom >= 0 && y_bottom <= height - 1; + + if (is_top_left_in) { + out_sum += (1 - w + x_left) * (1 - y_top + h) + * batch_data[top_left_index]; + } + if (is_top_right_in) { + out_sum += (1 - x_right + w) * (1 - y_top + h) + * batch_data[top_right_index]; + } + if (is_bottom_left_in) { + out_sum += (1 - w + x_left) * (1 - h + y_bottom) + * batch_data[bottom_left_index]; + } + if (is_bottom_right_in) { + out_sum += (1 - x_right + w) * (1 - h + y_bottom) + * batch_data[bottom_right_index]; + } + } + } + + T bin_area = (hend - hstart) * (wend - wstart); + output_ptr[out_idx] = is_empty ? 0. : out_sum / bin_area; + } + } + } + + // Increment ROI data pointer + rois_ptr += 5; + output_ptr += pooled_height * pooled_width * output_dim_; + } + } + + const T spatial_scale_; + const int output_dim_; + const int group_size_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_PSROI_ALIGN_H_ diff --git a/mace/kernels/reshape.h b/mace/kernels/reshape.h index 14e560789db709464400136116ba02d373207c65..ddcd0dba58c5241554623c67a884aca7cbe0c060 100644 --- a/mace/kernels/reshape.h +++ b/mace/kernels/reshape.h @@ -21,9 +21,7 @@ struct ReshapeFunctor { const std::vector &out_shape, Tensor *output, StatsFuture *future) { - output->Resize(out_shape); - // TODO(liuqi): copy on write to avoid this copy. - output->CopyBytes(input->raw_data(), input->size() * sizeof(T)); + output->ResizeWithBuffer(out_shape, input->UnderlyingBuffer()); } }; diff --git a/mace/ops/proposal.cc b/mace/ops/proposal.cc new file mode 100644 index 0000000000000000000000000000000000000000..853a4e5b36aae7bb426ac387d90d60a27f663c3a --- /dev/null +++ b/mace/ops/proposal.cc @@ -0,0 +1,19 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/proposal.h" + +namespace mace { +namespace ops { + +void Register_Proposal(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Proposal") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ProposalOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/proposal.h b/mace/ops/proposal.h new file mode 100644 index 0000000000000000000000000000000000000000..06dcc8a1b02b030a82e8bf5508421f0342decc46 --- /dev/null +++ b/mace/ops/proposal.h @@ -0,0 +1,50 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_PROPOSAL_H_ +#define MACE_OPS_PROPOSAL_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/proposal.h" + +namespace mace { +namespace ops { + +template +class ProposalOp : public Operator { + public: + ProposalOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetSingleArgument("min_size", 0), + OperatorBase::GetSingleArgument("nms_thresh", 0), + OperatorBase::GetSingleArgument("pre_nms_top_n", 0), + OperatorBase::GetSingleArgument("post_nms_top_n", 0), + OperatorBase::GetSingleArgument("feat_stride", 0), + OperatorBase::GetSingleArgument("base_size", 16), + OperatorBase::GetRepeatedArgument("scales"), + OperatorBase::GetRepeatedArgument("ratios")) {} + + bool Run(StatsFuture *future) override { + const Tensor *rpn_cls_prob = this->Input(RPN_CLS_PROB); + const Tensor *rpn_bbox_pred = this->Input(RPN_BBOX_PRED); + const Tensor *img_info = this->Input(IMG_INFO); + + Tensor *output = this->Output(ROIS); + + functor_(rpn_cls_prob, rpn_bbox_pred, img_info, output, future); + return true; + } + + private: + kernels::ProposalFunctor functor_; + + protected: + OP_INPUT_TAGS(RPN_CLS_PROB, RPN_BBOX_PRED, IMG_INFO); + OP_OUTPUT_TAGS(ROIS); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_PROPOSAL_H_ diff --git a/mace/ops/proposal_test.cc b/mace/ops/proposal_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..94203deb4b82bc1a1d2d96e3a122e6856f0ce109 --- /dev/null +++ b/mace/ops/proposal_test.cc @@ -0,0 +1,61 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ProposalOpTest : public OpsTestBase {}; + +TEST_F(ProposalOpTest, CPUSimple) { + const int img_height = 256; + const int img_width = 256; + const int height = 3; + const int width = 4; + + OpsTestNet net; + + OpDefBuilder("Proposal", "ProposalTest") + .Input("RpnCLSProb") + .Input("RpnBBoxPred") + .Input("ImgInfo") + .AddIntArg("min_size", 16) + .AddFloatArg("nms_thresh", 0.7) + .AddIntArg("pre_nms_top_n", 12000) + .AddIntArg("post_nms_top_n", 2000) + .AddIntArg("feat_stride", 16) + .AddIntArg("base_size", 16) + .AddIntsArg("scales", {8, 16, 32}) + .AddFloatsArg("ratios", {0.5, 1, 2}) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + std::vector scores(height * width * 18); + for (int i = 0 ; i < scores.size(); ++i) { + scores[i] = i; + } + + // Add input data + net.AddInputFromArray( + "RpnCLSProb", {1, height, width, 18}, scores); + net.AddRepeatedInput( + "RpnBBoxPred", {1, height, width, 4 * 9}, 1); + net.AddInputFromArray( + "ImgInfo", {1, 1, 1, 3}, {img_height, img_width, 2}); + + // Run + net.RunOp(); + + auto expected_tensor = CreateTensor({1, 1, 1, 5}, {0, 0, 0, 255, 255}); + + ExpectTensorNear(*expected_tensor, *net.GetTensor("Output"), 1e-5); +} + + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/psroi_align.cc b/mace/ops/psroi_align.cc new file mode 100644 index 0000000000000000000000000000000000000000..d2ba2b45327558beedf511ab8c170d4b9ea42188 --- /dev/null +++ b/mace/ops/psroi_align.cc @@ -0,0 +1,19 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/psroi_align.h" + +namespace mace { +namespace ops { + +void Register_PSROIAlign(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("PSROIAlign") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + PSROIAlignOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/psroi_align.h b/mace/ops/psroi_align.h new file mode 100644 index 0000000000000000000000000000000000000000..e134b4a43cb78022895ee62a7b3ccfde3c5e110d --- /dev/null +++ b/mace/ops/psroi_align.h @@ -0,0 +1,44 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_PSROI_ALIGN_H_ +#define MACE_OPS_PSROI_ALIGN_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/psroi_align.h" + +namespace mace { +namespace ops { + +template +class PSROIAlignOp : public Operator { + public: + PSROIAlignOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetSingleArgument("spatial_scale", 0), + OperatorBase::GetSingleArgument("output_dim", 0), + OperatorBase::GetSingleArgument("group_size", 0)) {} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *rois = this->Input(ROIS); + + Tensor *output = this->Output(OUTPUT); + + functor_(input, rois, output, future); + return true; + } + + private: + kernels::PSROIAlignFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, ROIS); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_PSROI_ALIGN_H_