提交 25df7e53 编写于 作者: L Liangliang He

Merge branch 'face-rfcn' into 'master'

Add new ops Proposal , PSROIAlign and Refactor reshape cpu kernel.

See merge request !321
...@@ -89,6 +89,8 @@ extern void Register_Reshape(OperatorRegistry *op_registry); ...@@ -89,6 +89,8 @@ extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry);
} // namespace ops } // namespace ops
...@@ -118,6 +120,8 @@ OperatorRegistry::OperatorRegistry() { ...@@ -118,6 +120,8 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Eltwise(this); ops::Register_Eltwise(this);
ops::Register_FullyConnected(this); ops::Register_FullyConnected(this);
ops::Register_Slice(this); ops::Register_Slice(this);
ops::Register_Proposal(this);
ops::Register_PSROIAlign(this);
} }
} // namespace mace } // namespace mace
...@@ -178,6 +178,18 @@ class Tensor { ...@@ -178,6 +178,18 @@ class Tensor {
} }
} }
inline void ResizeWithBuffer(const std::vector<index_t> &shape,
BufferBase *buffer) {
MACE_CHECK(!has_opencl_image(), "Cannot resize image, use ResizeImage.");
shape_ = shape;
image_shape_.clear();
if (buffer_ != nullptr && is_buffer_owner_) {
delete buffer_;
}
buffer_ = buffer;
is_buffer_owner_ = false;
}
inline void ResizeImage(const std::vector<index_t> &shape, inline void ResizeImage(const std::vector<index_t> &shape,
const std::vector<size_t> &image_shape) { const std::vector<size_t> &image_shape) {
shape_ = shape; shape_ = shape;
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_PROPOSAL_H_
#define MACE_KERNELS_PROPOSAL_H_
#include <algorithm>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
static std::vector<float> WHCenters(const std::vector<float> &anchor) {
// width, height, width_center, height_center
std::vector<float> window(4);
window[0] = anchor[2] - anchor[0] + 1;
window[1] = anchor[3] - anchor[1] + 1;
window[2] = anchor[0] + (window[0] - 1) / 2;
window[3] = anchor[1] + (window[1] - 1) / 2;
return window;
}
std::vector<std::vector<float>> GenerateAnchors(
const std::vector<int> &scales,
const std::vector<float> &ratios,
const int base_size) {
const std::vector<float> base_anchor =
{0, 0,
static_cast<float>(base_size-1),
static_cast<float>(base_size-1)};
const size_t scales_size = scales.size();
const size_t ratios_size = ratios.size();
// get height, width, centers
std::vector<float> base_window = WHCenters(base_anchor);
const float size = base_window[0] * base_window[1];
std::vector<std::vector<float>> anchors(scales_size * ratios_size,
std::vector<float>(4));
#pragma omp parallel for
for (size_t ratio_idx = 0; ratio_idx < ratios_size; ++ratio_idx) {
float ws = ::roundf(::sqrtf(size / ratios[ratio_idx]));
float hs = ::roundf(ws * ratios[ratio_idx]);
std::vector<float> tmp_anchor(4);
tmp_anchor[0] = base_window[2] - (ws - 1) / 2;
tmp_anchor[1] = base_window[3] - (hs - 1) / 2;
tmp_anchor[2] = base_window[2] + (ws - 1) / 2;
tmp_anchor[3] = base_window[3] + (hs - 1) / 2;
auto window = WHCenters(tmp_anchor);
for (size_t scale_idx = 0; scale_idx < scales_size; ++scale_idx) {
const size_t idx = ratio_idx * scales_size + scale_idx;
ws = window[0] * scales[scale_idx];
hs = window[1] * scales[scale_idx];
anchors[idx][0] = window[2] - (ws - 1) / 2;
anchors[idx][1] = window[3] - (hs - 1) / 2;
anchors[idx][2] = window[2] + (ws - 1) / 2;
anchors[idx][3] = window[3] + (hs - 1) / 2;
}
}
return anchors;
}
std::vector<int> nms(const float *bboxes_ptr,
const index_t num_bboxes,
const float thresh,
const int post_nms_top_n) {
std::vector<int> keep;
std::vector<int> suppressed(num_bboxes, 0);
std::vector<float> areas(num_bboxes, 0);
for (index_t i = 0; i < num_bboxes; ++i) {
const index_t idx = (i << 2);
areas[i] = (bboxes_ptr[idx + 2] - bboxes_ptr[idx] + 1) *
(bboxes_ptr[idx + 3] - bboxes_ptr[idx + 1] + 1);
}
for (int i = 0; i < num_bboxes; ++i) {
if (suppressed[i] == 1) continue;
keep.push_back(i);
if (keep.size() >= post_nms_top_n) break;
int coord_idx = i << 2;
const float x1 = bboxes_ptr[coord_idx];
const float y1 = bboxes_ptr[coord_idx + 1];
const float x2 = bboxes_ptr[coord_idx + 2];
const float y2 = bboxes_ptr[coord_idx + 3];
const float area1 = areas[i];
for (int j = i + 1; j < num_bboxes; ++j) {
if (suppressed[j] == 1) continue;
coord_idx = j << 2;
const float iou =
std::max<float>(0.0,
std::min(x2, bboxes_ptr[coord_idx + 2]) -
std::max(x1, bboxes_ptr[coord_idx]) + 1)
* std::max<float>(0.0,
std::min(y2, bboxes_ptr[coord_idx + 3]) -
std::max(y1, bboxes_ptr[coord_idx + 1]) + 1);
if ((iou / (area1 + areas[j] - iou)) >= thresh) {
suppressed[j] = 1;
}
}
}
return keep;
}
template<DeviceType D, typename T>
struct ProposalFunctor {
ProposalFunctor(const int min_size,
const float nms_thresh,
const int pre_nms_top_n,
const int post_nms_top_n,
const int feat_stride,
const int base_size,
const std::vector<int> &scales,
const std::vector<float> &ratios) :
min_size_(min_size),
thresh_(nms_thresh),
pre_nms_top_n_(pre_nms_top_n),
post_nms_top_n_(post_nms_top_n),
feat_stride_(feat_stride),
anchors_(GenerateAnchors(scales, ratios, base_size)) {}
void operator()(const Tensor *rpn_cls_prob,
const Tensor *rpn_bbox_pred,
const Tensor *img_info_tensor,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(rpn_cls_prob->dim(1) == rpn_bbox_pred->dim(1) &&
rpn_cls_prob->dim(2) == rpn_bbox_pred->dim(2));
MACE_CHECK((rpn_cls_prob->dim(3) / 2 == rpn_bbox_pred->dim(3) / 4) &&
(rpn_cls_prob->dim(3) / 2 == anchors_.size()));
const float *img_info = img_info_tensor->data<float>();
const int im_height = static_cast<int>(img_info[0] - 1);
const int im_width = static_cast<int>(img_info[1] - 1);
const index_t feat_height = rpn_cls_prob->dim(1);
const index_t feat_width = rpn_cls_prob->dim(2);
const int anchors_size = anchors_.size();
// shift anchors to original input
std::vector<std::vector<float>> proposals(
anchors_size * feat_height * feat_width,
std::vector<float>(4));
#pragma omp parallel for collapse(3)
for (int h_idx = 0; h_idx < feat_height; ++h_idx) {
for (int w_idx = 0; w_idx < feat_width; ++w_idx) {
for (int a_idx = 0; a_idx < anchors_size; ++a_idx) {
const int shift_h = h_idx * feat_stride_;
const int shift_w = w_idx * feat_stride_;
const index_t sanc_idx = (h_idx * feat_width + w_idx) * anchors_size
+ a_idx;
proposals[sanc_idx][0] = anchors_[a_idx][0] + shift_w;
proposals[sanc_idx][1] = anchors_[a_idx][1] + shift_h;
proposals[sanc_idx][2] = anchors_[a_idx][2] + shift_w;
proposals[sanc_idx][3] = anchors_[a_idx][3] + shift_h;
}
}
}
// Convert anchors into proposals via bbox transformations
// 2. clip predicted boxes to image
const float *bbox_deltas = rpn_bbox_pred->data<float>();
#pragma omp parallel for collapse(3)
for (int h_idx = 0; h_idx < feat_height; ++h_idx) {
for (int w_idx = 0; w_idx < feat_width; ++w_idx) {
for (int a_idx = 0; a_idx < anchors_size; ++a_idx) {
const int sanc_idx = (h_idx * feat_width + w_idx) * anchors_size
+ a_idx;
const float width = proposals[sanc_idx][2] -
proposals[sanc_idx][0] + 1;
const float height = proposals[sanc_idx][3] -
proposals[sanc_idx][1] + 1;
int delta_offset = sanc_idx * 4;
float pred_ctr_x = bbox_deltas[delta_offset + 0] * width +
(proposals[sanc_idx][0] + width / 2);
float pred_ctr_y = bbox_deltas[delta_offset + 1] * height +
(proposals[sanc_idx][1] + height / 2);
float pred_w = std::exp(bbox_deltas[delta_offset + 2]) * width;
float pred_h = std::exp(bbox_deltas[delta_offset + 3]) * height;
proposals[sanc_idx][0] = std::max<float>(
std::min<float>(pred_ctr_x - pred_w / 2, im_width),
0);
proposals[sanc_idx][1] = std::max<float>(
std::min<float>(pred_ctr_y - pred_h / 2, im_height),
0);
proposals[sanc_idx][2] = std::max<float>(
std::min<float>(pred_ctr_x + pred_w / 2, im_width),
0);
proposals[sanc_idx][3] = std::max<float>(
std::min<float>(pred_ctr_y + pred_h / 2, im_height),
0);
}
}
}
// 3. remove predicted boxes with either height or width < threshold
// (NOTE: convert min_size to input image scale stored in im_info[2])
std::vector<int> keep;
const float min_size = min_size_ * img_info[2];
for (int h_idx = 0; h_idx < feat_height; ++h_idx) {
for (int w_idx = 0; w_idx < feat_width; ++w_idx) {
for (int a_idx = 0; a_idx < anchors_size; ++a_idx) {
const int sanc_idx = (h_idx * feat_width + w_idx) * anchors_size
+ a_idx;
const float width = proposals[sanc_idx][2]
- proposals[sanc_idx][0] + 1;
const float height = proposals[sanc_idx][3]
- proposals[sanc_idx][1] + 1;
if (width >= min_size && height >= min_size) {
keep.push_back(sanc_idx);
}
}
}
}
// 4. sort all (proposal, score) pairs by score from highest to lowest
// 5. take top pre_nms_topN (e.g. 6000)
auto scores = rpn_cls_prob->data<float>();
const int scores_chan = static_cast<int>(rpn_cls_prob->dim(3));
auto score_idx_func = [&](int idx) -> int {
return (idx / anchors_size) * scores_chan +
(idx % anchors_size) + anchors_size;
};
std::sort(keep.begin(), keep.end(), [&](int left, int right) -> bool{
return scores[score_idx_func(left)] >
scores[score_idx_func(right)];
});
int size = std::min<int>(pre_nms_top_n_, keep.size());
std::vector<float> nms_scores(size, 0);
std::vector<float> nms_proposals((size << 2), 0);
#pragma omp parallel for
for (int i = 0; i < size; ++i) {
nms_scores[i] = scores[score_idx_func(keep[i])];
nms_proposals[i << 2] = proposals[keep[i]][0];
nms_proposals[(i << 2) + 1] = proposals[keep[i]][1];
nms_proposals[(i << 2) + 2] = proposals[keep[i]][2];
nms_proposals[(i << 2) + 3] = proposals[keep[i]][3];
}
/* 6. apply nms (e.g. threshold = 0.7)
7. take after_nms_topN (e.g. 300)
8. return the top proposals (-> RoIs top) */
auto nms_result = nms(nms_proposals.data(),
nms_scores.size(),
thresh_,
post_nms_top_n_);
// Output rois blob
// Our RPN implementation only supports a single input image, so all
// batch inds are 0
size = static_cast<int>(nms_result.size());
output->Resize({size, 1, 1, 5});
auto output_ptr = output->mutable_data<float>();
#pragma omp parallel for
for (int i = 0; i < size; ++i) {
const int out_idx = i * 5;
const int nms_idx = nms_result[i] * 4;
output_ptr[out_idx] = 0;
output_ptr[out_idx + 1] = nms_proposals[nms_idx];
output_ptr[out_idx + 2] = nms_proposals[nms_idx + 1];
output_ptr[out_idx + 3] = nms_proposals[nms_idx + 2];
output_ptr[out_idx + 4] = nms_proposals[nms_idx + 3];
}
}
const int min_size_;
const float thresh_;
const int pre_nms_top_n_;
const int post_nms_top_n_;
const int feat_stride_;
std::vector<std::vector<float>> anchors_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_PROPOSAL_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_PSROI_ALIGN_H_
#define MACE_KERNELS_PSROI_ALIGN_H_
#include <algorithm>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template<DeviceType D, typename T>
struct PSROIAlignFunctor {
PSROIAlignFunctor(const T spatial_scale,
const int output_dim,
const int group_size) :
spatial_scale_(spatial_scale),
output_dim_(output_dim),
group_size_(group_size) {}
void operator()(const Tensor *input,
const Tensor *rois,
Tensor *output,
StatsFuture *future) {
const int height = static_cast<int>(input->dim(1));
const int width = static_cast<int>(input->dim(2));
const int channels = static_cast<int>(input->dim(3));
const int pooled_height = group_size_;
const int pooled_width = group_size_;
const T *input_ptr = input->data<T>();
const T *rois_ptr = rois->data<T>();
// Number of ROIs
const int num_rois = rois->dim(0);
const int batch_size = input->dim(0);
output->Resize({num_rois, pooled_height, pooled_width, output_dim_});
T *output_ptr = output->mutable_data<T>();
for (int n = 0; n < num_rois; ++n) {
int roi_batch_ind = rois_ptr[0];
T roi_start_w =
static_cast<T>(rois_ptr[1]) * spatial_scale_;
T roi_start_h =
static_cast<T>(rois_ptr[2]) * spatial_scale_;
T roi_end_w =
static_cast<T>(rois_ptr[3] + 1.) * spatial_scale_;
T roi_end_h =
static_cast<T>(rois_ptr[4] + 1.) * spatial_scale_;
MACE_CHECK(roi_batch_ind >= 0);
MACE_CHECK(roi_batch_ind < batch_size);
// Force too small ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, static_cast<T>(0.1));
T roi_height = std::max(roi_end_h - roi_start_h, static_cast<T>(0.1));
// Compute w and h at bottom
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
const T *batch_data = input_ptr +
roi_batch_ind * height * width * channels;
std::vector<T> vhstart, vwstart, vhend, vwend;
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
T hstart = static_cast<T>(ph) * bin_size_h
+ roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w
+ roi_start_w;
T hend = static_cast<T>(ph + 1) * bin_size_h
+ roi_start_h;
T wend = static_cast<T>(pw + 1) * bin_size_w
+ roi_start_w;
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart, static_cast<T>(0.)),
static_cast<T>(height));
hend = std::min(std::max(hend, static_cast<T>(0.)),
static_cast<T>(height));
wstart = std::min(std::max(wstart, static_cast<T>(0.)),
static_cast<T>(width));
wend = std::min(std::max(wend, static_cast<T>(0.)),
static_cast<T>(width));
vhstart.push_back(hstart);
vwstart.push_back(wstart);
vhend.push_back(hend);
vwend.push_back(wend);
}
}
#pragma omp parallel for collapse(3)
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
for (int c = 0; c < output_dim_; ++c) {
const int pool_index = ph * pooled_width + pw;
const int out_idx = pool_index * output_dim_ + c;
const int in_chan_idx = (c * pooled_height + ph)
* pooled_width + pw;
T hstart = vhstart[pool_index];
T hend = vhend[pool_index];
T wstart = vwstart[pool_index];
T wend = vwend[pool_index];
bool is_empty = (hend <= hstart) || (wend <= wstart);
T out_sum = 0;
for (T h = hstart; h < hend; h += 1.) {
for (T w = wstart; w < wend; w += 1.) {
// Selecting four regular locations for bilinear interpolation
int x_left = std::floor(w);
int x_right = std::ceil(w);
int y_bottom = std::floor(h);
int y_top = std::ceil(h);
int top_left_index = (y_top * width + x_left)
* channels + in_chan_idx;
int top_right_index = (y_top * width + x_right)
* channels + in_chan_idx;
int bottom_left_index = (y_bottom * width + x_left)
* channels + in_chan_idx;
int bottom_right_index = (y_bottom * width + x_right)
* channels + in_chan_idx;
bool is_top_left_in = x_left >= 0 && x_left <= width - 1
&& y_top >= 0 && y_top <= height - 1;
bool is_top_right_in = x_right >= 0 && x_right <= width - 1
&& y_top >= 0 && y_top <= height - 1;
bool is_bottom_left_in = x_left >= 0 && x_left <= width - 1
&& y_bottom >= 0 && y_bottom <= height - 1;
bool is_bottom_right_in = x_right >= 0 && x_right <= width - 1
&& y_bottom >= 0 && y_bottom <= height - 1;
if (is_top_left_in) {
out_sum += (1 - w + x_left) * (1 - y_top + h)
* batch_data[top_left_index];
}
if (is_top_right_in) {
out_sum += (1 - x_right + w) * (1 - y_top + h)
* batch_data[top_right_index];
}
if (is_bottom_left_in) {
out_sum += (1 - w + x_left) * (1 - h + y_bottom)
* batch_data[bottom_left_index];
}
if (is_bottom_right_in) {
out_sum += (1 - x_right + w) * (1 - h + y_bottom)
* batch_data[bottom_right_index];
}
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output_ptr[out_idx] = is_empty ? 0. : out_sum / bin_area;
}
}
}
// Increment ROI data pointer
rois_ptr += 5;
output_ptr += pooled_height * pooled_width * output_dim_;
}
}
const T spatial_scale_;
const int output_dim_;
const int group_size_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_PSROI_ALIGN_H_
...@@ -21,9 +21,7 @@ struct ReshapeFunctor { ...@@ -21,9 +21,7 @@ struct ReshapeFunctor {
const std::vector<index_t> &out_shape, const std::vector<index_t> &out_shape,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
output->Resize(out_shape); output->ResizeWithBuffer(out_shape, input->UnderlyingBuffer());
// TODO(liuqi): copy on write to avoid this copy.
output->CopyBytes(input->raw_data(), input->size() * sizeof(T));
} }
}; };
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/proposal.h"
namespace mace {
namespace ops {
void Register_Proposal(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Proposal")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ProposalOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_PROPOSAL_H_
#define MACE_OPS_PROPOSAL_H_
#include "mace/core/operator.h"
#include "mace/kernels/proposal.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class ProposalOp : public Operator<D, T> {
public:
ProposalOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("min_size", 0),
OperatorBase::GetSingleArgument<float>("nms_thresh", 0),
OperatorBase::GetSingleArgument<int>("pre_nms_top_n", 0),
OperatorBase::GetSingleArgument<int>("post_nms_top_n", 0),
OperatorBase::GetSingleArgument<int>("feat_stride", 0),
OperatorBase::GetSingleArgument<int>("base_size", 16),
OperatorBase::GetRepeatedArgument<int>("scales"),
OperatorBase::GetRepeatedArgument<float>("ratios")) {}
bool Run(StatsFuture *future) override {
const Tensor *rpn_cls_prob = this->Input(RPN_CLS_PROB);
const Tensor *rpn_bbox_pred = this->Input(RPN_BBOX_PRED);
const Tensor *img_info = this->Input(IMG_INFO);
Tensor *output = this->Output(ROIS);
functor_(rpn_cls_prob, rpn_bbox_pred, img_info, output, future);
return true;
}
private:
kernels::ProposalFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(RPN_CLS_PROB, RPN_BBOX_PRED, IMG_INFO);
OP_OUTPUT_TAGS(ROIS);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_PROPOSAL_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ProposalOpTest : public OpsTestBase {};
TEST_F(ProposalOpTest, CPUSimple) {
const int img_height = 256;
const int img_width = 256;
const int height = 3;
const int width = 4;
OpsTestNet net;
OpDefBuilder("Proposal", "ProposalTest")
.Input("RpnCLSProb")
.Input("RpnBBoxPred")
.Input("ImgInfo")
.AddIntArg("min_size", 16)
.AddFloatArg("nms_thresh", 0.7)
.AddIntArg("pre_nms_top_n", 12000)
.AddIntArg("post_nms_top_n", 2000)
.AddIntArg("feat_stride", 16)
.AddIntArg("base_size", 16)
.AddIntsArg("scales", {8, 16, 32})
.AddFloatsArg("ratios", {0.5, 1, 2})
.Output("Output")
.Finalize(net.NewOperatorDef());
std::vector<float> scores(height * width * 18);
for (int i = 0 ; i < scores.size(); ++i) {
scores[i] = i;
}
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"RpnCLSProb", {1, height, width, 18}, scores);
net.AddRepeatedInput<DeviceType::CPU, float>(
"RpnBBoxPred", {1, height, width, 4 * 9}, 1);
net.AddInputFromArray<DeviceType::CPU, float>(
"ImgInfo", {1, 1, 1, 3}, {img_height, img_width, 2});
// Run
net.RunOp();
auto expected_tensor = CreateTensor<float>({1, 1, 1, 5}, {0, 0, 0, 255, 255});
ExpectTensorNear<float>(*expected_tensor, *net.GetTensor("Output"), 1e-5);
}
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/psroi_align.h"
namespace mace {
namespace ops {
void Register_PSROIAlign(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("PSROIAlign")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
PSROIAlignOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_PSROI_ALIGN_H_
#define MACE_OPS_PSROI_ALIGN_H_
#include "mace/core/operator.h"
#include "mace/kernels/psroi_align.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class PSROIAlignOp : public Operator<D, T> {
public:
PSROIAlignOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<T>("spatial_scale", 0),
OperatorBase::GetSingleArgument<int>("output_dim", 0),
OperatorBase::GetSingleArgument<int>("group_size", 0)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *rois = this->Input(ROIS);
Tensor *output = this->Output(OUTPUT);
functor_(input, rois, output, future);
return true;
}
private:
kernels::PSROIAlignFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, ROIS);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_PSROI_ALIGN_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册