diff --git a/src/operators/kernel/arm/box_coder_kernel.cpp b/src/operators/kernel/arm/box_coder_kernel.cpp index fb113b16f53bcd1b9fca7a1dbbf94a846e9a0f81..d2a479391fbbb416eea7d19ae64125cac4637ef1 100644 --- a/src/operators/kernel/arm/box_coder_kernel.cpp +++ b/src/operators/kernel/arm/box_coder_kernel.cpp @@ -15,130 +15,21 @@ limitations under the License. */ #ifdef BOXCODER_OP #include "operators/kernel/box_coder_kernel.h" -#include +#include "operators/kernel/central-arm-func/box_coder_arm_func.h" namespace paddle_mobile { namespace operators { -template -void EncodeCenterSize(const framework::Tensor& target_box, - const framework::Tensor& prior_box, - const framework::Tensor& prior_box_var, T* output) { - int64_t row = target_box.dims()[0]; - int64_t col = prior_box.dims()[0]; - int64_t len = prior_box.dims()[1]; - auto* target_box_data = target_box.data(); - auto* prior_box_data = prior_box.data(); - auto* prior_box_var_data = prior_box_var.data(); - - for (int64_t i = 0; i < row; ++i) { - for (int64_t j = 0; j < col; ++j) { - T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len]; - T prior_box_height = - prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; - T prior_box_center_x = - (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; - T prior_box_center_y = - (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; - - T target_box_center_x = - (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; - T target_box_center_y = - (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; - T target_box_width = - target_box_data[i * len + 2] - target_box_data[i * len]; - T target_box_height = - target_box_data[i * len + 3] - target_box_data[i * len + 1]; - - size_t offset = i * col * len + j * len; - output[offset] = (target_box_center_x - prior_box_center_x) / - prior_box_width / prior_box_var_data[j * len]; - output[offset + 1] = (target_box_center_y - prior_box_center_y) / - prior_box_height / prior_box_var_data[j * len + 1]; - output[offset + 2] = - std::log(std::fabs(target_box_width / prior_box_width)) / - prior_box_var_data[j * len + 2]; - output[offset + 3] = - std::log(std::fabs(target_box_height / prior_box_height)) / - prior_box_var_data[j * len + 3]; - } - } -} - -template -void DecodeCenterSize(const framework::Tensor& target_box, - const framework::Tensor& prior_box, - const framework::Tensor& prior_box_var, T* output) { - int64_t row = target_box.dims()[0]; - int64_t col = prior_box.dims()[0]; - int64_t len = prior_box.dims()[1]; - - auto* target_box_data = target_box.data(); - auto* prior_box_data = prior_box.data(); - auto* prior_box_var_data = prior_box_var.data(); - - for (int64_t i = 0; i < row; ++i) { - for (int64_t j = 0; j < col; ++j) { - size_t offset = i * col * len + j * len; - T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len]; - T prior_box_height = - prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; - T prior_box_center_x = - (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; - T prior_box_center_y = - (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; - - T target_box_center_x = prior_box_var_data[j * len] * - target_box_data[offset] * prior_box_width + - prior_box_center_x; - T target_box_center_y = prior_box_var_data[j * len + 1] * - target_box_data[offset + 1] * - prior_box_height + - prior_box_center_y; - T target_box_width = std::exp(prior_box_var_data[j * len + 2] * - target_box_data[offset + 2]) * - prior_box_width; - T target_box_height = std::exp(prior_box_var_data[j * len + 3] * - target_box_data[offset + 3]) * - prior_box_height; - - output[offset] = target_box_center_x - target_box_width / 2; - output[offset + 1] = target_box_center_y - target_box_height / 2; - output[offset + 2] = target_box_center_x + target_box_width / 2; - output[offset + 3] = target_box_center_y + target_box_height / 2; - } - } -} - template <> -bool BoxCoderKernel::Init(BoxCoderParam* param) { +bool BoxCoderKernel::Init(BoxCoderParam *param) { return true; } template <> -void BoxCoderKernel::Compute(const BoxCoderParam& param) const { - const auto* input_priorbox = param.InputPriorBox(); - const auto* input_priorboxvar = param.InputPriorBoxVar(); - const auto* input_targetbox = param.InputTargetBox(); - - const auto& code_type = param.CodeType(); - - auto row = input_targetbox->dims()[0]; - auto col = input_priorbox->dims()[0]; - auto len = input_priorbox->dims()[1]; - - Tensor* output_box = param.OutputBox(); - auto* output_box_dataptr = output_box->mutable_data({row, col, len}); - - if (code_type == "encode_center_size") { - EncodeCenterSize(*input_targetbox, *input_priorbox, - *input_priorboxvar, output_box_dataptr); - } - if (code_type == "decode_center_size") { - DecodeCenterSize(*input_targetbox, *input_priorbox, - *input_priorboxvar, output_box_dataptr); - } +void BoxCoderKernel::Compute(const BoxCoderParam ¶m) const { + BoxCoderCompute(param); } + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/concat_kernel.cpp b/src/operators/kernel/arm/concat_kernel.cpp index 5daf3e104a04025165ce7281f3a16d8e3f9cb522..b6810bf76946bfb8151f3001b76fcbaa5e99e5fc 100644 --- a/src/operators/kernel/arm/concat_kernel.cpp +++ b/src/operators/kernel/arm/concat_kernel.cpp @@ -15,42 +15,10 @@ limitations under the License. */ #ifdef CONCAT_OP #include "operators/kernel/concat_kernel.h" +#include "operators/kernel/central-arm-func/concat_arm_func.h" namespace paddle_mobile { namespace operators { -template -class ConcatFunctor { - public: - void operator()(const std::vector &input, const int axis, - framework::Tensor *output) { - size_t num = input.size(); - int rows = 1; - auto dim_0 = input[0].dims(); - for (int i = 0; i < axis; ++i) { - rows *= dim_0[i]; - } - int out_rows = rows, out_cols = 0; - - std::vector input_cols(input.size()); - for (int i = 0; i < num; ++i) { - int t_cols = input[i].numel() / rows; - out_cols += t_cols; - input_cols[i] = t_cols; - } - - // computation - for (int k = 0; k < out_rows; ++k) { - T *dst_ptr = output->data() + k * out_cols; - int col_idx = 0; - for (int j = 0; j < num; ++j) { - int col_len = input_cols[j]; - const T *src_prt = input[j].data() + k * col_len; - memory::Copy(dst_ptr + col_idx, src_prt, sizeof(T) * col_len); - col_idx += col_len; - } - } - } -}; template <> bool ConcatKernel::Init(ConcatParam *param) { @@ -59,33 +27,7 @@ bool ConcatKernel::Init(ConcatParam *param) { template <> void ConcatKernel::Compute(const ConcatParam ¶m) const { - auto inputs = param.Inputs(); - auto *out = param.Out(); - int64_t axis = param.Axis(); - out->mutable_data(); - - /// Sometimes direct copies will be faster, this maybe need deeply analysis. - if (axis == 0 && inputs.size() < 10) { - size_t output_offset = 0; - for (auto *in : inputs) { - auto in_stride = framework::stride_numel(in->dims()); - auto out_stride = framework::stride_numel(out->dims()); - auto dst = out->data() + output_offset; - auto src = in->data(); - PADDLE_MOBILE_ENFORCE( - in_stride.size() == out_stride.size(), - "src and dst tensor should have the same dims size."); - memory::Copy(dst, src, sizeof(float) * in_stride[0]); - output_offset += in_stride[0]; - } - } else { - std::vector inputs_concat(inputs.size()); - for (int j = 0; j < inputs.size(); ++j) { - inputs_concat[j] = *inputs[j]; - } - ConcatFunctor concat_functor; - concat_functor(inputs_concat, static_cast(axis), out); - } + ConcatCompute(param); } } // namespace operators diff --git a/src/operators/kernel/arm/elementwise_add_kernel.cpp b/src/operators/kernel/arm/elementwise_add_kernel.cpp index bd9bb26d299bd340074965e41e5658df86bab347..fdab1c60a310480d8e59f3f84802001ea592433a 100644 --- a/src/operators/kernel/arm/elementwise_add_kernel.cpp +++ b/src/operators/kernel/arm/elementwise_add_kernel.cpp @@ -14,18 +14,12 @@ limitations under the License. */ #ifdef ELEMENTWISEADD_OP -#pragma once - #include "operators/kernel/elementwise_add_kernel.h" +#include "operators/kernel/central-arm-func/elementwise_add_arm_func.h" namespace paddle_mobile { namespace operators { -template -struct AddFunctor { - inline T operator()(T a, T b) const { return a + b; } -}; - template <> bool ElementwiseAddKernel::Init(ElementwiseAddParam *param) { return true; @@ -34,17 +28,9 @@ bool ElementwiseAddKernel::Init(ElementwiseAddParam *param) { template <> void ElementwiseAddKernel::Compute( const ElementwiseAddParam ¶m) const { - const Tensor *input_x = param.InputX(); - const Tensor *input_y = param.InputY(); - Tensor *Out = param.Out(); - Out->mutable_data(); - int axis = param.Axis(); - ElementwiseComputeEx, float>(input_x, input_y, axis, - AddFunctor(), Out); + ElementwiseAddCompute(param); } -template class ElementwiseAddKernel; - } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/fusion_fc_kernel.cpp b/src/operators/kernel/arm/fusion_fc_kernel.cpp index e10f11c0b19edf710ffc45f199f096bea0a34b7d..c72960e67f19c601e6f27a3bedf7123c80875e0c 100644 --- a/src/operators/kernel/arm/fusion_fc_kernel.cpp +++ b/src/operators/kernel/arm/fusion_fc_kernel.cpp @@ -14,9 +14,8 @@ limitations under the License. */ #ifdef FUSION_FC_OP -#pragma once - #include "operators/kernel/fusion_fc_kernel.h" +#include "operators/kernel/central-arm-func/fusion_fc_arm_func.h" namespace paddle_mobile { namespace operators { @@ -28,46 +27,7 @@ bool FusionFcKernel::Init(FusionFcParam *param) { template <> void FusionFcKernel::Compute(const FusionFcParam ¶m) const { - const Tensor *input_x = param.InputX(); - const Tensor *input_y = param.InputY(); - const Tensor *input_z = param.InputZ(); - auto *input_z_data = input_z->data(); - int axis = param.Axis(); - Tensor *out = param.Out(); - auto *out_data = out->mutable_data(); - const Tensor x_matrix = - input_x->dims().size() > 2 - ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) - : *input_x; - const Tensor y_matrix = - input_y->dims().size() > 2 - ? framework::ReshapeToMatrix(*input_y, param.YNumColDims()) - : *input_y; - auto out_dim = out->dims(); - if (out_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); - PADDLE_MOBILE_ENFORCE(input_z->dims().size() == 1, "inpu_z size must be 1"); - PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0], - " out_dim.size must be 2."); - axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis); - PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ") - - int64_t classes = input_z->numel(); - for (int i = 0; i < out_dim[0]; i++) { - memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); - } - - for (int i = 0; i < out->numel(); i++) { - DLOG << out_data[i]; - } - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(1)); - PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); - // if (out_dim.size() != 2) { - // out->Resize(out_dim); - // } + FusionFcCompute(param); } } // namespace operators diff --git a/src/operators/kernel/arm/lrn_kernel.cpp b/src/operators/kernel/arm/lrn_kernel.cpp index 356aa388276d9d0359b1a6b3a45c86bcb822fd9e..0c20c5167adee5165067cc5ab4935df255751755 100644 --- a/src/operators/kernel/arm/lrn_kernel.cpp +++ b/src/operators/kernel/arm/lrn_kernel.cpp @@ -14,9 +14,8 @@ limitations under the License. */ #ifdef LRN_OP -#pragma once - #include "operators/kernel/lrn_kernel.h" +#include "operators/kernel/central-arm-func/lrn_arm_func.h" namespace paddle_mobile { namespace operators { @@ -28,26 +27,9 @@ bool LrnKernel::Init(LrnParam *param) { template <> void LrnKernel::Compute(const LrnParam ¶m) const { - const Tensor *input_x = param.InputX(); - auto x_dims = input_x->dims(); - Tensor *out = param.Out(); - out->mutable_data(); - /// data_format = NCHW - const int N = x_dims[0]; - const int C = x_dims[1]; - const int H = x_dims[2]; - const int W = x_dims[3]; - - const int n = param.N(); - const float alpha = param.Alpha(); - const float beta = param.Beta(); - const float k = param.K(); - LRNFunctor lrnFunctor; - lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta); + LrnCompute(param); } -template class LrnKernel; - } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/mul_kernel.cpp b/src/operators/kernel/arm/mul_kernel.cpp index 99b6576d364671be78efa9a8f2ebf85a6e133f33..ac5010ce5492ae1d99e59bfa761e22bb3aa5d1c9 100644 --- a/src/operators/kernel/arm/mul_kernel.cpp +++ b/src/operators/kernel/arm/mul_kernel.cpp @@ -14,9 +14,8 @@ limitations under the License. */ #ifdef MUL_OP -#pragma once - #include "operators/kernel/mul_kernel.h" +#include "operators/kernel/central-arm-func/mul_arm_func.h" namespace paddle_mobile { namespace operators { @@ -28,31 +27,9 @@ bool MulKernel::Init(MulParam *param) { template <> void MulKernel::Compute(const MulParam ¶m) const { - const Tensor *input_x = param.InputX(); - const Tensor *input_y = param.InputY(); - Tensor *out = param.Out(); - out->mutable_data(); - const Tensor x_matrix = - input_x->dims().size() > 2 - ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) - : *input_x; - const Tensor y_matrix = - input_y->dims().size() > 2 - ? framework::ReshapeToMatrix(*input_y, param.YNumColDims()) - : *input_y; - auto out_dim = out->dims(); - if (out_dim.size() != 2) { - out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); - } - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), - out, static_cast(0)); - if (out_dim.size() != 2) { - out->Resize(out_dim); - } + MulCompute(param); } -template class MulKernel; - } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/multiclass_nms_kernel.cpp b/src/operators/kernel/arm/multiclass_nms_kernel.cpp index ecdc60f77b0cad3af2e8b026ab3666394dc43fee..9ed8f1731afe2bab723c66ea1e2e8c5042f6ce28 100644 --- a/src/operators/kernel/arm/multiclass_nms_kernel.cpp +++ b/src/operators/kernel/arm/multiclass_nms_kernel.cpp @@ -15,265 +15,20 @@ limitations under the License. */ #ifdef MULTICLASSNMS_OP #include "operators/kernel/multiclass_nms_kernel.h" -#include +#include "operators/kernel/central-arm-func/multiclass_nms_arm_func.h" + namespace paddle_mobile { namespace operators { -constexpr int kOutputDim = 6; -constexpr int kBBoxSize = 4; - -template -bool SortScorePairDescend(const std::pair& pair1, - const std::pair& pair2) { - return pair1.first > pair2.first; -} - -template -static inline void GetMaxScoreIndex( - const std::vector& scores, const T threshold, int top_k, - std::vector>* sorted_indices) { - for (size_t i = 0; i < scores.size(); ++i) { - if (scores[i] > threshold) { - sorted_indices->push_back(std::make_pair(scores[i], i)); - } - } - // Sort the score pair according to the scores in descending order - std::stable_sort(sorted_indices->begin(), sorted_indices->end(), - SortScorePairDescend); - // Keep top_k scores if needed. - if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { - sorted_indices->resize(top_k); - } -} - -template -static inline T BBoxArea(const T* box, const bool normalized) { - if (box[2] < box[0] || box[3] < box[1]) { - // If coordinate values are is invalid - // (e.g. xmax < xmin or ymax < ymin), return 0. - return static_cast(0.); - } else { - const T w = box[2] - box[0]; - const T h = box[3] - box[1]; - if (normalized) { - return w * h; - } else { - // If coordinate values are not within range [0, 1]. - return (w + 1) * (h + 1); - } - } -} - -template -static inline T JaccardOverlap(const T* box1, const T* box2, - const bool normalized) { - if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || - box2[3] < box1[1]) { - return static_cast(0.); - } else { - const T inter_xmin = std::max(box1[0], box2[0]); - const T inter_ymin = std::max(box1[1], box2[1]); - const T inter_xmax = std::min(box1[2], box2[2]); - const T inter_ymax = std::min(box1[3], box2[3]); - const T inter_w = inter_xmax - inter_xmin; - const T inter_h = inter_ymax - inter_ymin; - const T inter_area = inter_w * inter_h; - const T bbox1_area = BBoxArea(box1, normalized); - const T bbox2_area = BBoxArea(box2, normalized); - return inter_area / (bbox1_area + bbox2_area - inter_area); - } -} - -template -static inline void NMSFast(const Tensor& bbox, const Tensor& scores, - const T score_threshold, const T nms_threshold, - const T eta, const int64_t top_k, - std::vector* selected_indices) { - // The total boxes for each instance. - int64_t num_boxes = bbox.dims()[0]; - // 4: [xmin ymin xmax ymax] - int64_t box_size = bbox.dims()[1]; - - std::vector scores_data(num_boxes); - std::copy_n(scores.data(), num_boxes, scores_data.begin()); - std::vector> sorted_indices; - GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices); - - selected_indices->clear(); - T adaptive_threshold = nms_threshold; - const T* bbox_data = bbox.data(); - - while (sorted_indices.size() != 0) { - const int idx = sorted_indices.front().second; - bool keep = true; - for (size_t k = 0; k < selected_indices->size(); ++k) { - if (keep) { - const int kept_idx = (*selected_indices)[k]; - T overlap = JaccardOverlap(bbox_data + idx * box_size, - bbox_data + kept_idx * box_size, true); - keep = overlap <= adaptive_threshold; - } else { - break; - } - } - if (keep) { - selected_indices->push_back(idx); - } - sorted_indices.erase(sorted_indices.begin()); - if (keep && eta < 1 && adaptive_threshold > 0.5) { - adaptive_threshold *= eta; - } - } -} - -template -void MultiClassNMS(const Tensor& scores, const Tensor& bboxes, - std::map>* indices, int* num_nmsed_out, - const int& background_label, const int& nms_top_k, - const int& keep_top_k, const T& nms_threshold, - const T& nms_eta, const T& score_threshold) { - int64_t class_num = scores.dims()[0]; - int64_t predict_dim = scores.dims()[1]; - int num_det = 0; - for (int64_t c = 0; c < class_num; ++c) { - if (c == background_label) continue; - Tensor score = scores.Slice(c, c + 1); - /// [c] is key - NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, - nms_top_k, &((*indices)[c])); - num_det += (*indices)[c].size(); - } - - *num_nmsed_out = num_det; - const T* scores_data = scores.data(); - if (keep_top_k > -1 && num_det > keep_top_k) { - std::vector>> score_index_pairs; - for (const auto& it : *indices) { - int label = it.first; - const T* sdata = scores_data + label * predict_dim; - const std::vector& label_indices = it.second; - for (size_t j = 0; j < label_indices.size(); ++j) { - int idx = label_indices[j]; - // PADDLE_ENFORCE_LT(idx, predict_dim); - score_index_pairs.push_back( - std::make_pair(sdata[idx], std::make_pair(label, idx))); - } - } - // Keep top k results per image. - std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(), - SortScorePairDescend>); - score_index_pairs.resize(keep_top_k); - - // Store the new indices. - std::map> new_indices; - for (size_t j = 0; j < score_index_pairs.size(); ++j) { - int label = score_index_pairs[j].second.first; - int idx = score_index_pairs[j].second.second; - new_indices[label].push_back(idx); - } - new_indices.swap(*indices); - *num_nmsed_out = keep_top_k; - } -} - -template -void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, - const std::map>& selected_indices, - Tensor* outs) { - int predict_dim = scores.dims()[1]; - auto* scores_data = scores.data(); - auto* bboxes_data = bboxes.data(); - auto* odata = outs->data(); - - int count = 0; - for (const auto& it : selected_indices) { - /// one batch - int label = it.first; - const T* sdata = scores_data + label * predict_dim; - const std::vector& indices = it.second; - for (size_t j = 0; j < indices.size(); ++j) { - int idx = indices[j]; - const T* bdata = bboxes_data + idx * kBBoxSize; - odata[count * kOutputDim] = label; // label - odata[count * kOutputDim + 1] = sdata[idx]; // score - // xmin, ymin, xmax, ymax - std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T)); - count++; - } - } -} - template <> -bool MultiClassNMSKernel::Init(MultiClassNMSParam* param) { +bool MultiClassNMSKernel::Init(MultiClassNMSParam *param) { return true; } template <> void MultiClassNMSKernel::Compute( - const MultiClassNMSParam& param) const { - const auto* input_bboxes = param.InputBBoxes(); - const auto& input_bboxes_dims = input_bboxes->dims(); - - const auto* input_scores = param.InputScores(); - const auto& input_scores_dims = input_scores->dims(); - - auto* outs = param.Out(); - auto background_label = param.BackGroundLabel(); - auto nms_top_k = param.NMSTopK(); - auto keep_top_k = param.KeepTopK(); - auto nms_threshold = param.NMSThreshold(); - auto nms_eta = param.NMSEta(); - auto score_threshold = param.ScoreThreshold(); - - int64_t batch_size = input_scores_dims[0]; - int64_t class_num = input_scores_dims[1]; - int64_t predict_dim = input_scores_dims[2]; - int64_t box_dim = input_bboxes_dims[2]; - - std::vector>> all_indices; - std::vector batch_starts = {0}; - for (int64_t i = 0; i < batch_size; ++i) { - Tensor ins_score = input_scores->Slice(i, i + 1); - ins_score.Resize({class_num, predict_dim}); - - Tensor ins_boxes = input_bboxes->Slice(i, i + 1); - ins_boxes.Resize({predict_dim, box_dim}); - - std::map> indices; - int num_nmsed_out = 0; - MultiClassNMS(ins_score, ins_boxes, &indices, &num_nmsed_out, - background_label, nms_top_k, keep_top_k, nms_threshold, - nms_eta, score_threshold); - all_indices.push_back(indices); - batch_starts.push_back(batch_starts.back() + num_nmsed_out); - } - - int num_kept = batch_starts.back(); - if (num_kept == 0) { - float* od = outs->mutable_data({1}); - od[0] = -1; - } else { - outs->mutable_data({num_kept, kOutputDim}); - for (int64_t i = 0; i < batch_size; ++i) { - Tensor ins_score = input_scores->Slice(i, i + 1); - ins_score.Resize({class_num, predict_dim}); - - Tensor ins_boxes = input_bboxes->Slice(i, i + 1); - ins_boxes.Resize({predict_dim, box_dim}); - - int64_t s = batch_starts[i]; - int64_t e = batch_starts[i + 1]; - if (e > s) { - Tensor out = outs->Slice(s, e); - MultiClassOutput(ins_score, ins_boxes, all_indices[i], &out); - } - } - } - - // framework::LoD lod; - // lod.emplace_back(batch_starts); - // - // outs->set_lod(lod); + const MultiClassNMSParam ¶m) const { + MultiClassNMSCompute(param); } } // namespace operators diff --git a/src/operators/kernel/arm/prior_box_kernel.cpp b/src/operators/kernel/arm/prior_box_kernel.cpp index 32d3818ef244e4c2879167b4273b0538eef08c56..217d4b83cb1156a0e942c5ced5917546250e8bb1 100644 --- a/src/operators/kernel/arm/prior_box_kernel.cpp +++ b/src/operators/kernel/arm/prior_box_kernel.cpp @@ -15,17 +15,11 @@ limitations under the License. */ #ifdef PRIORBOX_OP #include "operators/kernel/prior_box_kernel.h" +#include "operators/kernel/central-arm-func/prior_box_arm_func.h" namespace paddle_mobile { namespace operators { -template -struct ClipFunctor { - inline T operator()(T in) const { - return std::min(std::max(in, 0.), 1.); - } -}; - template <> bool PriorBoxKernel::Init(PriorBoxParam *param) { return true; @@ -33,117 +27,7 @@ bool PriorBoxKernel::Init(PriorBoxParam *param) { template <> void PriorBoxKernel::Compute(const PriorBoxParam ¶m) const { - const auto *input_ = param.Input(); - const auto &input_dims = input_->dims(); - - const auto *input_image = param.InputImage(); - const auto &input_image_dims = input_image->dims(); - - const auto &min_sizes = param.MinSizes(); - const auto &max_sizes = param.MaxSizes(); - const auto &variances = param.Variances(); - const auto &input_aspect_ratio = param.AspectRatios(); - const bool &flip = param.Flip(); - const bool &clip = param.Clip(); - const float &step_w = param.StepW(); - const float &step_h = param.StepH(); - const float &offset = param.Offset(); - - Tensor *output_boxes = param.OutputBoxes(); - auto output_boxes_dataptr = output_boxes->mutable_data(); - Tensor *output_variances = param.OutputVariances(); - auto output_variances_dataptr = output_variances->mutable_data(); - - std::vector aspect_ratios; - ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); - - auto img_width = input_image_dims[3]; - auto img_height = input_image_dims[2]; - - auto feature_width = input_dims[3]; - auto feature_height = input_dims[2]; - - auto stride0 = output_boxes->dims()[1] * output_boxes->dims()[2] * - output_boxes->dims()[3]; - auto stride1 = output_boxes->dims()[2] * output_boxes->dims()[3]; - auto stride2 = output_boxes->dims()[3]; - - float step_width, step_height; - /// 300 / 19 - if (step_w == 0 || step_h == 0) { - step_width = static_cast(img_width) / feature_width; - step_height = static_cast(img_height) / feature_height; - } else { - step_width = step_w; - step_height = step_h; - } - - int num_priors = aspect_ratios.size() * min_sizes.size(); - if (!max_sizes.empty()) { - num_priors += max_sizes.size(); - } - - for (int h = 0; h < feature_height; ++h) { - for (int w = 0; w < feature_width; ++w) { - /// map origin image - float center_x = (w + offset) * step_width; - float center_y = (h + offset) * step_height; - float box_width, box_height; - int idx = 0; - for (size_t s = 0; s < min_sizes.size(); ++s) { - auto min_size = min_sizes[s]; - // priors with different aspect ratios - for (float ar : aspect_ratios) { - box_width = min_size * sqrt(ar) / 2.; - box_height = min_size / sqrt(ar) / 2.; - /// box_width/2 , / img_width 为了得到feature map 相对于 - /// 原图的归一化位置的比例。 - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] = - (center_x - box_width) / img_width; - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] = - (center_y - box_height) / img_height; - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] = - (center_x + box_width) / img_width; - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] = - (center_y + box_height) / img_height; - idx++; - } - if (!max_sizes.empty()) { - auto max_size = max_sizes[s]; - // square prior with size sqrt(minSize * maxSize) - box_width = box_height = sqrt(min_size * max_size) / 2.; - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] = - (center_x - box_width) / img_width; - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] = - (center_y - box_height) / img_height; - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] = - (center_x + box_width) / img_width; - output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] = - (center_y + box_height) / img_height; - idx++; - } - } - } - } - if (clip) { - math::Transform trans; - ClipFunctor clip_func; - trans(output_boxes_dataptr, output_boxes_dataptr + output_boxes->numel(), - output_boxes_dataptr, clip_func); - } - - if ((variances.size() != 4)) { - LOG(kLOG_ERROR) << " variances.size() must be 4."; - } - - int64_t box_num = feature_height * feature_width * num_priors; - - for (int i = 0; i < box_num; i++) { - output_variances_dataptr[4 * i] = variances[0]; - output_variances_dataptr[4 * i + 1] = variances[1]; - output_variances_dataptr[4 * i + 2] = variances[2]; - output_variances_dataptr[4 * i + 3] = variances[3]; - } + PriorBoxCompute(param); } } // namespace operators diff --git a/src/operators/kernel/arm/relu_kernel.cpp b/src/operators/kernel/arm/relu_kernel.cpp index f6480dea75289cf6615a9737acfd913a3cb13008..63259a0c303f5e186f9eb90b98f2a8685f8ba5ca 100644 --- a/src/operators/kernel/arm/relu_kernel.cpp +++ b/src/operators/kernel/arm/relu_kernel.cpp @@ -15,98 +15,21 @@ limitations under the License. */ #ifdef RELU_OP #include "operators/kernel/relu_kernel.h" -#include +#include "operators/kernel/central-arm-func/relu_arm_func.h" namespace paddle_mobile { namespace operators { -template -struct ReluFunctor { - inline T operator()(T in) const { return in > 0 ? in : 0; } -}; - template <> bool ReluKernel::Init(ReluParam *param) { return true; } -/* - * @b 特化到具体平台的实现, param 从 op 层传入 - * */ template <> void ReluKernel::Compute(const ReluParam ¶m) const { - const auto *input_x = param.InputX(); - auto *input_x_ptr = input_x->data(); - auto *out = param.Out(); - auto *out_ptr = out->mutable_data(); - - int numel = input_x->numel(); - // if (numel > 64) { - // asm volatile( - // "pld [%[input_x_ptr], #0] \n\t" - // "vmov.f32 q8, #0.0 \n\t" - // "subs %[num], %[num], #32 \n\t" - // "blt end_num_%= \n\t" - // "loop_num_%=: \n\t" - // "pld [%[input_x_ptr], #1024] \n\t" - // - // "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" - // - // "vmax.f32 q0, q0, q8 \n\t" - // "vmax.f32 q1, q1, q8 \n\t" - // "vmax.f32 q2, q2, q8 \n\t" - // "vmax.f32 q3, q3, q8 \n\t" - // "vmax.f32 q4, q4, q8 \n\t" - // "vmax.f32 q5, q5, q8 \n\t" - // "vmax.f32 q6, q6, q8 \n\t" - // "vmax.f32 q7, q7, q8 \n\t" - // - // "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" - // "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" - // "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" - // "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" - // - // "subs %[num], %[num], #32 \n\t" - // "bge loop_num_%= \n\t" - // "end_num_%=: \n\t" - // "cmp %[num], #0 \n\t" - // "bge end_%= \n\t" - // "mov r6, #4 \n\t" - // "mul r5, %[num], r6 \n\t" - // "add %[input_x_ptr], %[input_x_ptr], r5 \n\t" - // "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" - // "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" - // "vmax.f32 q0, q0, q8 \n\t" - // "vmax.f32 q1, q1, q8 \n\t" - // "vmax.f32 q2, q2, q8 \n\t" - // "vmax.f32 q3, q3, q8 \n\t" - // "vmax.f32 q4, q4, q8 \n\t" - // "vmax.f32 q5, q5, q8 \n\t" - // "vmax.f32 q6, q6, q8 \n\t" - // "vmax.f32 q7, q7, q8 \n\t" - // "add %[out_ptr], %[out_ptr], r5 \n\t" - // "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" - // "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" - // "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" - // "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" - // "end_%=: \n\t" - // : - // : - // [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] - // "r"(numel) : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", - // "q7", "q8", "r5", - // "r6"); - // } else { - ReluFunctor func_; - math::Transform trans; - trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_); - // } + ReluCompute(param); } + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/reshape_kernel.cpp b/src/operators/kernel/arm/reshape_kernel.cpp index 9e0fd96d3ecd9772ef6e95bc12bb071a25a1d84a..5ae8e5e3f945d115215652ded58dc8571868fcd7 100644 --- a/src/operators/kernel/arm/reshape_kernel.cpp +++ b/src/operators/kernel/arm/reshape_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef RESHAPE_OP #include "operators/kernel/reshape_kernel.h" +#include "operators/kernel/central-arm-func/reshape_arm_func.h" namespace paddle_mobile { namespace operators { @@ -26,30 +27,7 @@ bool ReshapeKernel::Init(ReshapeParam *param) { template <> void ReshapeKernel::Compute(const ReshapeParam ¶m) const { - const auto *input_x = param.InputX(); - const auto &input_x_dims = input_x->dims(); - auto *out = param.Out(); - framework::DDim out_dims = out->dims(); - const auto *input_shape = param.InputShape(); - - if (input_shape) { - auto *shape_data = input_shape->data(); - framework::Tensor cpu_shape_tensor; - auto shape = - std::vector(shape_data, shape_data + input_shape->numel()); - out_dims = ValidateShape(shape, input_x->dims()); - } - - bool inplace = param.Inplace(); - out->Resize(out_dims); - if (!inplace) { - out->mutable_data(); - framework::TensorCopy(*input_x, out); - out->Resize(out_dims); - } else { - out->ShareDataWith(*input_x); - out->Resize(out_dims); - } + ReshapeCompute(param); } } // namespace operators diff --git a/src/operators/kernel/arm/transpose_kernel.cpp b/src/operators/kernel/arm/transpose_kernel.cpp index f697d4ca473d64b834fe1451afd8e0df7f84b3a6..c358edd76e93cee3f8be6086a70c34671c87d383 100644 --- a/src/operators/kernel/arm/transpose_kernel.cpp +++ b/src/operators/kernel/arm/transpose_kernel.cpp @@ -14,72 +14,19 @@ limitations under the License. */ #ifdef TRANSPOSE_OP #include "operators/kernel/transpose_kernel.h" +#include "operators/kernel/central-arm-func/transpose_arm_func.h" + namespace paddle_mobile { namespace operators { -// vector pos; -// template -// void TransposeFunc(const int numel, const T* input, const vector axis, -// const vector old_strides, const vector -// new_strides, T* output) { -// for (int i = 0; i < numel; ++i) { -// int old_idx = 0; -// int idx = i; -// for (int j = 0; j < axis.size(); ++j) { -// int order = axis[j]; -// old_idx += (idx / new_strides[j]) * old_strides[order]; -// idx %= new_strides[j]; -// } -// output[i] = input[old_idx]; -// } -// } - template <> -bool TransposeKernel::Init(TransposeParam* param) { +bool TransposeKernel::Init(TransposeParam *param) { return true; } template <> -void TransposeKernel::Compute(const TransposeParam& param) const { - const auto* input_x = param.InputX(); - const auto input_x_dims = input_x->dims(); - auto* out = param.Out(); - const auto axis = param.Axis(); - const auto* input_x_data = input_x->data(); - auto* out_data = out->mutable_data(); - - size_t ndim = axis.size(); - std::vector xdim(ndim); - std::vector xstride(ndim); - std::vector xout(ndim); - for (int i = 0; i < ndim; i++) { - int j = ndim - 1 - i; - xdim[j] = input_x_dims[axis[i]]; - xstride[j] = 1; - for (int k = axis[i] + 1; k < ndim; k++) { - xstride[j] *= input_x_dims[k]; - } - xout[j] = xstride[j] * xdim[j]; - } - - auto numel = input_x->numel(); - size_t pind = 0; - std::vector ind(ndim); - for (int i = 0; i < numel; i++) { - out_data[i] = input_x_data[pind]; - ind[0]++; - pind += xstride[0]; - for (int j = 0; j < ndim - 1; j++) { - if (ind[j] == xdim[j]) { - ind[j + 1]++; - ind[j] = 0; - pind += xstride[j + 1]; - pind -= xout[j]; - } else { - break; - } - } - } +void TransposeKernel::Compute(const TransposeParam ¶m) const { + TransposeCompute(param); } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/box_coder_arm_func.h b/src/operators/kernel/central-arm-func/box_coder_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..eeb05f31b744c9e55e78375a495c5a5debf095c2 --- /dev/null +++ b/src/operators/kernel/central-arm-func/box_coder_arm_func.h @@ -0,0 +1,140 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BOXCODER_OP +#pragma once + +#include + +namespace paddle_mobile { +namespace operators { + +template +void EncodeCenterSize(const framework::Tensor& target_box, + const framework::Tensor& prior_box, + const framework::Tensor& prior_box_var, T* output) { + int64_t row = target_box.dims()[0]; + int64_t col = prior_box.dims()[0]; + int64_t len = prior_box.dims()[1]; + auto* target_box_data = target_box.data(); + auto* prior_box_data = prior_box.data(); + auto* prior_box_var_data = prior_box_var.data(); + + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len]; + T prior_box_height = + prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; + T prior_box_center_x = + (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; + T prior_box_center_y = + (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; + + T target_box_center_x = + (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; + T target_box_center_y = + (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; + T target_box_width = + target_box_data[i * len + 2] - target_box_data[i * len]; + T target_box_height = + target_box_data[i * len + 3] - target_box_data[i * len + 1]; + + size_t offset = i * col * len + j * len; + output[offset] = (target_box_center_x - prior_box_center_x) / + prior_box_width / prior_box_var_data[j * len]; + output[offset + 1] = (target_box_center_y - prior_box_center_y) / + prior_box_height / prior_box_var_data[j * len + 1]; + output[offset + 2] = + std::log(std::fabs(target_box_width / prior_box_width)) / + prior_box_var_data[j * len + 2]; + output[offset + 3] = + std::log(std::fabs(target_box_height / prior_box_height)) / + prior_box_var_data[j * len + 3]; + } + } +} + +template +void DecodeCenterSize(const framework::Tensor& target_box, + const framework::Tensor& prior_box, + const framework::Tensor& prior_box_var, T* output) { + int64_t row = target_box.dims()[0]; + int64_t col = prior_box.dims()[0]; + int64_t len = prior_box.dims()[1]; + + auto* target_box_data = target_box.data(); + auto* prior_box_data = prior_box.data(); + auto* prior_box_var_data = prior_box_var.data(); + + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + size_t offset = i * col * len + j * len; + T prior_box_width = prior_box_data[j * len + 2] - prior_box_data[j * len]; + T prior_box_height = + prior_box_data[j * len + 3] - prior_box_data[j * len + 1]; + T prior_box_center_x = + (prior_box_data[j * len + 2] + prior_box_data[j * len]) / 2; + T prior_box_center_y = + (prior_box_data[j * len + 3] + prior_box_data[j * len + 1]) / 2; + + T target_box_center_x = prior_box_var_data[j * len] * + target_box_data[offset] * prior_box_width + + prior_box_center_x; + T target_box_center_y = prior_box_var_data[j * len + 1] * + target_box_data[offset + 1] * + prior_box_height + + prior_box_center_y; + T target_box_width = std::exp(prior_box_var_data[j * len + 2] * + target_box_data[offset + 2]) * + prior_box_width; + T target_box_height = std::exp(prior_box_var_data[j * len + 3] * + target_box_data[offset + 3]) * + prior_box_height; + + output[offset] = target_box_center_x - target_box_width / 2; + output[offset + 1] = target_box_center_y - target_box_height / 2; + output[offset + 2] = target_box_center_x + target_box_width / 2; + output[offset + 3] = target_box_center_y + target_box_height / 2; + } + } +} + +template +void BoxCoderCompute(const BoxCoderParam& param) { + const auto* input_priorbox = param.InputPriorBox(); + const auto* input_priorboxvar = param.InputPriorBoxVar(); + const auto* input_targetbox = param.InputTargetBox(); + + const auto& code_type = param.CodeType(); + + auto row = input_targetbox->dims()[0]; + auto col = input_priorbox->dims()[0]; + auto len = input_priorbox->dims()[1]; + + Tensor* output_box = param.OutputBox(); + auto* output_box_dataptr = output_box->mutable_data({row, col, len}); + + if (code_type == "encode_center_size") { + EncodeCenterSize(*input_targetbox, *input_priorbox, + *input_priorboxvar, output_box_dataptr); + } + if (code_type == "decode_center_size") { + DecodeCenterSize(*input_targetbox, *input_priorbox, + *input_priorboxvar, output_box_dataptr); + } +} +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/concat_arm_func.h b/src/operators/kernel/central-arm-func/concat_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..e9926505b33b32ee83a16f882cc0f775797f154a --- /dev/null +++ b/src/operators/kernel/central-arm-func/concat_arm_func.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONCAT_OP +#pragma once + +#include + +namespace paddle_mobile { +namespace operators { +template +class ConcatFunctor { + public: + void operator()(const std::vector &input, const int axis, + framework::Tensor *output) { + size_t num = input.size(); + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + // computation + for (int k = 0; k < out_rows; ++k) { + T *dst_ptr = output->data() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + const T *src_prt = input[j].data() + k * col_len; + memory::Copy(dst_ptr + col_idx, src_prt, sizeof(T) * col_len); + col_idx += col_len; + } + } + } +}; + +template +void ConcatCompute(const ConcatParam ¶m) { + auto inputs = param.Inputs(); + auto *out = param.Out(); + int64_t axis = param.Axis(); + out->mutable_data(); + + /// Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && inputs.size() < 10) { + size_t output_offset = 0; + for (auto *in : inputs) { + auto in_stride = framework::stride_numel(in->dims()); + auto out_stride = framework::stride_numel(out->dims()); + auto dst = out->data() + output_offset; + auto src = in->data(); + PADDLE_MOBILE_ENFORCE( + in_stride.size() == out_stride.size(), + "src and dst tensor should have the same dims size."); + memory::Copy(dst, src, sizeof(float) * in_stride[0]); + output_offset += in_stride[0]; + } + } else { + std::vector inputs_concat(inputs.size()); + for (int j = 0; j < inputs.size(); ++j) { + inputs_concat[j] = *inputs[j]; + } + ConcatFunctor concat_functor; + concat_functor(inputs_concat, static_cast(axis), out); + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h index 6fce2c26347183c47ae756b07156de76f37ea6e5..bf96a2d46fd96516743127b71db57496e35b8a77 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h @@ -15,7 +15,6 @@ limitations under the License. */ #ifdef FUSION_CONVADDBNRELU_OP #pragma once -#include "operators/kernel/conv_add_bn_relu_kernel.h" #include "operators/math/depthwise_conv_3x3.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..8b3f5d0a8083b63334319b2054f9bf463efa66c7 --- /dev/null +++ b/src/operators/kernel/central-arm-func/elementwise_add_arm_func.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ELEMENTWISEADD_OP + +#pragma once + +namespace paddle_mobile { +namespace operators { + +template +struct AddFunctor { + inline T operator()(T a, T b) const { return a + b; } +}; + +template +void ElementwiseAddCompute(const ElementwiseAddParam ¶m) { + const Tensor *input_x = param.InputX(); + const Tensor *input_y = param.InputY(); + Tensor *Out = param.Out(); + Out->mutable_data(); + int axis = param.Axis(); + ElementwiseComputeEx, float>(input_x, input_y, axis, + AddFunctor(), Out); +} + +template class ElementwiseAddKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..8a01f554140712c6a941b40372cbcfe35a951ce7 --- /dev/null +++ b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_FC_OP + +#pragma once + +namespace paddle_mobile { +namespace operators { + +template +void FusionFcCompute(const FusionFcParam ¶m) { + const Tensor *input_x = param.InputX(); + const Tensor *input_y = param.InputY(); + const Tensor *input_z = param.InputZ(); + auto *input_z_data = input_z->data(); + int axis = param.Axis(); + Tensor *out = param.Out(); + auto *out_data = out->mutable_data(); + const Tensor x_matrix = + input_x->dims().size() > 2 + ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) + : *input_x; + const Tensor y_matrix = + input_y->dims().size() > 2 + ? framework::ReshapeToMatrix(*input_y, param.YNumColDims()) + : *input_y; + auto out_dim = out->dims(); + if (out_dim.size() != 2) { + out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); + PADDLE_MOBILE_ENFORCE(input_z->dims().size() == 1, "inpu_z size must be 1"); + PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0], + " out_dim.size must be 2."); + axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis); + PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ") + + int64_t classes = input_z->numel(); + for (int i = 0; i < out_dim[0]; i++) { + memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); + } + + for (int i = 0; i < out->numel(); i++) { + DLOG << out_data[i]; + } + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(1)); + PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); + // if (out_dim.size() != 2) { + // out->Resize(out_dim); + // } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/lrn_arm_func.h b/src/operators/kernel/central-arm-func/lrn_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..52bb1b67dee83c28f513649a8763034a8d538d73 --- /dev/null +++ b/src/operators/kernel/central-arm-func/lrn_arm_func.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef LRN_OP + +#pragma once + +namespace paddle_mobile { +namespace operators { + +template +void LrnCompute(const LrnParam ¶m) { + const Tensor *input_x = param.InputX(); + auto x_dims = input_x->dims(); + Tensor *out = param.Out(); + out->mutable_data(); + /// data_format = NCHW + const int N = x_dims[0]; + const int C = x_dims[1]; + const int H = x_dims[2]; + const int W = x_dims[3]; + + const int n = param.N(); + const float alpha = param.Alpha(); + const float beta = param.Beta(); + const float k = param.K(); + LRNFunctor lrnFunctor; + lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta); +} + +template class LrnKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..9dfb1f48a574156f1b026fc6af3a03d77b81263f --- /dev/null +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef MUL_OP + +#pragma once + +namespace paddle_mobile { +namespace operators { + +template +void MulCompute(const MulParam ¶m) { + const Tensor *input_x = param.InputX(); + const Tensor *input_y = param.InputY(); + Tensor *out = param.Out(); + out->mutable_data(); + const Tensor x_matrix = + input_x->dims().size() > 2 + ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) + : *input_x; + const Tensor y_matrix = + input_y->dims().size() > 2 + ? framework::ReshapeToMatrix(*input_y, param.YNumColDims()) + : *input_y; + auto out_dim = out->dims(); + if (out_dim.size() != 2) { + out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(0)); + if (out_dim.size() != 2) { + out->Resize(out_dim); + } +} + +template class MulKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h b/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..8833f012d97390e758ac6fc394ef237cb86632b1 --- /dev/null +++ b/src/operators/kernel/central-arm-func/multiclass_nms_arm_func.h @@ -0,0 +1,280 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef MULTICLASSNMS_OP +#pragma once + +#include +#include +#include +#include + +namespace paddle_mobile { +namespace operators { + +constexpr int kOutputDim = 6; +constexpr int kBBoxSize = 4; + +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +static inline void GetMaxScoreIndex( + const std::vector& scores, const T threshold, int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { + sorted_indices->resize(top_k); + } +} + +template +static inline T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const T* box1, const T* box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + const T inter_w = inter_xmax - inter_xmin; + const T inter_h = inter_ymax - inter_ymin; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +static inline void NMSFast(const Tensor& bbox, const Tensor& scores, + const T score_threshold, const T nms_threshold, + const T eta, const int64_t top_k, + std::vector* selected_indices) { + // The total boxes for each instance. + int64_t num_boxes = bbox.dims()[0]; + // 4: [xmin ymin xmax ymax] + int64_t box_size = bbox.dims()[1]; + + std::vector scores_data(num_boxes); + std::copy_n(scores.data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices; + GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices); + + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + const T* bbox_data = bbox.data(); + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (size_t k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = JaccardOverlap(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, true); + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } +} + +template +void MultiClassNMS(const Tensor& scores, const Tensor& bboxes, + std::map>* indices, int* num_nmsed_out, + const int& background_label, const int& nms_top_k, + const int& keep_top_k, const T& nms_threshold, + const T& nms_eta, const T& score_threshold) { + int64_t class_num = scores.dims()[0]; + int64_t predict_dim = scores.dims()[1]; + int num_det = 0; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + Tensor score = scores.Slice(c, c + 1); + /// [c] is key + NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, + nms_top_k, &((*indices)[c])); + num_det += (*indices)[c].size(); + } + + *num_nmsed_out = num_det; + const T* scores_data = scores.data(); + if (keep_top_k > -1 && num_det > keep_top_k) { + std::vector>> score_index_pairs; + for (const auto& it : *indices) { + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + const std::vector& label_indices = it.second; + for (size_t j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + // PADDLE_ENFORCE_LT(idx, predict_dim); + score_index_pairs.push_back( + std::make_pair(sdata[idx], std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScorePairDescend>); + score_index_pairs.resize(keep_top_k); + + // Store the new indices. + std::map> new_indices; + for (size_t j = 0; j < score_index_pairs.size(); ++j) { + int label = score_index_pairs[j].second.first; + int idx = score_index_pairs[j].second.second; + new_indices[label].push_back(idx); + } + new_indices.swap(*indices); + *num_nmsed_out = keep_top_k; + } +} + +template +void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, + const std::map>& selected_indices, + Tensor* outs) { + int predict_dim = scores.dims()[1]; + auto* scores_data = scores.data(); + auto* bboxes_data = bboxes.data(); + auto* odata = outs->data(); + + int count = 0; + for (const auto& it : selected_indices) { + /// one batch + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + const std::vector& indices = it.second; + for (size_t j = 0; j < indices.size(); ++j) { + int idx = indices[j]; + const T* bdata = bboxes_data + idx * kBBoxSize; + odata[count * kOutputDim] = label; // label + odata[count * kOutputDim + 1] = sdata[idx]; // score + // xmin, ymin, xmax, ymax + std::memcpy(odata + count * kOutputDim + 2, bdata, 4 * sizeof(T)); + count++; + } + } +} + +template +void MultiClassNMSCompute(const MultiClassNMSParam& param) { + const auto* input_bboxes = param.InputBBoxes(); + const auto& input_bboxes_dims = input_bboxes->dims(); + + const auto* input_scores = param.InputScores(); + const auto& input_scores_dims = input_scores->dims(); + + auto* outs = param.Out(); + auto background_label = param.BackGroundLabel(); + auto nms_top_k = param.NMSTopK(); + auto keep_top_k = param.KeepTopK(); + auto nms_threshold = param.NMSThreshold(); + auto nms_eta = param.NMSEta(); + auto score_threshold = param.ScoreThreshold(); + + int64_t batch_size = input_scores_dims[0]; + int64_t class_num = input_scores_dims[1]; + int64_t predict_dim = input_scores_dims[2]; + int64_t box_dim = input_bboxes_dims[2]; + + std::vector>> all_indices; + std::vector batch_starts = {0}; + for (int64_t i = 0; i < batch_size; ++i) { + Tensor ins_score = input_scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + + Tensor ins_boxes = input_bboxes->Slice(i, i + 1); + ins_boxes.Resize({predict_dim, box_dim}); + + std::map> indices; + int num_nmsed_out = 0; + MultiClassNMS(ins_score, ins_boxes, &indices, &num_nmsed_out, + background_label, nms_top_k, keep_top_k, nms_threshold, + nms_eta, score_threshold); + all_indices.push_back(indices); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + int num_kept = batch_starts.back(); + if (num_kept == 0) { + float* od = outs->mutable_data({1}); + od[0] = -1; + } else { + outs->mutable_data({num_kept, kOutputDim}); + for (int64_t i = 0; i < batch_size; ++i) { + Tensor ins_score = input_scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + + Tensor ins_boxes = input_bboxes->Slice(i, i + 1); + ins_boxes.Resize({predict_dim, box_dim}); + + int64_t s = batch_starts[i]; + int64_t e = batch_starts[i + 1]; + if (e > s) { + Tensor out = outs->Slice(s, e); + MultiClassOutput(ins_score, ins_boxes, all_indices[i], &out); + } + } + } + + // framework::LoD lod; + // lod.emplace_back(batch_starts); + // + // outs->set_lod(lod); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/prior_box_arm_func.h b/src/operators/kernel/central-arm-func/prior_box_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..892dceb9254ac423d3591a0fc9e9347bc375831b --- /dev/null +++ b/src/operators/kernel/central-arm-func/prior_box_arm_func.h @@ -0,0 +1,149 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PRIORBOX_OP +#pragma once + +#include +#include + +namespace paddle_mobile { +namespace operators { + +template +struct ClipFunctor { + inline T operator()(T in) const { + return std::min(std::max(in, 0.), 1.); + } +}; + +template +void PriorBoxCompute(const PriorBoxParam ¶m) { + const auto *input_ = param.Input(); + const auto &input_dims = input_->dims(); + + const auto *input_image = param.InputImage(); + const auto &input_image_dims = input_image->dims(); + + const auto &min_sizes = param.MinSizes(); + const auto &max_sizes = param.MaxSizes(); + const auto &variances = param.Variances(); + const auto &input_aspect_ratio = param.AspectRatios(); + const bool &flip = param.Flip(); + const bool &clip = param.Clip(); + const float &step_w = param.StepW(); + const float &step_h = param.StepH(); + const float &offset = param.Offset(); + + Tensor *output_boxes = param.OutputBoxes(); + auto output_boxes_dataptr = output_boxes->mutable_data(); + Tensor *output_variances = param.OutputVariances(); + auto output_variances_dataptr = output_variances->mutable_data(); + + std::vector aspect_ratios; + ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); + + auto img_width = input_image_dims[3]; + auto img_height = input_image_dims[2]; + + auto feature_width = input_dims[3]; + auto feature_height = input_dims[2]; + + auto stride0 = output_boxes->dims()[1] * output_boxes->dims()[2] * + output_boxes->dims()[3]; + auto stride1 = output_boxes->dims()[2] * output_boxes->dims()[3]; + auto stride2 = output_boxes->dims()[3]; + + float step_width, step_height; + /// 300 / 19 + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (!max_sizes.empty()) { + num_priors += max_sizes.size(); + } + + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { + /// map origin image + float center_x = (w + offset) * step_width; + float center_y = (h + offset) * step_height; + float box_width, box_height; + int idx = 0; + for (size_t s = 0; s < min_sizes.size(); ++s) { + auto min_size = min_sizes[s]; + // priors with different aspect ratios + for (float ar : aspect_ratios) { + box_width = min_size * sqrt(ar) / 2.; + box_height = min_size / sqrt(ar) / 2.; + /// box_width/2 , / img_width 为了得到feature map 相对于 + /// 原图的归一化位置的比例。 + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] = + (center_x - box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] = + (center_y - box_height) / img_height; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] = + (center_x + box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] = + (center_y + box_height) / img_height; + idx++; + } + if (!max_sizes.empty()) { + auto max_size = max_sizes[s]; + // square prior with size sqrt(minSize * maxSize) + box_width = box_height = sqrt(min_size * max_size) / 2.; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 0] = + (center_x - box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 1] = + (center_y - box_height) / img_height; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 2] = + (center_x + box_width) / img_width; + output_boxes_dataptr[h * stride0 + w * stride1 + idx * stride2 + 3] = + (center_y + box_height) / img_height; + idx++; + } + } + } + } + if (clip) { + math::Transform trans; + ClipFunctor clip_func; + trans(output_boxes_dataptr, output_boxes_dataptr + output_boxes->numel(), + output_boxes_dataptr, clip_func); + } + + if ((variances.size() != 4)) { + LOG(kLOG_ERROR) << " variances.size() must be 4."; + } + + int64_t box_num = feature_height * feature_width * num_priors; + + for (int i = 0; i < box_num; i++) { + output_variances_dataptr[4 * i] = variances[0]; + output_variances_dataptr[4 * i + 1] = variances[1]; + output_variances_dataptr[4 * i + 2] = variances[2]; + output_variances_dataptr[4 * i + 3] = variances[3]; + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/relu_arm_func.h b/src/operators/kernel/central-arm-func/relu_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..19ccb3e862a29cab79453572b24ed0c5a2a8301d --- /dev/null +++ b/src/operators/kernel/central-arm-func/relu_arm_func.h @@ -0,0 +1,108 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef RELU_OP +#pragma once + +#include + +namespace paddle_mobile { +namespace operators { + +template +struct ReluFunctor { + inline T operator()(T in) const { return in > 0 ? in : 0; } +}; + +/* + * @b 特化到具体平台的实现, param 从 op 层传入 + * */ +template +void ReluCompute(const ReluParam ¶m) { + const auto *input_x = param.InputX(); + auto *input_x_ptr = input_x->data(); + auto *out = param.Out(); + auto *out_ptr = out->mutable_data(); + + int numel = input_x->numel(); + // if (numel > 64) { + // asm volatile( + // "pld [%[input_x_ptr], #0] \n\t" + // "vmov.f32 q8, #0.0 \n\t" + // "subs %[num], %[num], #32 \n\t" + // "blt end_num_%= \n\t" + // "loop_num_%=: \n\t" + // "pld [%[input_x_ptr], #1024] \n\t" + // + // "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" + // "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" + // "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" + // "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" + // + // "vmax.f32 q0, q0, q8 \n\t" + // "vmax.f32 q1, q1, q8 \n\t" + // "vmax.f32 q2, q2, q8 \n\t" + // "vmax.f32 q3, q3, q8 \n\t" + // "vmax.f32 q4, q4, q8 \n\t" + // "vmax.f32 q5, q5, q8 \n\t" + // "vmax.f32 q6, q6, q8 \n\t" + // "vmax.f32 q7, q7, q8 \n\t" + // + // "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" + // "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" + // "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" + // "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" + // + // "subs %[num], %[num], #32 \n\t" + // "bge loop_num_%= \n\t" + // "end_num_%=: \n\t" + // "cmp %[num], #0 \n\t" + // "bge end_%= \n\t" + // "mov r6, #4 \n\t" + // "mul r5, %[num], r6 \n\t" + // "add %[input_x_ptr], %[input_x_ptr], r5 \n\t" + // "vld1.32 {q0, q1}, [%[input_x_ptr]]! \n\t" + // "vld1.32 {q2, q3}, [%[input_x_ptr]]! \n\t" + // "vld1.32 {q4, q5}, [%[input_x_ptr]]! \n\t" + // "vld1.32 {q6, q7}, [%[input_x_ptr]]! \n\t" + // "vmax.f32 q0, q0, q8 \n\t" + // "vmax.f32 q1, q1, q8 \n\t" + // "vmax.f32 q2, q2, q8 \n\t" + // "vmax.f32 q3, q3, q8 \n\t" + // "vmax.f32 q4, q4, q8 \n\t" + // "vmax.f32 q5, q5, q8 \n\t" + // "vmax.f32 q6, q6, q8 \n\t" + // "vmax.f32 q7, q7, q8 \n\t" + // "add %[out_ptr], %[out_ptr], r5 \n\t" + // "vst1.32 {q0, q1}, [%[out_ptr]]! \n\t" + // "vst1.32 {q2, q3}, [%[out_ptr]]! \n\t" + // "vst1.32 {q4, q5}, [%[out_ptr]]! \n\t" + // "vst1.32 {q6, q7}, [%[out_ptr]]! \n\t" + // "end_%=: \n\t" + // : + // : + // [out_ptr] "r"(out_ptr), [input_x_ptr] "r"(input_x_ptr), [num] + // "r"(numel) : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + // "q7", "q8", "r5", + // "r6"); + // } else { + ReluFunctor func_; + math::Transform trans; + trans(input_x_ptr, input_x_ptr + numel, out_ptr, func_); + // } +} +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/reshape_arm_func.h b/src/operators/kernel/central-arm-func/reshape_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..a2fb836257418923f41e94ceaf499e38033c6b4c --- /dev/null +++ b/src/operators/kernel/central-arm-func/reshape_arm_func.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef RESHAPE_OP +#pragma once + +#include + +namespace paddle_mobile { +namespace operators { + +template +void ReshapeCompute(const ReshapeParam ¶m) { + const auto *input_x = param.InputX(); + const auto &input_x_dims = input_x->dims(); + auto *out = param.Out(); + framework::DDim out_dims = out->dims(); + const auto *input_shape = param.InputShape(); + + if (input_shape) { + auto *shape_data = input_shape->data(); + framework::Tensor cpu_shape_tensor; + auto shape = + std::vector(shape_data, shape_data + input_shape->numel()); + out_dims = ValidateShape(shape, input_x->dims()); + } + + bool inplace = param.Inplace(); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(); + framework::TensorCopy(*input_x, out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*input_x); + out->Resize(out_dims); + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/transpose_arm_func.h b/src/operators/kernel/central-arm-func/transpose_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..1cbebc4525113374061541518775a94c6a64401f --- /dev/null +++ b/src/operators/kernel/central-arm-func/transpose_arm_func.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef TRANSPOSE_OP +#pragma once + +#include + +namespace paddle_mobile { +namespace operators { + +// vector pos; +// template +// void TransposeFunc(const int numel, const T* input, const vector axis, +// const vector old_strides, const vector +// new_strides, T* output) { +// for (int i = 0; i < numel; ++i) { +// int old_idx = 0; +// int idx = i; +// for (int j = 0; j < axis.size(); ++j) { +// int order = axis[j]; +// old_idx += (idx / new_strides[j]) * old_strides[order]; +// idx %= new_strides[j]; +// } +// output[i] = input[old_idx]; +// } +// } + +template +void TransposeCompute(const TransposeParam& param) { + const auto* input_x = param.InputX(); + const auto input_x_dims = input_x->dims(); + auto* out = param.Out(); + const auto axis = param.Axis(); + const auto* input_x_data = input_x->data(); + auto* out_data = out->mutable_data(); + + size_t ndim = axis.size(); + std::vector xdim(ndim); + std::vector xstride(ndim); + std::vector xout(ndim); + for (int i = 0; i < ndim; i++) { + int j = ndim - 1 - i; + xdim[j] = input_x_dims[axis[i]]; + xstride[j] = 1; + for (int k = axis[i] + 1; k < ndim; k++) { + xstride[j] *= input_x_dims[k]; + } + xout[j] = xstride[j] * xdim[j]; + } + + auto numel = input_x->numel(); + size_t pind = 0; + std::vector ind(ndim); + for (int i = 0; i < numel; i++) { + out_data[i] = input_x_data[pind]; + ind[0]++; + pind += xstride[0]; + for (int j = 0; j < ndim - 1; j++) { + if (ind[j] == xdim[j]) { + ind[j + 1]++; + ind[j] = 0; + pind += xstride[j + 1]; + pind -= xout[j]; + } else { + break; + } + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif