提交 1e073dcc 编写于 作者: E eclipsess

modify kernel.cpp to arm_func.h

上级 90344f8c
......@@ -15,130 +15,21 @@ limitations under the License. */
#ifdef BOXCODER_OP
#include "operators/kernel/box_coder_kernel.h"
#include <cmath>
#include "operators/kernel/central-arm-func/box_coder_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
void EncodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var, T* output) {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
auto* target_box_data = target_box.data<T>();
auto* prior_box_data = prior_box.data<T>();
auto* prior_box_var_data = prior_box_var.data<T>();
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len];
T prior_box_height =
prior_box_data[j * len + 3] - prior_box_data[j * len + 1];
T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
T target_box_center_x =
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
T target_box_center_y =
(target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
T target_box_width =
target_box_data[i * len + 2] - target_box_data[i * len];
T target_box_height =
target_box_data[i * len + 3] - target_box_data[i * len + 1];
size_t offset = i * col * len + j * len;
output[offset] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[j * len];
output[offset + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height / prior_box_var_data[j * len + 1];
output[offset + 2] =
std::log(std::fabs(target_box_width / prior_box_width)) /
prior_box_var_data[j * len + 2];
output[offset + 3] =
std::log(std::fabs(target_box_height / prior_box_height)) /
prior_box_var_data[j * len + 3];
}
}
}
template <typename T>
void DecodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var, T* output) {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
auto* target_box_data = target_box.data<T>();
auto* prior_box_data = prior_box.data<T>();
auto* prior_box_var_data = prior_box_var.data<T>();
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
size_t offset = i * col * len + j * len;
T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len];
T prior_box_height =
prior_box_data[j * len + 3] - prior_box_data[j * len + 1];
T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
T target_box_center_x = prior_box_var_data[j * len] *
target_box_data[offset] * prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[j * len + 1] *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
T target_box_width = std::exp(prior_box_var_data[j * len + 2] *
target_box_data[offset + 2]) *
prior_box_width;
T target_box_height = std::exp(prior_box_var_data[j * len + 3] *
target_box_data[offset + 3]) *
prior_box_height;
output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2;
output[offset + 2] = target_box_center_x + target_box_width / 2;
output[offset + 3] = target_box_center_y + target_box_height / 2;
}
}
}
template <>
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam* param) {
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam *param) {
return true;
}
template <>
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam& param) const {
const auto* input_priorbox = param.InputPriorBox();
const auto* input_priorboxvar = param.InputPriorBoxVar();
const auto* input_targetbox = param.InputTargetBox();
const auto& code_type = param.CodeType();
auto row = input_targetbox->dims()[0];
auto col = input_priorbox->dims()[0];
auto len = input_priorbox->dims()[1];
Tensor* output_box = param.OutputBox();
auto* output_box_dataptr = output_box->mutable_data<float>({row, col, len});
if (code_type == "encode_center_size") {
EncodeCenterSize<float>(*input_targetbox, *input_priorbox,
*input_priorboxvar, output_box_dataptr);
}
if (code_type == "decode_center_size") {
DecodeCenterSize<float>(*input_targetbox, *input_priorbox,
*input_priorboxvar, output_box_dataptr);
}
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam &param) const {
BoxCoderCompute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -15,42 +15,10 @@ limitations under the License. */
#ifdef CONCAT_OP
#include "operators/kernel/concat_kernel.h"
#include "operators/kernel/central-arm-func/concat_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
class ConcatFunctor {
public:
void operator()(const std::vector<framework::Tensor> &input, const int axis,
framework::Tensor *output) {
size_t num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
for (int k = 0; k < out_rows; ++k) {
T *dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
const T *src_prt = input[j].data<T>() + k * col_len;
memory::Copy(dst_ptr + col_idx, src_prt, sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};
template <>
bool ConcatKernel<CPU, float>::Init(ConcatParam *param) {
......@@ -59,33 +27,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam *param) {
template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
auto inputs = param.Inputs();
auto *out = param.Out();
int64_t axis = param.Axis();
out->mutable_data<float>();
/// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && inputs.size() < 10) {
size_t output_offset = 0;
for (auto *in : inputs) {
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
auto dst = out->data<float>() + output_offset;
auto src = in->data<float>();
PADDLE_MOBILE_ENFORCE(
in_stride.size() == out_stride.size(),
"src and dst tensor should have the same dims size.");
memory::Copy(dst, src, sizeof(float) * in_stride[0]);
output_offset += in_stride[0];
}
} else {
std::vector<framework::Tensor> inputs_concat(inputs.size());
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = *inputs[j];
}
ConcatFunctor<float> concat_functor;
concat_functor(inputs_concat, static_cast<int>(axis), out);
}
ConcatCompute<float>(param);
}
} // namespace operators
......
......@@ -14,18 +14,12 @@ limitations under the License. */
#ifdef ELEMENTWISEADD_OP
#pragma once
#include "operators/kernel/elementwise_add_kernel.h"
#include "operators/kernel/central-arm-func/elementwise_add_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct AddFunctor {
inline T operator()(T a, T b) const { return a + b; }
};
template <>
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) {
return true;
......@@ -34,17 +28,9 @@ bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) {
template <>
void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam &param) const {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *Out = param.Out();
Out->mutable_data<float>();
int axis = param.Axis();
ElementwiseComputeEx<AddFunctor<float>, float>(input_x, input_y, axis,
AddFunctor<float>(), Out);
ElementwiseAddCompute<float>(param);
}
template class ElementwiseAddKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
......
......@@ -14,9 +14,8 @@ limitations under the License. */
#ifdef FUSION_FC_OP
#pragma once
#include "operators/kernel/fusion_fc_kernel.h"
#include "operators/kernel/central-arm-func/fusion_fc_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -28,46 +27,7 @@ bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) {
template <>
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
const Tensor *input_z = param.InputZ();
auto *input_z_data = input_z->data<float>();
int axis = param.Axis();
Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>();
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
: *input_x;
const Tensor y_matrix =
input_y->dims().size() > 2
? framework::ReshapeToMatrix(*input_y, param.YNumColDims())
: *input_y;
auto out_dim = out->dims();
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
PADDLE_MOBILE_ENFORCE(input_z->dims().size() == 1, "inpu_z size must be 1");
PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0],
" out_dim.size must be 2.");
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ")
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
for (int i = 0; i < out->numel(); i++) {
DLOG << out_data[i];
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1));
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
FusionFcCompute<float>(param);
}
} // namespace operators
......
......@@ -14,9 +14,8 @@ limitations under the License. */
#ifdef LRN_OP
#pragma once
#include "operators/kernel/lrn_kernel.h"
#include "operators/kernel/central-arm-func/lrn_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -28,26 +27,9 @@ bool LrnKernel<CPU, float>::Init(LrnParam *param) {
template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims();
Tensor *out = param.Out();
out->mutable_data<float>();
/// data_format = NCHW
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int n = param.N();
const float alpha = param.Alpha();
const float beta = param.Beta();
const float k = param.K();
LRNFunctor<float> lrnFunctor;
lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta);
LrnCompute<float>(param);
}
template class LrnKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
......
......@@ -14,9 +14,8 @@ limitations under the License. */
#ifdef MUL_OP
#pragma once
#include "operators/kernel/mul_kernel.h"
#include "operators/kernel/central-arm-func/mul_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -28,31 +27,9 @@ bool MulKernel<CPU, float>::Init(MulParam *param) {
template <>
void MulKernel<CPU, float>::Compute(const MulParam &param) const {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *out = param.Out();
out->mutable_data<float>();
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
: *input_x;
const Tensor y_matrix =
input_y->dims().size() > 2
? framework::ReshapeToMatrix(*input_y, param.YNumColDims())
: *input_y;
auto out_dim = out->dims();
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
if (out_dim.size() != 2) {
out->Resize(out_dim);
}
MulCompute<float>(param);
}
template class MulKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
......
......@@ -15,265 +15,20 @@ limitations under the License. */
#ifdef MULTICLASSNMS_OP
#include "operators/kernel/multiclass_nms_kernel.h"
#include <algorithm>
#include "operators/kernel/central-arm-func/multiclass_nms_arm_func.h"
namespace paddle_mobile {
namespace operators {
constexpr int kOutputDim = 6;
constexpr int kBBoxSize = 4;
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static inline void GetMaxScoreIndex(
const std::vector<T>& scores, const T threshold, int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static inline T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const T* box1, const T* box2,
const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin;
const T inter_h = inter_ymax - inter_ymin;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T>
static inline void NMSFast(const Tensor& bbox, const Tensor& scores,
const T score_threshold, const T nms_threshold,
const T eta, const int64_t top_k,
std::vector<int>* selected_indices) {
// The total boxes for each instance.
int64_t num_boxes = bbox.dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox.dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox.data<T>();
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, true);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <typename T>
void MultiClassNMS(const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>* indices, int* num_nmsed_out,
const int& background_label, const int& nms_top_k,
const int& keep_top_k, const T& nms_threshold,
const T& nms_eta, const T& score_threshold) {
int64_t class_num = scores.dims()[0];
int64_t predict_dim = scores.dims()[1];
int num_det = 0;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
Tensor score = scores.Slice(c, c + 1);
/// [c] is key
NMSFast<float>(bboxes, score, score_threshold, nms_threshold, nms_eta,
nms_top_k, &((*indices)[c]));
num_det += (*indices)[c].size();
}
*num_nmsed_out = num_det;
const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) {
int label = it.first;
const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
// PADDLE_ENFORCE_LT(idx, predict_dim);
score_index_pairs.push_back(
std::make_pair(sdata[idx], std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
new_indices.swap(*indices);
*num_nmsed_out = keep_top_k;
}
}
template <typename T>
void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
Tensor* outs) {
int predict_dim = scores.dims()[1];
auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->data<T>();
int count = 0;
for (const auto& it : selected_indices) {
/// one batch
int label = it.first;
const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& indices = it.second;
for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize;
odata[count * kOutputDim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score
// xmin, ymin, xmax, ymax
std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T));
count++;
}
}
}
template <>
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam* param) {
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam *param) {
return true;
}
template <>
void MultiClassNMSKernel<CPU, float>::Compute(
const MultiClassNMSParam& param) const {
const auto* input_bboxes = param.InputBBoxes();
const auto& input_bboxes_dims = input_bboxes->dims();
const auto* input_scores = param.InputScores();
const auto& input_scores_dims = input_scores->dims();
auto* outs = param.Out();
auto background_label = param.BackGroundLabel();
auto nms_top_k = param.NMSTopK();
auto keep_top_k = param.KeepTopK();
auto nms_threshold = param.NMSThreshold();
auto nms_eta = param.NMSEta();
auto score_threshold = param.ScoreThreshold();
int64_t batch_size = input_scores_dims[0];
int64_t class_num = input_scores_dims[1];
int64_t predict_dim = input_scores_dims[2];
int64_t box_dim = input_bboxes_dims[2];
std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<size_t> batch_starts = {0};
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
std::map<int, std::vector<int>> indices;
int num_nmsed_out = 0;
MultiClassNMS<float>(ins_score, ins_boxes, &indices, &num_nmsed_out,
background_label, nms_top_k, keep_top_k, nms_threshold,
nms_eta, score_threshold);
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
int num_kept = batch_starts.back();
if (num_kept == 0) {
float* od = outs->mutable_data<float>({1});
od[0] = -1;
} else {
outs->mutable_data<float>({num_kept, kOutputDim});
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
Tensor out = outs->Slice(s, e);
MultiClassOutput<float>(ins_score, ins_boxes, all_indices[i], &out);
}
}
}
// framework::LoD lod;
// lod.emplace_back(batch_starts);
//
// outs->set_lod(lod);
const MultiClassNMSParam &param) const {
MultiClassNMSCompute<float>(param);
}
} // namespace operators
......
......@@ -15,17 +15,11 @@ limitations under the License. */
#ifdef PRIORBOX_OP
#include "operators/kernel/prior_box_kernel.h"
#include "operators/kernel/central-arm-func/prior_box_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct ClipFunctor {
inline T operator()(T in) const {
return std::min<T>(std::max<T>(in, 0.), 1.);
}
};
template <>
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) {
return true;
......@@ -33,117 +27,7 @@ bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) {
template <>
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const {
const auto *input_ = param.Input();
const auto &input_dims = input_->dims();
const auto *input_image = param.InputImage();
const auto &input_image_dims = input_image->dims();
const auto &min_sizes = param.MinSizes();
const auto &max_sizes = param.MaxSizes();
const auto &variances = param.Variances();
const auto &input_aspect_ratio = param.AspectRatios();
const bool &flip = param.Flip();
const bool &clip = param.Clip();
const float &step_w = param.StepW();
const float &step_h = param.StepH();
const float &offset = param.Offset();
Tensor *output_boxes = param.OutputBoxes();
auto output_boxes_dataptr = output_boxes->mutable_data<float>();
Tensor *output_variances = param.OutputVariances();
auto output_variances_dataptr = output_variances->mutable_data<float>();
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
auto img_width = input_image_dims[3];
auto img_height = input_image_dims[2];
auto feature_width = input_dims[3];
auto feature_height = input_dims[2];
auto stride0 = output_boxes->dims()[1] * output_boxes->dims()[2] *
output_boxes->dims()[3];
auto stride1 = output_boxes->dims()[2] * output_boxes->dims()[3];
auto stride2 = output_boxes->dims()[3];
float step_width, step_height;
/// 300 / 19
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / feature_width;
step_height = static_cast<float>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (!max_sizes.empty()) {
num_priors += max_sizes.size();
}
for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) {
/// map origin image
float center_x = (w + offset) * step_width;
float center_y = (h + offset) * step_height;
float box_width, box_height;
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
// priors with different aspect ratios
for (float ar : aspect_ratios) {
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
/// box_width/2 , / img_width 为了得到feature map 相对于
/// 原图的归一化位置的比例。
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] =
(center_x - box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] =
(center_y - box_height) / img_height;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] =
(center_x + box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] =
(center_y + box_height) / img_height;
idx++;
}
if (!max_sizes.empty()) {
auto max_size = max_sizes[s];
// square prior with size sqrt(minSize * maxSize)
box_width = box_height = sqrt(min_size * max_size) / 2.;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] =
(center_x - box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] =
(center_y - box_height) / img_height;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] =
(center_x + box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] =
(center_y + box_height) / img_height;
idx++;
}
}
}
}
if (clip) {
math::Transform trans;
ClipFunctor<float> clip_func;
trans(output_boxes_dataptr, output_boxes_dataptr + output_boxes->numel(),
output_boxes_dataptr, clip_func);
}
if ((variances.size() != 4)) {
LOG(kLOG_ERROR) << " variances.size() must be 4.";
}
int64_t box_num = feature_height * feature_width * num_priors;
for (int i = 0; i < box_num; i++) {
output_variances_dataptr[4 * i] = variances[0];
output_variances_dataptr[4 * i + 1] = variances[1];
output_variances_dataptr[4 * i + 2] = variances[2];
output_variances_dataptr[4 * i + 3] = variances[3];
}
PriorBoxCompute<float>(param);
}
} // namespace operators
......
......@@ -15,98 +15,21 @@ limitations under the License. */
#ifdef RELU_OP
#include "operators/kernel/relu_kernel.h"
#include <operators/math/transform.h>
#include "operators/kernel/central-arm-func/relu_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <typename T>
struct ReluFunctor {
inline T operator()(T in) const { return in > 0 ? in : 0; }
};
template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) {
return true;
}
/*
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
auto *out_ptr = out->mutable_data<float>();
int numel = input_x->numel();
// if (numel > 64) {
// asm volatile(
// "pld [%[input_x_ptr], #0] \n\t"
// "vmov.f32 q8, #0.0 \n\t"
// "subs %[num], %[num], #32 \n\t"
// "blt end_num_%= \n\t"
// "loop_num_%=: \n\t"
// "pld [%[input_x_ptr], #1024] \n\t"
//
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
//
// "vmax.f32 q0, q0, q8 \n\t"
// "vmax.f32 q1, q1, q8 \n\t"
// "vmax.f32 q2, q2, q8 \n\t"
// "vmax.f32 q3, q3, q8 \n\t"
// "vmax.f32 q4, q4, q8 \n\t"
// "vmax.f32 q5, q5, q8 \n\t"
// "vmax.f32 q6, q6, q8 \n\t"
// "vmax.f32 q7, q7, q8 \n\t"
//
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
//
// "subs %[num], %[num], #32 \n\t"
// "bge loop_num_%= \n\t"
// "end_num_%=: \n\t"
// "cmp %[num], #0 \n\t"
// "bge end_%= \n\t"
// "mov r6, #4 \n\t"
// "mul r5, %[num], r6 \n\t"
// "add %[input_x_ptr], %[input_x_ptr], r5 \n\t"
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
// "vmax.f32 q0, q0, q8 \n\t"
// "vmax.f32 q1, q1, q8 \n\t"
// "vmax.f32 q2, q2, q8 \n\t"
// "vmax.f32 q3, q3, q8 \n\t"
// "vmax.f32 q4, q4, q8 \n\t"
// "vmax.f32 q5, q5, q8 \n\t"
// "vmax.f32 q6, q6, q8 \n\t"
// "vmax.f32 q7, q7, q8 \n\t"
// "add %[out_ptr], %[out_ptr], r5 \n\t"
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
// "end_%=: \n\t"
// :
// :
// [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num]
// "r"(numel) : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
// "q7", "q8", "r5",
// "r6");
// } else {
ReluFunctor<float> func_;
math::Transform trans;
trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_);
// }
ReluCompute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef RESHAPE_OP
#include "operators/kernel/reshape_kernel.h"
#include "operators/kernel/central-arm-func/reshape_arm_func.h"
namespace paddle_mobile {
namespace operators {
......@@ -26,30 +27,7 @@ bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) {
template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims();
auto *out = param.Out();
framework::DDim out_dims = out->dims();
const auto *input_shape = param.InputShape();
if (input_shape) {
auto *shape_data = input_shape->data<int>();
framework::Tensor cpu_shape_tensor;
auto shape =
std::vector<int>(shape_data, shape_data + input_shape->numel());
out_dims = ValidateShape(shape, input_x->dims());
}
bool inplace = param.Inplace();
out->Resize(out_dims);
if (!inplace) {
out->mutable_data<float>();
framework::TensorCopy(*input_x, out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*input_x);
out->Resize(out_dims);
}
ReshapeCompute<float>(param);
}
} // namespace operators
......
......@@ -14,72 +14,19 @@ limitations under the License. */
#ifdef TRANSPOSE_OP
#include "operators/kernel/transpose_kernel.h"
#include "operators/kernel/central-arm-func/transpose_arm_func.h"
namespace paddle_mobile {
namespace operators {
// vector<int> pos;
// template <typename T>
// void TransposeFunc(const int numel, const T* input, const vector<int> axis,
// const vector<int> old_strides, const vector<int>
// new_strides, T* output) {
// for (int i = 0; i < numel; ++i) {
// int old_idx = 0;
// int idx = i;
// for (int j = 0; j < axis.size(); ++j) {
// int order = axis[j];
// old_idx += (idx / new_strides[j]) * old_strides[order];
// idx %= new_strides[j];
// }
// output[i] = input[old_idx];
// }
// }
template <>
bool TransposeKernel<CPU, float>::Init(TransposeParam* param) {
bool TransposeKernel<CPU, float>::Init(TransposeParam *param) {
return true;
}
template <>
void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
const auto* input_x = param.InputX();
const auto input_x_dims = input_x->dims();
auto* out = param.Out();
const auto axis = param.Axis();
const auto* input_x_data = input_x->data<float>();
auto* out_data = out->mutable_data<float>();
size_t ndim = axis.size();
std::vector<int> xdim(ndim);
std::vector<int> xstride(ndim);
std::vector<int> xout(ndim);
for (int i = 0; i < ndim; i++) {
int j = ndim - 1 - i;
xdim[j] = input_x_dims[axis[i]];
xstride[j] = 1;
for (int k = axis[i] + 1; k < ndim; k++) {
xstride[j] *= input_x_dims[k];
}
xout[j] = xstride[j] * xdim[j];
}
auto numel = input_x->numel();
size_t pind = 0;
std::vector<int> ind(ndim);
for (int i = 0; i < numel; i++) {
out_data[i] = input_x_data[pind];
ind[0]++;
pind += xstride[0];
for (int j = 0; j < ndim - 1; j++) {
if (ind[j] == xdim[j]) {
ind[j + 1]++;
ind[j] = 0;
pind += xstride[j + 1];
pind -= xout[j];
} else {
break;
}
}
}
void TransposeKernel<CPU, float>::Compute(const TransposeParam &param) const {
TransposeCompute<float>(param);
}
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef BOXCODER_OP
#pragma once
#include <cmath>
namespace paddle_mobile {
namespace operators {
template <typename T>
void EncodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var, T* output) {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
auto* target_box_data = target_box.data<T>();
auto* prior_box_data = prior_box.data<T>();
auto* prior_box_var_data = prior_box_var.data<T>();
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len];
T prior_box_height =
prior_box_data[j * len + 3] - prior_box_data[j * len + 1];
T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
T target_box_center_x =
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
T target_box_center_y =
(target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
T target_box_width =
target_box_data[i * len + 2] - target_box_data[i * len];
T target_box_height =
target_box_data[i * len + 3] - target_box_data[i * len + 1];
size_t offset = i * col * len + j * len;
output[offset] = (target_box_center_x - prior_box_center_x) /
prior_box_width / prior_box_var_data[j * len];
output[offset + 1] = (target_box_center_y - prior_box_center_y) /
prior_box_height / prior_box_var_data[j * len + 1];
output[offset + 2] =
std::log(std::fabs(target_box_width / prior_box_width)) /
prior_box_var_data[j * len + 2];
output[offset + 3] =
std::log(std::fabs(target_box_height / prior_box_height)) /
prior_box_var_data[j * len + 3];
}
}
}
template <typename T>
void DecodeCenterSize(const framework::Tensor& target_box,
const framework::Tensor& prior_box,
const framework::Tensor& prior_box_var, T* output) {
int64_t row = target_box.dims()[0];
int64_t col = prior_box.dims()[0];
int64_t len = prior_box.dims()[1];
auto* target_box_data = target_box.data<T>();
auto* prior_box_data = prior_box.data<T>();
auto* prior_box_var_data = prior_box_var.data<T>();
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
size_t offset = i * col * len + j * len;
T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len];
T prior_box_height =
prior_box_data[j * len + 3] - prior_box_data[j * len + 1];
T prior_box_center_x =
(prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2;
T prior_box_center_y =
(prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2;
T target_box_center_x = prior_box_var_data[j * len] *
target_box_data[offset] * prior_box_width +
prior_box_center_x;
T target_box_center_y = prior_box_var_data[j * len + 1] *
target_box_data[offset + 1] *
prior_box_height +
prior_box_center_y;
T target_box_width = std::exp(prior_box_var_data[j * len + 2] *
target_box_data[offset + 2]) *
prior_box_width;
T target_box_height = std::exp(prior_box_var_data[j * len + 3] *
target_box_data[offset + 3]) *
prior_box_height;
output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2;
output[offset + 2] = target_box_center_x + target_box_width / 2;
output[offset + 3] = target_box_center_y + target_box_height / 2;
}
}
}
template <typename P>
void BoxCoderCompute(const BoxCoderParam& param) {
const auto* input_priorbox = param.InputPriorBox();
const auto* input_priorboxvar = param.InputPriorBoxVar();
const auto* input_targetbox = param.InputTargetBox();
const auto& code_type = param.CodeType();
auto row = input_targetbox->dims()[0];
auto col = input_priorbox->dims()[0];
auto len = input_priorbox->dims()[1];
Tensor* output_box = param.OutputBox();
auto* output_box_dataptr = output_box->mutable_data<float>({row, col, len});
if (code_type == "encode_center_size") {
EncodeCenterSize<float>(*input_targetbox, *input_priorbox,
*input_priorboxvar, output_box_dataptr);
}
if (code_type == "decode_center_size") {
DecodeCenterSize<float>(*input_targetbox, *input_priorbox,
*input_priorboxvar, output_box_dataptr);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CONCAT_OP
#pragma once
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename T>
class ConcatFunctor {
public:
void operator()(const std::vector<framework::Tensor> &input, const int axis,
framework::Tensor *output) {
size_t num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
for (int k = 0; k < out_rows; ++k) {
T *dst_ptr = output->data<T>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
const T *src_prt = input[j].data<T>() + k * col_len;
memory::Copy(dst_ptr + col_idx, src_prt, sizeof(T) * col_len);
col_idx += col_len;
}
}
}
};
template <typename P>
void ConcatCompute(const ConcatParam &param) {
auto inputs = param.Inputs();
auto *out = param.Out();
int64_t axis = param.Axis();
out->mutable_data<float>();
/// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && inputs.size() < 10) {
size_t output_offset = 0;
for (auto *in : inputs) {
auto in_stride = framework::stride_numel(in->dims());
auto out_stride = framework::stride_numel(out->dims());
auto dst = out->data<float>() + output_offset;
auto src = in->data<float>();
PADDLE_MOBILE_ENFORCE(
in_stride.size() == out_stride.size(),
"src and dst tensor should have the same dims size.");
memory::Copy(dst, src, sizeof(float) * in_stride[0]);
output_offset += in_stride[0];
}
} else {
std::vector<framework::Tensor> inputs_concat(inputs.size());
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = *inputs[j];
}
ConcatFunctor<float> concat_functor;
concat_functor(inputs_concat, static_cast<int>(axis), out);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -15,7 +15,6 @@ limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP
#pragma once
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/op_param.h"
namespace paddle_mobile {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef ELEMENTWISEADD_OP
#pragma once
namespace paddle_mobile {
namespace operators {
template <typename T>
struct AddFunctor {
inline T operator()(T a, T b) const { return a + b; }
};
template <typename P>
void ElementwiseAddCompute(const ElementwiseAddParam &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *Out = param.Out();
Out->mutable_data<float>();
int axis = param.Axis();
ElementwiseComputeEx<AddFunctor<float>, float>(input_x, input_y, axis,
AddFunctor<float>(), Out);
}
template class ElementwiseAddKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_FC_OP
#pragma once
namespace paddle_mobile {
namespace operators {
template <typename P>
void FusionFcCompute(const FusionFcParam &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
const Tensor *input_z = param.InputZ();
auto *input_z_data = input_z->data<float>();
int axis = param.Axis();
Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>();
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
: *input_x;
const Tensor y_matrix =
input_y->dims().size() > 2
? framework::ReshapeToMatrix(*input_y, param.YNumColDims())
: *input_y;
auto out_dim = out->dims();
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
PADDLE_MOBILE_ENFORCE(input_z->dims().size() == 1, "inpu_z size must be 1");
PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0],
" out_dim.size must be 2.");
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ")
int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
for (int i = 0; i < out->numel(); i++) {
DLOG << out_data[i];
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1));
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LRN_OP
#pragma once
namespace paddle_mobile {
namespace operators {
template <typename P>
void LrnCompute(const LrnParam &param) {
const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims();
Tensor *out = param.Out();
out->mutable_data<float>();
/// data_format = NCHW
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int n = param.N();
const float alpha = param.Alpha();
const float beta = param.Beta();
const float k = param.K();
LRNFunctor<float> lrnFunctor;
lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta);
}
template class LrnKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MUL_OP
#pragma once
namespace paddle_mobile {
namespace operators {
template <typename P>
void MulCompute(const MulParam &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *out = param.Out();
out->mutable_data<float>();
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
: *input_x;
const Tensor y_matrix =
input_y->dims().size() > 2
? framework::ReshapeToMatrix(*input_y, param.YNumColDims())
: *input_y;
auto out_dim = out->dims();
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
if (out_dim.size() != 2) {
out->Resize(out_dim);
}
}
template class MulKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef MULTICLASSNMS_OP
#pragma once
#include <algorithm>
#include <map>
#include <utility>
#include <vector>
namespace paddle_mobile {
namespace operators {
constexpr int kOutputDim = 6;
constexpr int kBBoxSize = 4;
template <class T>
bool SortScorePairDescend(const std::pair<float, T>& pair1,
const std::pair<float, T>& pair2) {
return pair1.first > pair2.first;
}
template <class T>
static inline void GetMaxScoreIndex(
const std::vector<T>& scores, const T threshold, int top_k,
std::vector<std::pair<T, int>>* sorted_indices) {
for (size_t i = 0; i < scores.size(); ++i) {
if (scores[i] > threshold) {
sorted_indices->push_back(std::make_pair(scores[i], i));
}
}
// Sort the score pair according to the scores in descending order
std::stable_sort(sorted_indices->begin(), sorted_indices->end(),
SortScorePairDescend<int>);
// Keep top_k scores if needed.
if (top_k > -1 && top_k < static_cast<int>(sorted_indices->size())) {
sorted_indices->resize(top_k);
}
}
template <class T>
static inline T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static inline T JaccardOverlap(const T* box1, const T* box2,
const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
const T inter_w = inter_xmax - inter_xmin;
const T inter_h = inter_ymax - inter_ymin;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <typename T>
static inline void NMSFast(const Tensor& bbox, const Tensor& scores,
const T score_threshold, const T nms_threshold,
const T eta, const int64_t top_k,
std::vector<int>* selected_indices) {
// The total boxes for each instance.
int64_t num_boxes = bbox.dims()[0];
// 4: [xmin ymin xmax ymax]
int64_t box_size = bbox.dims()[1];
std::vector<T> scores_data(num_boxes);
std::copy_n(scores.data<T>(), num_boxes, scores_data.begin());
std::vector<std::pair<T, int>> sorted_indices;
GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices);
selected_indices->clear();
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox.data<T>();
while (sorted_indices.size() != 0) {
const int idx = sorted_indices.front().second;
bool keep = true;
for (size_t k = 0; k < selected_indices->size(); ++k) {
if (keep) {
const int kept_idx = (*selected_indices)[k];
T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size, true);
keep = overlap <= adaptive_threshold;
} else {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
template <typename T>
void MultiClassNMS(const Tensor& scores, const Tensor& bboxes,
std::map<int, std::vector<int>>* indices, int* num_nmsed_out,
const int& background_label, const int& nms_top_k,
const int& keep_top_k, const T& nms_threshold,
const T& nms_eta, const T& score_threshold) {
int64_t class_num = scores.dims()[0];
int64_t predict_dim = scores.dims()[1];
int num_det = 0;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
Tensor score = scores.Slice(c, c + 1);
/// [c] is key
NMSFast<float>(bboxes, score, score_threshold, nms_threshold, nms_eta,
nms_top_k, &((*indices)[c]));
num_det += (*indices)[c].size();
}
*num_nmsed_out = num_det;
const T* scores_data = scores.data<T>();
if (keep_top_k > -1 && num_det > keep_top_k) {
std::vector<std::pair<float, std::pair<int, int>>> score_index_pairs;
for (const auto& it : *indices) {
int label = it.first;
const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& label_indices = it.second;
for (size_t j = 0; j < label_indices.size(); ++j) {
int idx = label_indices[j];
// PADDLE_ENFORCE_LT(idx, predict_dim);
score_index_pairs.push_back(
std::make_pair(sdata[idx], std::make_pair(label, idx)));
}
}
// Keep top k results per image.
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
// Store the new indices.
std::map<int, std::vector<int>> new_indices;
for (size_t j = 0; j < score_index_pairs.size(); ++j) {
int label = score_index_pairs[j].second.first;
int idx = score_index_pairs[j].second.second;
new_indices[label].push_back(idx);
}
new_indices.swap(*indices);
*num_nmsed_out = keep_top_k;
}
}
template <typename T>
void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices,
Tensor* outs) {
int predict_dim = scores.dims()[1];
auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>();
auto* odata = outs->data<T>();
int count = 0;
for (const auto& it : selected_indices) {
/// one batch
int label = it.first;
const T* sdata = scores_data + label * predict_dim;
const std::vector<int>& indices = it.second;
for (size_t j = 0; j < indices.size(); ++j) {
int idx = indices[j];
const T* bdata = bboxes_data + idx * kBBoxSize;
odata[count * kOutputDim] = label; // label
odata[count * kOutputDim + 1] = sdata[idx]; // score
// xmin, ymin, xmax, ymax
std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T));
count++;
}
}
}
template <typename P>
void MultiClassNMSCompute(const MultiClassNMSParam& param) {
const auto* input_bboxes = param.InputBBoxes();
const auto& input_bboxes_dims = input_bboxes->dims();
const auto* input_scores = param.InputScores();
const auto& input_scores_dims = input_scores->dims();
auto* outs = param.Out();
auto background_label = param.BackGroundLabel();
auto nms_top_k = param.NMSTopK();
auto keep_top_k = param.KeepTopK();
auto nms_threshold = param.NMSThreshold();
auto nms_eta = param.NMSEta();
auto score_threshold = param.ScoreThreshold();
int64_t batch_size = input_scores_dims[0];
int64_t class_num = input_scores_dims[1];
int64_t predict_dim = input_scores_dims[2];
int64_t box_dim = input_bboxes_dims[2];
std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<size_t> batch_starts = {0};
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
std::map<int, std::vector<int>> indices;
int num_nmsed_out = 0;
MultiClassNMS<float>(ins_score, ins_boxes, &indices, &num_nmsed_out,
background_label, nms_top_k, keep_top_k, nms_threshold,
nms_eta, score_threshold);
all_indices.push_back(indices);
batch_starts.push_back(batch_starts.back() + num_nmsed_out);
}
int num_kept = batch_starts.back();
if (num_kept == 0) {
float* od = outs->mutable_data<float>({1});
od[0] = -1;
} else {
outs->mutable_data<float>({num_kept, kOutputDim});
for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim});
int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1];
if (e > s) {
Tensor out = outs->Slice(s, e);
MultiClassOutput<float>(ins_score, ins_boxes, all_indices[i], &out);
}
}
}
// framework::LoD lod;
// lod.emplace_back(batch_starts);
//
// outs->set_lod(lod);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PRIORBOX_OP
#pragma once
#include <algorithm>
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename T>
struct ClipFunctor {
inline T operator()(T in) const {
return std::min<T>(std::max<T>(in, 0.), 1.);
}
};
template <typename P>
void PriorBoxCompute(const PriorBoxParam &param) {
const auto *input_ = param.Input();
const auto &input_dims = input_->dims();
const auto *input_image = param.InputImage();
const auto &input_image_dims = input_image->dims();
const auto &min_sizes = param.MinSizes();
const auto &max_sizes = param.MaxSizes();
const auto &variances = param.Variances();
const auto &input_aspect_ratio = param.AspectRatios();
const bool &flip = param.Flip();
const bool &clip = param.Clip();
const float &step_w = param.StepW();
const float &step_h = param.StepH();
const float &offset = param.Offset();
Tensor *output_boxes = param.OutputBoxes();
auto output_boxes_dataptr = output_boxes->mutable_data<float>();
Tensor *output_variances = param.OutputVariances();
auto output_variances_dataptr = output_variances->mutable_data<float>();
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
auto img_width = input_image_dims[3];
auto img_height = input_image_dims[2];
auto feature_width = input_dims[3];
auto feature_height = input_dims[2];
auto stride0 = output_boxes->dims()[1] * output_boxes->dims()[2] *
output_boxes->dims()[3];
auto stride1 = output_boxes->dims()[2] * output_boxes->dims()[3];
auto stride2 = output_boxes->dims()[3];
float step_width, step_height;
/// 300 / 19
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / feature_width;
step_height = static_cast<float>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (!max_sizes.empty()) {
num_priors += max_sizes.size();
}
for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) {
/// map origin image
float center_x = (w + offset) * step_width;
float center_y = (h + offset) * step_height;
float box_width, box_height;
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
// priors with different aspect ratios
for (float ar : aspect_ratios) {
box_width = min_size * sqrt(ar) / 2.;
box_height = min_size / sqrt(ar) / 2.;
/// box_width/2 , / img_width 为了得到feature map 相对于
/// 原图的归一化位置的比例。
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] =
(center_x - box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] =
(center_y - box_height) / img_height;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] =
(center_x + box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] =
(center_y + box_height) / img_height;
idx++;
}
if (!max_sizes.empty()) {
auto max_size = max_sizes[s];
// square prior with size sqrt(minSize * maxSize)
box_width = box_height = sqrt(min_size * max_size) / 2.;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] =
(center_x - box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] =
(center_y - box_height) / img_height;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] =
(center_x + box_width) / img_width;
output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] =
(center_y + box_height) / img_height;
idx++;
}
}
}
}
if (clip) {
math::Transform trans;
ClipFunctor<float> clip_func;
trans(output_boxes_dataptr, output_boxes_dataptr + output_boxes->numel(),
output_boxes_dataptr, clip_func);
}
if ((variances.size() != 4)) {
LOG(kLOG_ERROR) << " variances.size() must be 4.";
}
int64_t box_num = feature_height * feature_width * num_priors;
for (int i = 0; i < box_num; i++) {
output_variances_dataptr[4 * i] = variances[0];
output_variances_dataptr[4 * i + 1] = variances[1];
output_variances_dataptr[4 * i + 2] = variances[2];
output_variances_dataptr[4 * i + 3] = variances[3];
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#pragma once
#include <operators/math/transform.h>
namespace paddle_mobile {
namespace operators {
template <typename T>
struct ReluFunctor {
inline T operator()(T in) const { return in > 0 ? in : 0; }
};
/*
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <typename P>
void ReluCompute(const ReluParam &param) {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
auto *out_ptr = out->mutable_data<float>();
int numel = input_x->numel();
// if (numel > 64) {
// asm volatile(
// "pld [%[input_x_ptr], #0] \n\t"
// "vmov.f32 q8, #0.0 \n\t"
// "subs %[num], %[num], #32 \n\t"
// "blt end_num_%= \n\t"
// "loop_num_%=: \n\t"
// "pld [%[input_x_ptr], #1024] \n\t"
//
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
//
// "vmax.f32 q0, q0, q8 \n\t"
// "vmax.f32 q1, q1, q8 \n\t"
// "vmax.f32 q2, q2, q8 \n\t"
// "vmax.f32 q3, q3, q8 \n\t"
// "vmax.f32 q4, q4, q8 \n\t"
// "vmax.f32 q5, q5, q8 \n\t"
// "vmax.f32 q6, q6, q8 \n\t"
// "vmax.f32 q7, q7, q8 \n\t"
//
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
//
// "subs %[num], %[num], #32 \n\t"
// "bge loop_num_%= \n\t"
// "end_num_%=: \n\t"
// "cmp %[num], #0 \n\t"
// "bge end_%= \n\t"
// "mov r6, #4 \n\t"
// "mul r5, %[num], r6 \n\t"
// "add %[input_x_ptr], %[input_x_ptr], r5 \n\t"
// "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t"
// "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t"
// "vmax.f32 q0, q0, q8 \n\t"
// "vmax.f32 q1, q1, q8 \n\t"
// "vmax.f32 q2, q2, q8 \n\t"
// "vmax.f32 q3, q3, q8 \n\t"
// "vmax.f32 q4, q4, q8 \n\t"
// "vmax.f32 q5, q5, q8 \n\t"
// "vmax.f32 q6, q6, q8 \n\t"
// "vmax.f32 q7, q7, q8 \n\t"
// "add %[out_ptr], %[out_ptr], r5 \n\t"
// "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t"
// "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t"
// "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t"
// "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t"
// "end_%=: \n\t"
// :
// :
// [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num]
// "r"(numel) : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
// "q7", "q8", "r5",
// "r6");
// } else {
ReluFunctor<float> func_;
math::Transform trans;
trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_);
// }
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RESHAPE_OP
#pragma once
#include <vector>
namespace paddle_mobile {
namespace operators {
template <typename P>
void ReshapeCompute(const ReshapeParam &param) {
const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims();
auto *out = param.Out();
framework::DDim out_dims = out->dims();
const auto *input_shape = param.InputShape();
if (input_shape) {
auto *shape_data = input_shape->data<int>();
framework::Tensor cpu_shape_tensor;
auto shape =
std::vector<int>(shape_data, shape_data + input_shape->numel());
out_dims = ValidateShape(shape, input_x->dims());
}
bool inplace = param.Inplace();
out->Resize(out_dims);
if (!inplace) {
out->mutable_data<float>();
framework::TensorCopy(*input_x, out);
out->Resize(out_dims);
} else {
out->ShareDataWith(*input_x);
out->Resize(out_dims);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE_OP
#pragma once
#include <vector>
namespace paddle_mobile {
namespace operators {
// vector<int> pos;
// template <typename T>
// void TransposeFunc(const int numel, const T* input, const vector<int> axis,
// const vector<int> old_strides, const vector<int>
// new_strides, T* output) {
// for (int i = 0; i < numel; ++i) {
// int old_idx = 0;
// int idx = i;
// for (int j = 0; j < axis.size(); ++j) {
// int order = axis[j];
// old_idx += (idx / new_strides[j]) * old_strides[order];
// idx %= new_strides[j];
// }
// output[i] = input[old_idx];
// }
// }
template <typename P>
void TransposeCompute(const TransposeParam& param) {
const auto* input_x = param.InputX();
const auto input_x_dims = input_x->dims();
auto* out = param.Out();
const auto axis = param.Axis();
const auto* input_x_data = input_x->data<float>();
auto* out_data = out->mutable_data<float>();
size_t ndim = axis.size();
std::vector<int> xdim(ndim);
std::vector<int> xstride(ndim);
std::vector<int> xout(ndim);
for (int i = 0; i < ndim; i++) {
int j = ndim - 1 - i;
xdim[j] = input_x_dims[axis[i]];
xstride[j] = 1;
for (int k = axis[i] + 1; k < ndim; k++) {
xstride[j] *= input_x_dims[k];
}
xout[j] = xstride[j] * xdim[j];
}
auto numel = input_x->numel();
size_t pind = 0;
std::vector<int> ind(ndim);
for (int i = 0; i < numel; i++) {
out_data[i] = input_x_data[pind];
ind[0]++;
pind += xstride[0];
for (int j = 0; j < ndim - 1; j++) {
if (ind[j] == xdim[j]) {
ind[j + 1]++;
ind[j] = 0;
pind += xstride[j + 1];
pind -= xout[j];
} else {
break;
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册