未验证 提交 0e241384 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #13991 from qingqing01/refine_generate_proposals_op

Refine generate proposals op
...@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cmath>
#include <cstring>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -25,21 +27,17 @@ namespace operators { ...@@ -25,21 +27,17 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
struct AppendProposalsFunctor { static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
LoDTensor *out_;
int64_t offset_;
Tensor *to_add_;
AppendProposalsFunctor(LoDTensor *out, int64_t offset, Tensor *to_add) static void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) {
: out_(out), offset_(offset), to_add_(to_add) {} auto *out_data = dst->data<void>();
auto *to_add_data = src.data<void>();
template <typename T> size_t size_of_t = framework::SizeOfType(src.type());
void apply() const { offset *= size_of_t;
auto *out_data = out_->data<T>(); std::memcpy(
auto *to_add_data = to_add_->data<T>(); reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(out_data) + offset),
memcpy(out_data + offset_, to_add_data, to_add_->numel() * sizeof(T)); to_add_data, src.numel() * size_of_t);
} }
};
class GenerateProposalsOp : public framework::OperatorWithKernel { class GenerateProposalsOp : public framework::OperatorWithKernel {
public: public:
...@@ -75,8 +73,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -75,8 +73,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
}; };
template <class T> template <class T>
void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, static inline void BoxCoder(const platform::DeviceContext &ctx,
Tensor *bbox_deltas, Tensor *variances, Tensor *proposals) { Tensor *all_anchors, Tensor *bbox_deltas,
Tensor *variances, Tensor *proposals) {
T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace()); T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());
int64_t row = all_anchors->dims()[0]; int64_t row = all_anchors->dims()[0];
...@@ -108,11 +107,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -108,11 +107,11 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
anchor_center_y; anchor_center_y;
bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] * bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
bbox_deltas_data[i * len + 2], bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_width; anchor_width;
bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] * bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
bbox_deltas_data[i * len + 3], bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_height; anchor_height;
} else { } else {
bbox_center_x = bbox_center_x =
...@@ -120,10 +119,10 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -120,10 +119,10 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
bbox_center_y = bbox_center_y =
bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y; bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2], bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_width; anchor_width;
bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3], bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
std::log(1000.0 / 16.0))) * kBBoxClipDefault)) *
anchor_height; anchor_height;
} }
...@@ -136,30 +135,32 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors, ...@@ -136,30 +135,32 @@ void BoxCoder(const platform::DeviceContext &ctx, Tensor *all_anchors,
} }
template <class T> template <class T>
void ClipTiledBoxes(const platform::DeviceContext &ctx, const Tensor &im_info, static inline void ClipTiledBoxes(const platform::DeviceContext &ctx,
Tensor *boxes) { const Tensor &im_info, Tensor *boxes) {
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace()); T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
const T *im_info_data = im_info.data<T>(); const T *im_info_data = im_info.data<T>();
T zero(0);
for (int64_t i = 0; i < boxes->numel(); ++i) { for (int64_t i = 0; i < boxes->numel(); ++i) {
if (i % 4 == 0) { if (i % 4 == 0) {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else if (i % 4 == 1) { } else if (i % 4 == 1) {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} else if (i % 4 == 2) { } else if (i % 4 == 2) {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[1] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
} else { } else {
boxes_data[i] = boxes_data[i] =
std::max(std::min(boxes_data[i], im_info_data[0] - 1), 0.0f); std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
} }
} }
} }
template <class T> template <class T>
void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, static inline void FilterBoxes(const platform::DeviceContext &ctx,
float min_size, const Tensor &im_info, Tensor *keep) { Tensor *boxes, float min_size,
const Tensor &im_info, Tensor *keep) {
const T *im_info_data = im_info.data<T>(); const T *im_info_data = im_info.data<T>();
T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace()); T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
T im_scale = im_info_data[2]; T im_scale = im_info_data[2];
...@@ -185,24 +186,24 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes, ...@@ -185,24 +186,24 @@ void FilterBoxes(const platform::DeviceContext &ctx, Tensor *boxes,
keep->Resize({keep_len}); keep->Resize({keep_len});
} }
bool SortScorePairDescend(const std::pair<float, int> &pair1,
const std::pair<float, int> &pair2) {
return pair1.first > pair2.first;
}
template <class T> template <class T>
void GetMaxScoreIndex(const std::vector<T> &scores, static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
std::vector<std::pair<T, int>> *sorted_indices) { const std::vector<T> &scores) {
std::vector<std::pair<T, int>> sorted_indices;
sorted_indices.reserve(scores.size());
for (size_t i = 0; i < scores.size(); ++i) { for (size_t i = 0; i < scores.size(); ++i) {
sorted_indices->push_back(std::make_pair(scores[i], i)); sorted_indices.emplace_back(scores[i], i);
} }
// Sort the score pair according to the scores in descending order // Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(), std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
SortScorePairDescend); [](const std::pair<T, int> &a, const std::pair<T, int> &b) {
return a.first < b.first;
});
return sorted_indices;
} }
template <class T> template <class T>
T BBoxArea(const T *box, const bool normalized) { static inline T BBoxArea(const T *box, bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) { if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid // If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0. // (e.g. xmax < xmin or ymax < ymin), return 0.
...@@ -220,7 +221,7 @@ T BBoxArea(const T *box, const bool normalized) { ...@@ -220,7 +221,7 @@ T BBoxArea(const T *box, const bool normalized) {
} }
template <class T> template <class T>
T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) { box2[3] < box1[1]) {
return static_cast<T>(0.); return static_cast<T>(0.);
...@@ -229,8 +230,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { ...@@ -229,8 +230,8 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
const T inter_ymin = std::max(box1[1], box2[1]); const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]); const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]); const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = std::max(0.0f, inter_xmax - inter_xmin + 1); const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1);
const T inter_h = std::max(0.0f, inter_ymax - inter_ymin + 1); const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1);
const T inter_area = inter_w * inter_h; const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized); const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized); const T bbox2_area = BBoxArea<T>(box2, normalized);
...@@ -238,9 +239,21 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) { ...@@ -238,9 +239,21 @@ T JaccardOverlap(const T *box1, const T *box2, const bool normalized) {
} }
} }
template <typename T>
static inline Tensor VectorToTensor(const std::vector<T> &selected_indices,
int selected_num) {
Tensor keep_nms;
keep_nms.Resize({selected_num});
auto *keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
}
template <class T> template <class T>
Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
const T nms_threshold, const float eta) { Tensor *scores, T nms_threshold, float eta) {
PADDLE_ENFORCE_NOT_NULL(bbox); PADDLE_ENFORCE_NOT_NULL(bbox);
int64_t num_boxes = bbox->dims()[0]; int64_t num_boxes = bbox->dims()[0];
// 4: [xmin ymin xmax ymax] // 4: [xmin ymin xmax ymax]
...@@ -248,20 +261,18 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, ...@@ -248,20 +261,18 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
std::vector<T> scores_data(num_boxes); std::vector<T> scores_data(num_boxes);
std::copy_n(scores->data<T>(), num_boxes, scores_data.begin()); std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices; std::vector<std::pair<T, int>> sorted_indices =
GetMaxScoreIndex<T>(scores_data, &sorted_indices); GetSortedScoreIndex<T>(scores_data);
std::vector<int> selected_indices; std::vector<int> selected_indices;
int selected_num = 0; int selected_num = 0;
T adaptive_threshold = nms_threshold; T adaptive_threshold = nms_threshold;
const T *bbox_data = bbox->data<T>(); const T *bbox_data = bbox->data<T>();
bool flag;
while (sorted_indices.size() != 0) { while (sorted_indices.size() != 0) {
int idx = sorted_indices.front().second; int idx = sorted_indices.back().second;
flag = true; bool flag = true;
for (size_t k = 0; k < selected_indices.size(); ++k) { for (int kept_idx : selected_indices) {
if (flag) { if (flag) {
const int kept_idx = selected_indices[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size, T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, false); bbox_data + kept_idx * box_size, false);
flag = (overlap <= adaptive_threshold); flag = (overlap <= adaptive_threshold);
...@@ -271,32 +282,29 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores, ...@@ -271,32 +282,29 @@ Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox, Tensor *scores,
} }
if (flag) { if (flag) {
selected_indices.push_back(idx); selected_indices.push_back(idx);
selected_num++; ++selected_num;
} }
sorted_indices.erase(sorted_indices.begin()); sorted_indices.erase(sorted_indices.end());
if (flag && eta < 1 && adaptive_threshold > 0.5) { if (flag && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta; adaptive_threshold *= eta;
} }
} }
Tensor keep_nms; return VectorToTensor(selected_indices, selected_num);
keep_nms.Resize({selected_num});
int *keep_data = keep_nms.mutable_data<int>(ctx.GetPlace());
for (int i = 0; i < selected_num; ++i) {
keep_data[i] = selected_indices[i];
}
return keep_nms;
} }
template <typename DeviceContext, typename T> template <typename T>
class GenerateProposalsKernel : public framework::OpKernel<T> { class GenerateProposalsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors"); auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
auto *variances = context.Input<Tensor>("Variances"); "Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
...@@ -307,15 +315,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -307,15 +315,16 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
float min_size = context.Attr<float>("min_size"); float min_size = context.Attr<float>("min_size");
float eta = context.Attr<float>("eta"); float eta = context.Attr<float>("eta");
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
auto scores_dim = scores->dims(); auto &scores_dim = scores->dims();
int64_t num = scores_dim[0]; int64_t num = scores_dim[0];
int64_t c_score = scores_dim[1]; int64_t c_score = scores_dim[1];
int64_t h_score = scores_dim[2]; int64_t h_score = scores_dim[2];
int64_t w_score = scores_dim[3]; int64_t w_score = scores_dim[3];
auto bbox_dim = bbox_deltas->dims(); auto &bbox_dim = bbox_deltas->dims();
int64_t c_bbox = bbox_dim[1]; int64_t c_bbox = bbox_dim[1];
int64_t h_bbox = bbox_dim[2]; int64_t h_bbox = bbox_dim[2];
int64_t w_bbox = bbox_dim[3]; int64_t w_bbox = bbox_dim[3];
...@@ -330,17 +339,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -330,17 +339,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_swap.mutable_data<T>({num, h_score, w_score, c_score}, scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
dev_ctx.GetPlace()); dev_ctx.GetPlace());
math::Transpose<DeviceContext, T, 4> trans; math::Transpose<platform::CPUDeviceContext, T, 4> trans;
std::vector<int> axis = {0, 2, 3, 1}; std::vector<int> axis = {0, 2, 3, 1};
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis); trans(dev_ctx, *scores, &scores_swap, axis);
framework::LoD lod; framework::LoD lod;
std::vector<size_t> lod0(1, 0); lod.resize(1);
Tensor *anchor = const_cast<framework::Tensor *>(anchors); auto &lod0 = lod[0];
anchor->Resize({anchors->numel() / 4, 4}); lod0.push_back(0);
Tensor *var = const_cast<framework::Tensor *>(variances); anchors.Resize({anchors.numel() / 4, 4});
var->Resize({var->numel() / 4, 4}); variances.Resize({variances.numel() / 4, 4});
int64_t num_proposals = 0; int64_t num_proposals = 0;
for (int64_t i = 0; i < num; ++i) { for (int64_t i = 0; i < num; ++i) {
...@@ -352,24 +361,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -352,24 +361,17 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1}); scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> tensor_pair = std::pair<Tensor, Tensor> tensor_pair =
ProposalForOneImage(dev_ctx, im_info_slice, *anchor, *var, ProposalForOneImage(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n, bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta); post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = tensor_pair.first; Tensor &proposals = tensor_pair.first;
Tensor scores = tensor_pair.second; Tensor &scores = tensor_pair.second;
framework::VisitDataType(
framework::ToDataType(rpn_rois->type()),
AppendProposalsFunctor(rpn_rois, 4 * num_proposals, &proposals));
framework::VisitDataType(
framework::ToDataType(rpn_roi_probs->type()),
AppendProposalsFunctor(rpn_roi_probs, num_proposals, &scores));
AppendProposals(rpn_rois, 4 * num_proposals, proposals);
AppendProposals(rpn_roi_probs, num_proposals, scores);
num_proposals += proposals.dims()[0]; num_proposals += proposals.dims()[0];
lod0.emplace_back(num_proposals); lod0.push_back(num_proposals);
} }
lod.emplace_back(lod0);
rpn_rois->set_lod(lod); rpn_rois->set_lod(lod);
rpn_roi_probs->set_lod(lod); rpn_roi_probs->set_lod(lod);
rpn_rois->Resize({num_proposals, 4}); rpn_rois->Resize({num_proposals, 4});
...@@ -377,7 +379,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -377,7 +379,7 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
} }
std::pair<Tensor, Tensor> ProposalForOneImage( std::pair<Tensor, Tensor> ProposalForOneImage(
const DeviceContext &ctx, const Tensor &im_info_slice, const platform::CPUDeviceContext &ctx, const Tensor &im_info_slice,
const Tensor &anchors, const Tensor &variances, const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas_slice, // [M, 4] const Tensor &bbox_deltas_slice, // [M, 4]
const Tensor &scores_slice, // [N, 1] const Tensor &scores_slice, // [N, 1]
...@@ -392,10 +394,9 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -392,10 +394,9 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
for (int i = 0; i < scores_slice.numel(); ++i) { for (int i = 0; i < scores_slice.numel(); ++i) {
index[i] = i; index[i] = i;
} }
std::function<bool(const int64_t &, const int64_t &)> compare = auto compare = [scores_data](const int64_t &i, const int64_t &j) {
[scores_data](const int64_t &i, const int64_t &j) { return scores_data[i] > scores_data[j];
return scores_data[i] > scores_data[j]; };
};
if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) { if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
std::sort(index, index + scores_slice.numel(), compare); std::sort(index, index + scores_slice.numel(), compare);
...@@ -452,33 +453,45 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -452,33 +453,45 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker { class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Scores", "The scores of anchors should be foreground."); AddInput("Scores",
AddInput("BboxDeltas", "bbox_deltas."); "(Tensor) The scores from conv is in shape (N, A, H, W), "
AddInput("ImInfo", "Information for image reshape."); "N is batch size, A is number of anchors, "
AddInput("Anchors", "All anchors."); "H and W are height and width of the feature map");
AddInput("Variances", " variances"); AddInput("BboxDeltas",
"(Tensor) Bounding box deltas from conv is in "
AddOutput("RpnRois", "Anchors."); "shape (N, 4*A, H, W).");
AddOutput("RpnRoiProbs", "Anchors."); AddInput("ImInfo",
AddAttr<int>("pre_nms_topN", "pre_nms_topN"); "(Tensor) Information for image reshape is in shape (N, 3), "
AddAttr<int>("post_nms_topN", "post_nms_topN"); "in format (height, width, scale)");
AddAttr<float>("nms_thresh", "nms_thres"); AddInput("Anchors",
AddAttr<float>("min_size", "min size"); "(Tensor) Bounding box anchors from anchor_generator_op "
"is in shape (A, H, W, 4).");
AddInput("Variances",
"(Tensor) Bounding box variances with same shape as `Anchors`.");
AddOutput("RpnRois",
"(LoDTensor), Output proposals with shape (rois_num, 4).");
AddOutput("RpnRoiProbs",
"(LoDTensor) Scores of proposals with shape (rois_num, 1).");
AddAttr<int>("pre_nms_topN",
"Number of top scoring RPN proposals to keep before "
"applying NMS.");
AddAttr<int>("post_nms_topN",
"Number of top scoring RPN proposals to keep after "
"applying NMS");
AddAttr<float>("nms_thresh", "NMS threshold used on RPN proposals.");
AddAttr<float>("min_size",
"Proposal height and width both need to be greater "
"than this min_size.");
AddAttr<float>("eta", "The parameter for adaptive NMS."); AddAttr<float>("eta", "The parameter for adaptive NMS.");
AddComment(R"DOC( AddComment(R"DOC(
Generate Proposals OP This operator Generate bounding box proposals for Faster RCNN.
The propoasls are generated for a list of images based on image
This operator proposes rois according to each box with their probability to be a foreground object and score 'Scores', bounding box regression result 'BboxDeltas' as
the box can be calculated by anchors. Bbox_deltais and scores are the output of RPN. Final proposals well as predefined bounding box shapes 'anchors'. Greedy
could be used to train detection net. non-maximum suppression is applied to generate the final bounding
boxes.
Scores is the probability for each box to be an object. In format of (N, A, H, W) where N is batch size, A is number
of anchors, H and W are height and width of the feature map.
BboxDeltas is the differece between predicted box locatoin and anchor location. In format of (N, 4*A, H, W)
For generating proposals, this operator transposes and resizes scores and bbox_deltas in size of (H*W*A, 1) and (H*W*A, 4) and
calculate box locations as proposals candidates. Then clip boxes to image and remove predicted boxes with small area.
Finally, apply nms to get final proposals as output.
)DOC"); )DOC");
} }
}; };
...@@ -490,6 +503,5 @@ namespace ops = paddle::operators; ...@@ -490,6 +503,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp, REGISTER_OPERATOR(generate_proposals, ops::GenerateProposalsOp,
ops::GenerateProposalsOpMaker, ops::GenerateProposalsOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
generate_proposals, ops::GenerateProposalsKernel<double>);
ops::GenerateProposalsKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -16,10 +16,13 @@ limitations under the License. */ ...@@ -16,10 +16,13 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,36 +39,38 @@ namespace { ...@@ -36,36 +39,38 @@ namespace {
int const kThreadsPerBlock = sizeof(uint64_t) * 8; int const kThreadsPerBlock = sizeof(uint64_t) * 8;
template <typename T> static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
__global__ void RangeInitKernel(const T start, const T delta, const int size,
T *out) { struct RangeInitFunctor {
CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; } int start_;
} int delta_;
int *out_;
__device__ void operator()(size_t i) { out_[i] = start_ + i * delta_; }
};
template <typename T> template <typename T>
void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value, static void SortDescending(const platform::CUDADeviceContext &ctx,
Tensor *value_out, Tensor *index_out) { const Tensor &value, Tensor *value_out,
int num = value.numel(); Tensor *index_out) {
int num = static_cast<int>(value.numel());
Tensor index_in_t; Tensor index_in_t;
int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace()); int *idx_in = index_in_t.mutable_data<int>({num}, ctx.GetPlace());
int block = 512; platform::ForRange<platform::CUDADeviceContext> for_range(ctx, num);
auto stream = ctx.stream(); for_range(RangeInitFunctor{0, 1, idx_in});
RangeInitKernel<<<DIVUP(num, block), block, 0, stream>>>(0, 1, num, idx_in);
int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace()); int *idx_out = index_out->mutable_data<int>({num}, ctx.GetPlace());
const T *keys_in = value.data<T>(); const T *keys_in = value.data<T>();
T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace()); T *keys_out = value_out->mutable_data<T>({num}, ctx.GetPlace());
// Determine temporary device storage requirements // Determine temporary device storage requirements
void *d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num);
num);
// Allocate temporary storage // Allocate temporary storage
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
d_temp_storage = memory::Alloc(place, temp_storage_bytes); void *d_temp_storage = memory::Alloc(place, temp_storage_bytes);
// Run sorting operation // Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>( cub::DeviceRadixSort::SortPairsDescending<T, int>(
...@@ -76,22 +81,27 @@ void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value, ...@@ -76,22 +81,27 @@ void SortDescending(const platform::CUDADeviceContext &ctx, const Tensor &value,
} }
template <typename T> template <typename T>
__device__ __forceinline__ T Min(T x, T y) { struct BoxDecodeAndClipFunctor {
return x < y ? x : y; const T *anchor;
} const T *deltas;
const T *var;
template <typename T> const int *index;
__device__ __forceinline__ T Max(T x, T y) { const T *im_info;
return x > y ? x : y;
} T *proposals;
template <typename T> BoxDecodeAndClipFunctor(const T *anchor, const T *deltas, const T *var,
__global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas, const int *index, const T *im_info, T *proposals)
const T *var, const int *index, : anchor(anchor),
const T *im_info, const int num, deltas(deltas),
T *proposals) { var(var),
T kBBoxClipDefault = log(1000.0 / 16.0); index(index),
CUDA_1D_KERNEL_LOOP(i, num) { im_info(im_info),
proposals(proposals) {}
T bbox_clip_default{static_cast<T>(kBBoxClipDefault)};
__device__ void operator()(size_t i) {
int k = index[i] * 4; int k = index[i] * 4;
T axmin = anchor[k]; T axmin = anchor[k];
T aymin = anchor[k + 1]; T aymin = anchor[k + 1];
...@@ -108,17 +118,17 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas, ...@@ -108,17 +118,17 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T dxmax = deltas[k + 2]; T dxmax = deltas[k + 2];
T dymax = deltas[k + 3]; T dymax = deltas[k + 3];
T d_cx = 0., d_cy = 0., d_w = 0., d_h = 0.; T d_cx, d_cy, d_w, d_h;
if (var) { if (var) {
d_cx = cx + dxmin * w * var[k]; d_cx = cx + dxmin * w * var[k];
d_cy = cy + dymin * h * var[k + 1]; d_cy = cy + dymin * h * var[k + 1];
d_w = exp(Min<T>(dxmax * var[k + 2], kBBoxClipDefault)) * w; d_w = exp(Min(dxmax * var[k + 2], bbox_clip_default)) * w;
d_h = exp(Min<T>(dymax * var[k + 3], kBBoxClipDefault)) * h; d_h = exp(Min(dymax * var[k + 3], bbox_clip_default)) * h;
} else { } else {
d_cx = cx + dxmin * w; d_cx = cx + dxmin * w;
d_cy = cy + dymin * h; d_cy = cy + dymin * h;
d_w = exp(Min<T>(dxmax, kBBoxClipDefault)) * w; d_w = exp(Min(dxmax, bbox_clip_default)) * w;
d_h = exp(Min<T>(dymax, kBBoxClipDefault)) * h; d_h = exp(Min(dymax, bbox_clip_default)) * h;
} }
T oxmin = d_cx - d_w * 0.5; T oxmin = d_cx - d_w * 0.5;
...@@ -126,17 +136,21 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas, ...@@ -126,17 +136,21 @@ __global__ void BoxDecodeAndClipKernel(const T *anchor, const T *deltas,
T oxmax = d_cx + d_w * 0.5 - 1.; T oxmax = d_cx + d_w * 0.5 - 1.;
T oymax = d_cy + d_h * 0.5 - 1.; T oymax = d_cy + d_h * 0.5 - 1.;
proposals[i * 4] = Max<T>(Min<T>(oxmin, im_info[1] - 1.), 0.); proposals[i * 4] = Max(Min(oxmin, im_info[1] - 1.), 0.);
proposals[i * 4 + 1] = Max<T>(Min<T>(oymin, im_info[0] - 1.), 0.); proposals[i * 4 + 1] = Max(Min(oymin, im_info[0] - 1.), 0.);
proposals[i * 4 + 2] = Max<T>(Min<T>(oxmax, im_info[1] - 1.), 0.); proposals[i * 4 + 2] = Max(Min(oxmax, im_info[1] - 1.), 0.);
proposals[i * 4 + 3] = Max<T>(Min<T>(oymax, im_info[0] - 1.), 0.); proposals[i * 4 + 3] = Max(Min(oymax, im_info[0] - 1.), 0.);
} }
}
__device__ __forceinline__ T Min(T a, T b) const { return a > b ? b : a; }
__device__ __forceinline__ T Max(T a, T b) const { return a > b ? a : b; }
};
template <typename T, int BlockSize> template <typename T, int BlockSize>
__global__ void FilterBBoxes(const T *bboxes, const T *im_info, static __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
const T min_size, const int num, int *keep_num, const T min_size, const int num,
int *keep) { int *keep_num, int *keep) {
T im_h = im_info[0]; T im_h = im_info[0];
T im_w = im_info[1]; T im_w = im_info[1];
T im_scale = im_info[2]; T im_scale = im_info[2];
...@@ -181,7 +195,7 @@ __global__ void FilterBBoxes(const T *bboxes, const T *im_info, ...@@ -181,7 +195,7 @@ __global__ void FilterBBoxes(const T *bboxes, const T *im_info,
} }
} }
__device__ inline float IoU(const float *a, const float *b) { static __device__ inline float IoU(const float *a, const float *b) {
float left = max(a[0], b[0]), right = min(a[2], b[2]); float left = max(a[0], b[0]), right = min(a[2], b[2]);
float top = max(a[1], b[1]), bottom = min(a[3], b[3]); float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
...@@ -191,8 +205,9 @@ __device__ inline float IoU(const float *a, const float *b) { ...@@ -191,8 +205,9 @@ __device__ inline float IoU(const float *a, const float *b) {
return inter_s / (s_a + s_b - inter_s); return inter_s / (s_a + s_b - inter_s);
} }
__global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh, static __global__ void NMSKernel(const int n_boxes,
const float *dev_boxes, uint64_t *dev_mask) { const float nms_overlap_thresh,
const float *dev_boxes, uint64_t *dev_mask) {
const int row_start = blockIdx.y; const int row_start = blockIdx.y;
const int col_start = blockIdx.x; const int col_start = blockIdx.x;
...@@ -234,9 +249,9 @@ __global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh, ...@@ -234,9 +249,9 @@ __global__ void NMSKernel(const int n_boxes, const float nms_overlap_thresh,
} }
template <typename T> template <typename T>
void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, static void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const Tensor &sorted_indices, const T nms_threshold, const Tensor &sorted_indices, const T nms_threshold,
Tensor *keep_out) { Tensor *keep_out) {
int boxes_num = proposals.dims()[0]; int boxes_num = proposals.dims()[0];
PADDLE_ENFORCE_EQ(boxes_num, sorted_indices.dims()[0]); PADDLE_ENFORCE_EQ(boxes_num, sorted_indices.dims()[0]);
...@@ -247,13 +262,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, ...@@ -247,13 +262,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
const T *boxes = proposals.data<T>(); const T *boxes = proposals.data<T>();
auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace()); auto place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
int size_bytes = boxes_num * col_blocks * sizeof(uint64_t); framework::Vector<uint64_t> mask(boxes_num * col_blocks);
uint64_t *d_mask = NMSKernel<<<blocks, threads>>>(
reinterpret_cast<uint64_t *>(memory::Alloc(place, size_bytes)); boxes_num, nms_threshold, boxes,
NMSKernel<<<blocks, threads>>>(boxes_num, nms_threshold, boxes, d_mask); mask.CUDAMutableData(boost::get<platform::CUDAPlace>(ctx.GetPlace())));
uint64_t *h_mask = reinterpret_cast<uint64_t *>(
memory::Alloc(platform::CPUPlace(), size_bytes));
memory::Copy(platform::CPUPlace(), h_mask, place, d_mask, size_bytes, 0);
std::vector<uint64_t> remv(col_blocks); std::vector<uint64_t> remv(col_blocks);
memset(&remv[0], 0, sizeof(uint64_t) * col_blocks); memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
...@@ -267,7 +279,7 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, ...@@ -267,7 +279,7 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
if (!(remv[nblock] & (1ULL << inblock))) { if (!(remv[nblock] & (1ULL << inblock))) {
++num_to_keep; ++num_to_keep;
keep_vec.push_back(i); keep_vec.push_back(i);
uint64_t *p = &h_mask[0] + i * col_blocks; uint64_t *p = &mask[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) { for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j]; remv[j] |= p[j];
} }
...@@ -276,12 +288,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals, ...@@ -276,12 +288,10 @@ void NMS(const platform::CUDADeviceContext &ctx, const Tensor &proposals,
int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace()); int *keep = keep_out->mutable_data<int>({num_to_keep}, ctx.GetPlace());
memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(), memory::Copy(place, keep, platform::CPUPlace(), keep_vec.data(),
sizeof(int) * num_to_keep, 0); sizeof(int) * num_to_keep, 0);
memory::Free(place, d_mask);
memory::Free(platform::CPUPlace(), h_mask);
} }
template <typename T> template <typename T>
std::pair<Tensor, Tensor> ProposalForOneImage( static std::pair<Tensor, Tensor> ProposalForOneImage(
const platform::CUDADeviceContext &ctx, const Tensor &im_info, const platform::CUDADeviceContext &ctx, const Tensor &im_info,
const Tensor &anchors, const Tensor &variances, const Tensor &anchors, const Tensor &variances,
const Tensor &bbox_deltas, // [M, 4] const Tensor &bbox_deltas, // [M, 4]
...@@ -300,18 +310,20 @@ std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -300,18 +310,20 @@ std::pair<Tensor, Tensor> ProposalForOneImage(
// 2. box decode and clipping // 2. box decode and clipping
Tensor proposals; Tensor proposals;
proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace()); proposals.mutable_data<T>({pre_nms_num, 4}, ctx.GetPlace());
int block = 512;
auto stream = ctx.stream(); {
BoxDecodeAndClipKernel<T><<<DIVUP(pre_nms_num, block), block, 0, stream>>>( platform::ForRange<platform::CUDADeviceContext> for_range(ctx, pre_nms_num);
anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(), for_range(BoxDecodeAndClipFunctor<T>{
index_sort.data<int>(), im_info.data<T>(), pre_nms_num, anchors.data<T>(), bbox_deltas.data<T>(), variances.data<T>(),
proposals.data<T>()); index_sort.data<int>(), im_info.data<T>(), proposals.data<T>()});
}
// 3. filter // 3. filter
Tensor keep_index, keep_num_t; Tensor keep_index, keep_num_t;
keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace()); keep_index.mutable_data<int>({pre_nms_num}, ctx.GetPlace());
keep_num_t.mutable_data<int>({1}, ctx.GetPlace()); keep_num_t.mutable_data<int>({1}, ctx.GetPlace());
min_size = std::max(min_size, 1.0f); min_size = std::max(min_size, 1.0f);
auto stream = ctx.stream();
FilterBBoxes<T, 512><<<1, 512, 0, stream>>>( FilterBBoxes<T, 512><<<1, 512, 0, stream>>>(
proposals.data<T>(), im_info.data<T>(), min_size, pre_nms_num, proposals.data<T>(), im_info.data<T>(), min_size, pre_nms_num,
keep_num_t.data<int>(), keep_index.data<int>()); keep_num_t.data<int>(), keep_index.data<int>());
...@@ -355,8 +367,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -355,8 +367,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto *anchors = context.Input<Tensor>("Anchors"); auto anchors = detail::Ref(context.Input<Tensor>("Anchors"),
auto *variances = context.Input<Tensor>("Variances"); "Cannot find input Anchors(%s) in scope",
context.Inputs("Anchors")[0]);
auto variances = detail::Ref(context.Input<Tensor>("Variances"),
"Cannot find input Variances(%s) in scope",
context.Inputs("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
...@@ -392,10 +408,8 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -392,10 +408,8 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis); trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
trans(dev_ctx, *scores, &scores_swap, axis); trans(dev_ctx, *scores, &scores_swap, axis);
Tensor *anchor = const_cast<framework::Tensor *>(anchors); anchors.Resize({anchors.numel() / 4, 4});
anchor->Resize({anchors->numel() / 4, 4}); variances.Resize({variances.numel() / 4, 4});
Tensor *var = const_cast<framework::Tensor *>(variances);
var->Resize({var->numel() / 4, 4});
rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4}, rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
context.GetPlace()); context.GetPlace());
...@@ -417,12 +431,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -417,12 +431,12 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
scores_slice.Resize({h_score * w_score * c_score, 1}); scores_slice.Resize({h_score * w_score * c_score, 1});
std::pair<Tensor, Tensor> box_score_pair = std::pair<Tensor, Tensor> box_score_pair =
ProposalForOneImage<T>(dev_ctx, im_info_slice, *anchor, *var, ProposalForOneImage<T>(dev_ctx, im_info_slice, anchors, variances,
bbox_deltas_slice, scores_slice, pre_nms_top_n, bbox_deltas_slice, scores_slice, pre_nms_top_n,
post_nms_top_n, nms_thresh, min_size, eta); post_nms_top_n, nms_thresh, min_size, eta);
Tensor proposals = box_score_pair.first; Tensor &proposals = box_score_pair.first;
Tensor scores = box_score_pair.second; Tensor &scores = box_score_pair.second;
memory::Copy(place, rpn_rois_data + num_proposals * 4, place, memory::Copy(place, rpn_rois_data + num_proposals * 4, place,
proposals.data<T>(), sizeof(T) * proposals.numel(), 0); proposals.data<T>(), sizeof(T) * proposals.numel(), 0);
......
...@@ -39,11 +39,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -39,11 +39,9 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace())); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D // check index of shape 1-D
PADDLE_ENFORCE(index.dims().size() == 1); PADDLE_ENFORCE(index.dims().size() == 1);
int index_size = index.dims()[0]; int64_t index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
const T* p_src = src.data<T>(); const T* p_src = src.data<T>();
const int* p_index = index.data<int>(); const int* p_index = index.data<int>();
...@@ -55,7 +53,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -55,7 +53,7 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) { for (int64_t i = 0; i < index_size; ++i) {
int index_ = p_index[i]; int index_ = p_index[i];
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册