未验证 提交 bbfedb25 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge branch 'develop' into develop

...@@ -61,7 +61,14 @@ struct PaddleMobileException : public std::exception { ...@@ -61,7 +61,14 @@ struct PaddleMobileException : public std::exception {
} }
#else #else
#define PADDLE_MOBILE_THROW_EXCEPTION(...) #define PADDLE_MOBILE_THROW_EXCEPTION(...)
#define PADDLE_MOBILE_ENFORCE(stat, ...)
#define PADDLE_MOBILE_ENFORCE(stat, ...) \
{ \
if (stat) { \
} else { \
} \
}
#endif #endif
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "cstring"
#include "io/paddle_inference_api.h" #include "io/paddle_inference_api.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -30,9 +30,6 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -30,9 +30,6 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs, : framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope), scope),
param_(inputs, outputs, attrs, scope.get()) {} param_(inputs, outputs, attrs, scope.get()) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {}
void InferShape() const { void InferShape() const {
auto out_dims = param_.Out()->dims(); auto out_dims = param_.Out()->dims();
...@@ -40,6 +37,29 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -40,6 +37,29 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
param_.Out()->Resize(out_dims); param_.Out()->Resize(out_dims);
} }
#ifdef PADDLE_MOBILE_FPGA
void RunImpl() const { fpga::PerformBypass(param_.FpgaArgs()); }
void Init() {
const Tensor *input = param_.InputX();
auto input_ptr = input->data<float>();
Tensor *output = param_.Out();
auto output_ptr = output->mutable_data<half>();
fpga::BypassArgs args;
args.convert_type = fpga::DATA_FP32_TO_FP16;
args.layout_type = fpga::LAYOUT_CHW_TO_HWC;
args.image.address = (void *)input_ptr;
args.image.channels = input->dims()[1];
args.image.height = input->dims()[2];
args.image.width = input->dims()[3];
args.output.address = output_ptr;
param_.SetFpgaArgs(args);
}
#else
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() {}
#endif
protected: protected:
FeedParam param_; FeedParam param_;
}; };
...@@ -54,4 +74,5 @@ USE_OP_CPU(feed); ...@@ -54,4 +74,5 @@ USE_OP_CPU(feed);
USE_OP_MALI_GPU(feed); USE_OP_MALI_GPU(feed);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(feed);
#endif #endif
...@@ -25,4 +25,5 @@ REGISTER_OPERATOR_CPU(fetch, ops::FetchOp); ...@@ -25,4 +25,5 @@ REGISTER_OPERATOR_CPU(fetch, ops::FetchOp);
REGISTER_OPERATOR_MALI_GPU(fetch, ops::FetchOp); REGISTER_OPERATOR_MALI_GPU(fetch, ops::FetchOp);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fetch, ops::FetchOp);
#endif #endif
...@@ -54,4 +54,5 @@ USE_OP_CPU(fetch); ...@@ -54,4 +54,5 @@ USE_OP_CPU(fetch);
USE_OP_MALI_GPU(fetch); USE_OP_MALI_GPU(fetch);
#endif #endif
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(fetch);
#endif #endif
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#pragma once #pragma once
#include <cmath> #include <cmath>
#include "framework/tensor.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -122,7 +124,7 @@ void BoxCoderCompute(const BoxCoderParam& param) { ...@@ -122,7 +124,7 @@ void BoxCoderCompute(const BoxCoderParam& param) {
auto col = input_priorbox->dims()[0]; auto col = input_priorbox->dims()[0];
auto len = input_priorbox->dims()[1]; auto len = input_priorbox->dims()[1];
Tensor* output_box = param.OutputBox(); framework::Tensor* output_box = param.OutputBox();
auto* output_box_dataptr = output_box->mutable_data<float>({row, col, len}); auto* output_box_dataptr = output_box->mutable_data<float>({row, col, len});
if (code_type == "encode_center_size") { if (code_type == "encode_center_size") {
......
...@@ -31,12 +31,7 @@ void ConvAddBasic(const FusionConvAddParam &param) { ...@@ -31,12 +31,7 @@ void ConvAddBasic(const FusionConvAddParam &param) {
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
float *output_data = output->data<float>();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
for (int k = 0; k < output->numel(); ++k) {
output_data[k] = biase_data[k];
}
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
...@@ -113,7 +108,7 @@ void ConvAddBasic(const FusionConvAddParam &param) { ...@@ -113,7 +108,7 @@ void ConvAddBasic(const FusionConvAddParam &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(1)); static_cast<float>(1), false, biase_data);
} }
} }
} }
......
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -28,12 +32,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) { ...@@ -28,12 +32,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) {
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
int axis = param.Axis(); int axis = param.Axis();
Tensor *output = param.Output(); Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
float *output_data = output->data<float>();
float *biase_data = bias.data<float>(); float *biase_data = bias.data<float>();
for (int k = 0; k < output->numel(); ++k) {
output_data[k] = biase_data[k];
}
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
...@@ -111,7 +110,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) { ...@@ -111,7 +110,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice,
static_cast<float>(1), true); static_cast<float>(1), true, biase_data);
} }
} }
} }
......
...@@ -17,6 +17,9 @@ limitations under the License. */ ...@@ -17,6 +17,9 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -17,6 +17,9 @@ limitations under the License. */ ...@@ -17,6 +17,9 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#ifdef ELEMENTWISEADD_OP #ifdef ELEMENTWISEADD_OP
#pragma once #pragma once
#include "operators/math/elementwise_op_function.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#ifdef FUSION_FC_OP #ifdef FUSION_FC_OP
#pragma once #pragma once
#include "operators/math/math_function.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -28,6 +30,7 @@ void FusionFcCompute(const FusionFcParam &param) { ...@@ -28,6 +30,7 @@ void FusionFcCompute(const FusionFcParam &param) {
int axis = param.Axis(); int axis = param.Axis();
Tensor *out = param.Out(); Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>(); auto *out_data = out->mutable_data<float>();
float *bias_data = out->mutable_data<float>();
const Tensor x_matrix = const Tensor x_matrix =
input_x->dims().size() > 2 input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) ? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
...@@ -45,18 +48,18 @@ void FusionFcCompute(const FusionFcParam &param) { ...@@ -45,18 +48,18 @@ void FusionFcCompute(const FusionFcParam &param) {
PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0], PADDLE_MOBILE_ENFORCE(out_dim[1] == input_z->dims()[0],
" out_dim.size must be 2."); " out_dim.size must be 2.");
axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis); axis = (axis == -1 ? out_dim.size() - input_z->dims().size() : axis);
PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ") PADDLE_MOBILE_ENFORCE(axis == 1, " to fit broadcast, axis = 1. ");
int64_t classes = input_z->numel(); int64_t classes = input_z->numel();
for (int i = 0; i < out_dim[0]; i++) { for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes); memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
} }
for (int i = 0; i < out->numel(); i++) { // for (int i = 0; i < out->numel(); i++) {
DLOG << out_data[i]; // DLOG << out_data[i];
} // }
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1), math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1)); out, static_cast<float>(1), false, bias_data);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) { // if (out_dim.size() != 2) {
// out->Resize(out_dim); // out->Resize(out_dim);
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#ifdef LRN_OP #ifdef LRN_OP
#pragma once #pragma once
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -19,6 +19,8 @@ limitations under the License. */ ...@@ -19,6 +19,8 @@ limitations under the License. */
#include <map> #include <map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "framework/tensor.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -89,7 +91,8 @@ static inline T JaccardOverlap(const T* box1, const T* box2, ...@@ -89,7 +91,8 @@ static inline T JaccardOverlap(const T* box1, const T* box2,
} }
template <typename T> template <typename T>
static inline void NMSFast(const Tensor& bbox, const Tensor& scores, static inline void NMSFast(const framework::Tensor& bbox,
const framework::Tensor& scores,
const T score_threshold, const T nms_threshold, const T score_threshold, const T nms_threshold,
const T eta, const int64_t top_k, const T eta, const int64_t top_k,
std::vector<int>* selected_indices) { std::vector<int>* selected_indices) {
...@@ -131,7 +134,8 @@ static inline void NMSFast(const Tensor& bbox, const Tensor& scores, ...@@ -131,7 +134,8 @@ static inline void NMSFast(const Tensor& bbox, const Tensor& scores,
} }
template <typename T> template <typename T>
void MultiClassNMS(const Tensor& scores, const Tensor& bboxes, void MultiClassNMS(const framework::Tensor& scores,
const framework::Tensor& bboxes,
std::map<int, std::vector<int>>* indices, int* num_nmsed_out, std::map<int, std::vector<int>>* indices, int* num_nmsed_out,
const int& background_label, const int& nms_top_k, const int& background_label, const int& nms_top_k,
const int& keep_top_k, const T& nms_threshold, const int& keep_top_k, const T& nms_threshold,
...@@ -141,7 +145,7 @@ void MultiClassNMS(const Tensor& scores, const Tensor& bboxes, ...@@ -141,7 +145,7 @@ void MultiClassNMS(const Tensor& scores, const Tensor& bboxes,
int num_det = 0; int num_det = 0;
for (int64_t c = 0; c < class_num; ++c) { for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue; if (c == background_label) continue;
Tensor score = scores.Slice(c, c + 1); framework::Tensor score = scores.Slice(c, c + 1);
/// [c] is key /// [c] is key
NMSFast<float>(bboxes, score, score_threshold, nms_threshold, nms_eta, NMSFast<float>(bboxes, score, score_threshold, nms_threshold, nms_eta,
nms_top_k, &((*indices)[c])); nms_top_k, &((*indices)[c]));
...@@ -181,9 +185,10 @@ void MultiClassNMS(const Tensor& scores, const Tensor& bboxes, ...@@ -181,9 +185,10 @@ void MultiClassNMS(const Tensor& scores, const Tensor& bboxes,
} }
template <typename T> template <typename T>
void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, void MultiClassOutput(const framework::Tensor& scores,
const framework::Tensor& bboxes,
const std::map<int, std::vector<int>>& selected_indices, const std::map<int, std::vector<int>>& selected_indices,
Tensor* outs) { framework::Tensor* outs) {
int predict_dim = scores.dims()[1]; int predict_dim = scores.dims()[1];
auto* scores_data = scores.data<T>(); auto* scores_data = scores.data<T>();
auto* bboxes_data = bboxes.data<T>(); auto* bboxes_data = bboxes.data<T>();
...@@ -231,10 +236,10 @@ void MultiClassNMSCompute(const MultiClassNMSParam& param) { ...@@ -231,10 +236,10 @@ void MultiClassNMSCompute(const MultiClassNMSParam& param) {
std::vector<std::map<int, std::vector<int>>> all_indices; std::vector<std::map<int, std::vector<int>>> all_indices;
std::vector<size_t> batch_starts = {0}; std::vector<size_t> batch_starts = {0};
for (int64_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1); framework::Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim}); ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1); framework::Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim}); ins_boxes.Resize({predict_dim, box_dim});
std::map<int, std::vector<int>> indices; std::map<int, std::vector<int>> indices;
...@@ -253,16 +258,16 @@ void MultiClassNMSCompute(const MultiClassNMSParam& param) { ...@@ -253,16 +258,16 @@ void MultiClassNMSCompute(const MultiClassNMSParam& param) {
} else { } else {
outs->mutable_data<float>({num_kept, kOutputDim}); outs->mutable_data<float>({num_kept, kOutputDim});
for (int64_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
Tensor ins_score = input_scores->Slice(i, i + 1); framework::Tensor ins_score = input_scores->Slice(i, i + 1);
ins_score.Resize({class_num, predict_dim}); ins_score.Resize({class_num, predict_dim});
Tensor ins_boxes = input_bboxes->Slice(i, i + 1); framework::Tensor ins_boxes = input_bboxes->Slice(i, i + 1);
ins_boxes.Resize({predict_dim, box_dim}); ins_boxes.Resize({predict_dim, box_dim});
int64_t s = batch_starts[i]; int64_t s = batch_starts[i];
int64_t e = batch_starts[i + 1]; int64_t e = batch_starts[i + 1];
if (e > s) { if (e > s) {
Tensor out = outs->Slice(s, e); framework::Tensor out = outs->Slice(s, e);
MultiClassOutput<float>(ins_score, ins_boxes, all_indices[i], &out); MultiClassOutput<float>(ins_score, ins_boxes, all_indices[i], &out);
} }
} }
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#pragma once #pragma once
#include <operators/math/transform.h> #include <operators/math/transform.h>
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/kernel/reshape_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef SOFTMAX_OP #ifdef SOFTMAX_OP
#pragma once #pragma once
#include "../../math/softmax.h" #include "../../math/softmax.h"
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
......
...@@ -24,13 +24,13 @@ template <> ...@@ -24,13 +24,13 @@ template <>
bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) { bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) {
bool relu_enabled = false; bool relu_enabled = false;
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<float>(); auto input_ptr = input->data<half>();
const Tensor *bias = param->Bias(); const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>(); auto bias_ptr = bias->data<float>();
const Tensor *filter = param->Filter(); const Tensor *filter = param->Filter();
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<float>();
Tensor *out = param->Output(); Tensor *out = param->Output();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
auto bn_mean_ptr = param->InputMean()->data<float>(); auto bn_mean_ptr = param->InputMean()->data<float>();
auto bn_var_ptr = param->InputVariance()->data<float>(); auto bn_var_ptr = param->InputVariance()->data<float>();
auto bn_scale_ptr = param->InputScale()->data<float>(); auto bn_scale_ptr = param->InputScale()->data<float>();
......
...@@ -24,13 +24,13 @@ template <> ...@@ -24,13 +24,13 @@ template <>
bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) { bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<float>(); auto input_ptr = input->data<half>();
const Tensor *bias = param->Bias(); const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>(); auto bias_ptr = bias->data<float>();
const Tensor *filter = param->Filter(); const Tensor *filter = param->Filter();
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<float>();
Tensor *out = param->Output(); Tensor *out = param->Output();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
auto bn_mean_ptr = param->InputMean()->data<float>(); auto bn_mean_ptr = param->InputMean()->data<float>();
auto bn_var_ptr = param->InputVariance()->data<float>(); auto bn_var_ptr = param->InputVariance()->data<float>();
auto bn_scale_ptr = param->InputScale()->data<float>(); auto bn_scale_ptr = param->InputScale()->data<float>();
......
...@@ -24,13 +24,13 @@ template <> ...@@ -24,13 +24,13 @@ template <>
bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) { bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<float>(); auto input_ptr = input->data<half>();
const Tensor *bias = param->Bias(); const Tensor *bias = param->Bias();
auto bias_ptr = bias->data<float>(); auto bias_ptr = bias->data<float>();
const Tensor *filter = param->Filter(); const Tensor *filter = param->Filter();
auto filter_ptr = filter->data<float>(); auto filter_ptr = filter->data<float>();
Tensor *out = param->Output(); Tensor *out = param->Output();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input->dims()[1] == bias->dims()[0], PADDLE_MOBILE_ENFORCE(input->dims()[1] == bias->dims()[0],
"Image channel should be equal to bias number"); "Image channel should be equal to bias number");
......
...@@ -25,9 +25,9 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init( ...@@ -25,9 +25,9 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init(
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<half>();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
fpga::EWAddArgs ewaddArgs; fpga::EWAddArgs ewaddArgs;
ewaddArgs.relu_enabled = relu_enabled; ewaddArgs.relu_enabled = relu_enabled;
......
...@@ -22,13 +22,13 @@ template <> ...@@ -22,13 +22,13 @@ template <>
bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) { bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
......
...@@ -22,13 +22,13 @@ template <> ...@@ -22,13 +22,13 @@ template <>
bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
bool relu_enabled = false; bool relu_enabled = false;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<half>();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
auto input_y_ptr = input_y->data<float>(); auto input_y_ptr = input_y->data<float>();
const Tensor *input_z = param->InputZ(); const Tensor *input_z = param->InputZ();
auto input_z_ptr = input_z->data<float>(); auto input_z_ptr = input_z->data<float>();
Tensor *out = param->Out(); Tensor *out = param->Out();
auto out_ptr = out->mutable_data<float>(); auto out_ptr = out->mutable_data<half>();
PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0], PADDLE_MOBILE_ENFORCE(input_x->dims()[1] == input_y->dims()[0],
"Image channel should be equal to weight number"); "Image channel should be equal to weight number");
......
...@@ -22,9 +22,9 @@ namespace operators { ...@@ -22,9 +22,9 @@ namespace operators {
template <> template <>
bool PoolKernel<FPGA, float>::Init(PoolParam *param) { bool PoolKernel<FPGA, float>::Init(PoolParam *param) {
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<float>(); auto input_ptr = input->data<half>();
Tensor *output = param->Output(); Tensor *output = param->Output();
auto output_ptr = output->mutable_data<float>(); auto output_ptr = output->mutable_data<half>();
vector<int> ksize = param->Ksize(); vector<int> ksize = param->Ksize();
vector<int> strides = param->Strides(); vector<int> strides = param->Strides();
vector<int> paddings = param->Paddings(); vector<int> paddings = param->Paddings();
......
...@@ -529,42 +529,42 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -529,42 +529,42 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]); const int input_channel = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int hxw = input_height * input_width;
const int l = input_height;
float32x4_t vnewbias = vdupq_n_f32(0.0); float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0); float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0); float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; b++) {
const float *filter_data_tmp = filter_data; filter_data = filter->data<float>();
for (int c = 0; c < input_channel; c++) {
for (int j = 0; j < c; ++j) { vnewbias = vdupq_n_f32(newbias_data[c]);
vnewbias = vdupq_n_f32(newbias_data[j]); vnewscale = vdupq_n_f32(newscale_data[c]);
vnewscale = vdupq_n_f32(newscale_data[j]);
float w00 = filter_data[0];
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 float w01 = filter_data[1];
float w00 = filter_data_tmp[0]; float w02 = filter_data[2];
float w01 = filter_data_tmp[1]; float w10 = filter_data[3];
float w02 = filter_data_tmp[2]; float w11 = filter_data[4];
float w10 = filter_data_tmp[3]; float w12 = filter_data[5];
float w11 = filter_data_tmp[4]; float w20 = filter_data[6];
float w12 = filter_data_tmp[5]; float w21 = filter_data[7];
float w20 = filter_data_tmp[6]; float w22 = filter_data[8];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] + output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1]; w21 * input_data[l] + w22 * input_data[l + 1];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] + w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1]; w21 * input_data[2 * l - 1];
output_data[(l - 1) * l] = output_data[(l - 1) * l] =
w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1];
...@@ -572,13 +572,13 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -572,13 +572,13 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
w01 * input_data[(l - 2) * (l + 1) + 1] + w01 * input_data[(l - 2) * (l + 1) + 1] +
w10 * input_data[l * l - 2] + w10 * input_data[l * l - 2] +
w11 * input_data[l * l - 1]; w11 * input_data[l * l - 1];
output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j]; output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c];
output_data[l - 1] = output_data[l - 1] =
output_data[l - 1] * newscale_data[j] + newbias_data[j]; output_data[l - 1] * newscale_data[c] + newbias_data[c];
output_data[(l - 1) * l] = output_data[(l - 1) * l] =
output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; output_data[(l - 1) * l] * newscale_data[c] + newbias_data[c];
output_data[l * l - 1] = output_data[l * l - 1] =
output_data[l * l - 1] * newscale_data[j] + newbias_data[j]; output_data[l * l - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) { if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
...@@ -593,6 +593,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -593,6 +593,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] + w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] + w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] + w10 * input_data[i * l + l - 1 - 1] +
...@@ -600,9 +601,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -600,9 +601,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
w20 * input_data[i * l + l - 1 + l - 1] + w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l]; w21 * input_data[i * l + l - 1 + l];
output_data[i * l] = output_data[i * l] =
output_data[i * l] * newscale_data[j] + newbias_data[j]; output_data[i * l] * newscale_data[c] + newbias_data[c];
output_data[i * l + l - 1] = output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j]; output_data[i * l + l - 1] * newscale_data[c] + newbias_data[c];
if (if_relu) { if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l];
...@@ -611,28 +612,19 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -611,28 +612,19 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
} }
} }
// top 1 row and bottom 1 row int m;
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, out0;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + l + 4);
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = output_data + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + m - 1);
in1 = vld1q_f32(input_data + m + 3);
in2 = vld1q_f32(input_data + input_width + m - 1);
in3 = vld1q_f32(input_data + input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1); tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2); tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1); tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2); tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10); out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11); out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12); out0 = vmlaq_n_f32(out0, tmp1, w12);
...@@ -644,182 +636,438 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -644,182 +636,438 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
vst1q_f32(output_ptr, out0); vst1q_f32(output_ptr, out0);
}
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 +
input_data[j + 1] * w12 +
input_data[input_width + j - 1] * w20 +
input_data[input_width + j] * w21 +
input_data[input_width + j + 1] * w22;
output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c];
in5 = vld1q_f32(input_tmp_end + 4); if (if_relu) {
in7 = vld1q_f32(input_tmp_end + l + 4); output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
}
}
tmp0 = vextq_f32(in4, in5, 1); for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
tmp1 = vextq_f32(in4, in5, 2); float *output_ptr =
tmp2 = vextq_f32(in6, in7, 1); output_data + (output_height - 1) * output_width + m;
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00); float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1);
in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3);
in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1);
in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02); out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in6, w10); out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
vst1q_f32(output_ptr + (l - 1) * l, out0); vst1q_f32(output_ptr, out0);
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
} }
for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
}
for (int j = m; j < output_width - 1; j++) {
output_data[(output_height - 1) * input_width + j] =
input_data[(output_height - 2) * input_width + j - 1] * w00 +
input_data[(output_height - 2) * input_width + j] * w01 +
input_data[(output_height - 2) * input_width + j + 1] * w02 +
input_data[(output_height - 1) * input_width + j - 1] * w10 +
input_data[(output_height - 1) * input_width + j] * w11 +
input_data[(output_height - 1) * input_width + j + 1] * w12;
output_data[(output_height - 1) * output_width + j] =
output_data[(output_height - 1) * output_width + j] *
newscale_data[c] +
newbias_data[c];
// top right pad if (if_relu) {
float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); output_data[(output_height - 1) * output_width + j] =
float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); output_data[(output_height - 1) * output_width + j] < 0
? 0
tmp0 = vextq_f32(in0, pad0, 1); : output_data[(output_height - 1) * output_width + j];
tmp1 = vextq_f32(in0, pad0, 2); }
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
} }
for (int i = 0; i < c_mid; ++i) { #pragma omp parallel for
if (i == 0) { for (int i = 1; i < output_height - 1; i++) {
vst1q_lane_f32(output_ptr + i, out0, 0); for (int m = 1; (m + 3) < output_width - 1; m = m + 4) {
float *output_ptr = output_data + i * output_width + m;
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3,
tmp4, tmp5, out0;
in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1);
in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3);
in2 = vld1q_f32(input_data + i * input_width + m - 1);
in3 = vld1q_f32(input_data + i * input_width + m + 3);
in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1);
in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
} }
if (i == 1) { int m;
vst1q_lane_f32(output_ptr + i, out0, 1); for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
} }
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2); for (int j = m; j < output_width - 1; j++) {
output_data[i * output_width + j] =
input_data[(i - 1) * input_width + j - 1] * w00 +
input_data[(i - 1) * input_width + j] * w01 +
input_data[(i - 1) * input_width + j + 1] * w02 +
input_data[(i)*input_width + j - 1] * w10 +
input_data[(i)*input_width + j] * w11 +
input_data[(i)*input_width + j + 1] * w12 +
input_data[(i + 1) * input_width + j - 1] * w20 +
input_data[(i + 1) * input_width + j] * w21 +
input_data[(i + 1) * input_width + j + 1] * w22;
output_data[i * output_width + j] =
newscale_data[c] * output_data[i * output_width + j] +
newbias_data[c];
if (if_relu) {
output_data[i * output_width + j] =
output_data[i * output_width + j] < 0
? 0
: output_data[i * output_width + j];
}
} }
} }
// bottom right pad input_data = input_data + hxw;
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); output_data = output_data + hxw;
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); filter_data = filter_data + 9;
}
}
tmp0 = vextq_f32(in4, pad2, 1); /*
tmp1 = vextq_f32(in4, pad2, 2); const float *input_data = input->data<float>();
tmp2 = vextq_f32(in6, pad3, 1); const float *filter_data = filter->data<float>();
tmp3 = vextq_f32(in6, pad3, 2); float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>();
const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]);
const int l = h;
const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]);
const int hxw = h * w;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t vzero = vdupq_n_f32(0);
for (int b = 0; b < batch_size; ++b) {
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]);
int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0
float w00 = filter_data_tmp[0];
float w01 = filter_data_tmp[1];
float w02 = filter_data_tmp[2];
float w10 = filter_data_tmp[3];
float w11 = filter_data_tmp[4];
float w12 = filter_data_tmp[5];
float w20 = filter_data_tmp[6];
float w21 = filter_data_tmp[7];
float w22 = filter_data_tmp[8];
output_data[0] = w11 * input_data[0] + w12 * input_data[1] +
w21 * input_data[l] + w22 * input_data[l + 1];
output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] +
w20 * input_data[2 * l - 2] +
w21 * input_data[2 * l - 1];
out0 = vmulq_n_f32(in4, w00); output_data[(l - 1) * l] =
out0 = vmlaq_n_f32(out0, tmp0, w01); w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] +
out0 = vmlaq_n_f32(out0, tmp1, w02); w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1];
out0 = vmlaq_n_f32(out0, in6, w10); output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] +
out0 = vmlaq_n_f32(out0, tmp2, w11); w01 * input_data[(l - 2) * (l + 1) + 1] +
out0 = vmlaq_n_f32(out0, tmp3, w12); w10 * input_data[l * l - 2] +
out0 = vmlaq_f32(vnewbias, vnewscale, out0); w11 * input_data[l * l - 1];
if (if_relu) { output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j];
out0 = vmaxq_f32(out0, vzero); output_data[l - 1] =
} output_data[l - 1] * newscale_data[j] + newbias_data[j];
for (int i = 0; i < c_mid; ++i) { output_data[(l - 1) * l] =
if (i == 0) { output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j];
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); output_data[l * l - 1] =
} output_data[l * l - 1] * newscale_data[j] + newbias_data[j];
if (i == 1) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); if (if_relu) {
output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1];
output_data[(l - 1) * l] =
output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l];
output_data[l * l - 1] =
output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1];
} }
if (i == 2) { for (int i = 1; i < l - 1; ++i) {
vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); output_data[i * l] =
w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] +
w11 * input_data[i * l] + w12 * input_data[i * l + 1] +
w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1];
output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] +
w01 * input_data[i * l + l - 1 - l] +
w10 * input_data[i * l + l - 1 - 1] +
w11 * input_data[i * l + l - 1] +
w20 * input_data[i * l + l - 1 + l - 1] +
w21 * input_data[i * l + l - 1 + l];
output_data[i * l] =
output_data[i * l] * newscale_data[j] + newbias_data[j];
output_data[i * l + l - 1] =
output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j];
if (if_relu) {
output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i *
l]; output_data[i * l + l - 1] =
output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1];
}
} }
}
// mid
for (int i = 0; i < l - 2; ++i) { // top 1 row and bottom 1 row
auto output_ptr = output_data + (i + 1) * l + 1; const float *input_tmp = input_data;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp); float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
auto in2_tmp = vld1q_f32(input_tmp + l); tmp3, tmp4, tmp5, out0;
auto in4_tmp = vld1q_f32(input_tmp + l + l); in0 = vld1q_f32(input_tmp);
c_mid = l_mid; in2 = vld1q_f32(input_tmp + l);
const float *input_tmp_end = input_tmp + (l - 2) * l;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + l);
int c_mid = l_mid;
auto output_ptr = output_data + 1;
for (; c_mid > 3; c_mid -= 4) { for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4); in1 = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4); in3 = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00); tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + l + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02); out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10); out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
vst1q_f32(output_ptr, out0); vst1q_f32(output_ptr + (l - 1) * l, out0);
output_ptr += 4; // can optimize to each 8 stride.
input_tmp += 4; input_tmp += 4;
in0_tmp = in1_tmp; input_tmp_end += 4;
in2_tmp = in3_tmp; output_ptr += 4;
in4_tmp = in5_tmp; in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
} }
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); // top right pad
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]);
tmp0 = vextq_f32(in0_tmp, pad0, 1); tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2); tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1); tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2); tmp3 = vextq_f32(in2, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00); out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_n_f32(out0, tmp3, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom right pad
float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]);
float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
out0 = vmulq_n_f32(in4, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02); out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10); out0 = vmlaq_n_f32(out0, in6, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, vzero); out0 = vmaxq_f32(out0, vzero);
} }
for (int i = 0; i < c_mid; ++i) { for (int i = 0; i < c_mid; ++i) {
if (i == 0) { if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0); vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0);
} }
if (i == 1) { if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1); vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1);
} }
if (i == 2) { if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2); vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2);
}
}
// mid
for (int i = 0; i < l - 2; ++i) {
auto output_ptr = output_data + (i + 1) * l + 1;
input_tmp = input_data + i * l;
auto in0_tmp = vld1q_f32(input_tmp);
auto in2_tmp = vld1q_f32(input_tmp + l);
auto in4_tmp = vld1q_f32(input_tmp + l + l);
c_mid = l_mid;
for (; c_mid > 3; c_mid -= 4) {
auto in1_tmp = vld1q_f32(input_tmp + 4);
auto in3_tmp = vld1q_f32(input_tmp + l + 4);
auto in5_tmp = vld1q_f32(input_tmp + l + l + 4);
tmp0 = vextq_f32(in0_tmp, in1_tmp, 1);
tmp1 = vextq_f32(in0_tmp, in1_tmp, 2);
tmp2 = vextq_f32(in2_tmp, in3_tmp, 1);
tmp3 = vextq_f32(in2_tmp, in3_tmp, 2);
tmp4 = vextq_f32(in4_tmp, in5_tmp, 1);
tmp5 = vextq_f32(in4_tmp, in5_tmp, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0_tmp = in1_tmp;
in2_tmp = in3_tmp;
in4_tmp = in5_tmp;
}
float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]);
float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]);
tmp0 = vextq_f32(in0_tmp, pad0, 1);
tmp1 = vextq_f32(in0_tmp, pad0, 2);
tmp2 = vextq_f32(in2_tmp, pad1, 1);
tmp3 = vextq_f32(in2_tmp, pad1, 2);
tmp4 = vextq_f32(in4_tmp, pad2, 1);
tmp5 = vextq_f32(in4_tmp, pad2, 2);
out0 = vmulq_n_f32(in0_tmp, w00);
out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, in2_tmp, w10);
out0 = vmlaq_n_f32(out0, tmp2, w11);
out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmlaq_n_f32(out0, in4_tmp, w20);
out0 = vmlaq_n_f32(out0, tmp4, w21);
out0 = vmlaq_n_f32(out0, tmp5, w22);
out0 = vmlaq_f32(vnewbias, vnewscale, out0);
if (if_relu) {
out0 = vmaxq_f32(out0, vzero);
}
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
} }
} }
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
} }
output_data += hxw;
input_data += hxw;
filter_data_tmp += 9;
} }
} */
#endif #endif
} }
......
...@@ -33,6 +33,14 @@ float *packedA; ...@@ -33,6 +33,14 @@ float *packedA;
float *packedB; float *packedB;
float *packedC; float *packedC;
float *zero; float *zero;
typedef void (*FnPack)(int, int, int, const float *, int, float *);
typedef void (*FnAddDot)(int, const float *, const float *, float *, int);
FnPack procPackA;
FnPack procPackB;
FnAddDot procAddDot;
/* /*
// 将A矩阵分块复制到连续内存(ColMajor) // 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
...@@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, ...@@ -135,30 +143,32 @@ void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5; const int i_length = m - m_tail;
for (int i = 0; i < m - m_tail; i += MR) { for (int i = 0; i < i_length; i += MR) {
a0 = A + i * lda; const float *a0 = A + i * lda;
a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
} }
} }
if (m_tail != 0) { if (m_tail != 0) {
a0 = &A(m - m_tail, 0); const float *a0 = &A(i_length, 0);
a1 = a0 + lda; const float *a1 = a0 + lda;
a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) { switch (m_tail) {
case 1: case 1:
a1 = zero; a1 = zero;
...@@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -175,48 +185,105 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
break; break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
}
}
}
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const int i_length = m - m_tail;
#pragma omp parallel for
for (int i = 0; i < i_length; i += MR) {
const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
}
if (m_tail != 0) {
const float *a0 = &A(i_length, 0);
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
} }
} }
} }
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
const float *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; const int i_length = m - m_tail;
for (int i = 0; i < m - m_tail; i += MR) { for (int i = 0; i < i_length; i += MR) {
a0 = A + i * lda; const float *a0 = A + i * lda;
a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
a6 = A + (i + 6) * lda; const float *a6 = A + (i + 6) * lda;
a7 = A + (i + 7) * lda; const float *a7 = A + (i + 7) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
*buffer++ = *a6++; *local_buffer++ = *a6++;
*buffer++ = *a7++; *local_buffer++ = *a7++;
} }
} }
if (m_tail != 0) { if (m_tail != 0) {
a0 = &A(m - m_tail, 0); const float *a0 = &A(i_length, 0);
a1 = a0 + lda; const float *a1 = a0 + lda;
a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
a6 = a0 + 6 * lda; const float *a6 = a0 + 6 * lda;
a7 = a0 + 7 * lda; const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) { switch (m_tail) {
case 1: case 1:
a1 = zero; a1 = zero;
...@@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, ...@@ -237,14 +304,81 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
break; break;
} }
for (int j = 0; j < k; ++j) { for (int j = 0; j < k; ++j) {
*buffer++ = *a0++; *local_buffer++ = *a0++;
*buffer++ = *a1++; *local_buffer++ = *a1++;
*buffer++ = *a2++; *local_buffer++ = *a2++;
*buffer++ = *a3++; *local_buffer++ = *a3++;
*buffer++ = *a4++; *local_buffer++ = *a4++;
*buffer++ = *a5++; *local_buffer++ = *a5++;
*buffer++ = *a6++; *local_buffer++ = *a6++;
*buffer++ = *a7++; *local_buffer++ = *a7++;
}
}
}
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const int i_length = m - m_tail;
#pragma omp parallel for
for (int i = 0; i < i_length; i += MR) {
const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda;
const float *a6 = A + (i + 6) * lda;
const float *a7 = A + (i + 7) * lda;
float *local_buffer = buffer + i * k;
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
}
}
if (m_tail != 0) {
const float *a0 = &A(i_length, 0);
const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda;
const float *a6 = a0 + 6 * lda;
const float *a7 = a0 + 7 * lda;
float *local_buffer = buffer + i_length * k;
switch (m_tail) {
case 1:
a1 = zero;
case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
case 6:
a6 = zero;
case 7:
a7 = zero;
break;
default:
break;
}
for (int j = 0; j < k; ++j) {
*local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
*local_buffer++ = *a6++;
*local_buffer++ = *a7++;
} }
} }
} }
...@@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, ...@@ -252,48 +386,102 @@ void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
// 将B矩阵分块复制到连续内存(RowMajor) // 将B矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
asm volatile( asm volatile(
"prfm pldl1keep, [%[b0]] \n\t" "prfm pldl1keep, [%[b0]] \n\t"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s}, [%[buffer]], #32 \n\t" "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1"); : "memory", "v0", "v1");
#else #else
asm volatile( asm volatile(
"pld [%[b0]] \n\t" "pld [%[b0]] \n\t"
"vld1.32 {q0, q1}, [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]] \n\t"
"vst1.32 {q0, q1}, [%[buffer]]! \n\t" "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "q0", "q1"); : "memory", "q0", "q1");
#endif // __aarch64__ #endif // __aarch64__
#else #else
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
*buffer++ = *b0++; *local_buffer++ = *b0++;
#endif // __ARM_NEON #endif // __ARM_NEON
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%[b0]] \n\t"
"ld1 {v0.4s, v1.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1");
#else
asm volatile(
"pld [%[b0]] \n\t"
"vld1.32 {q0, q1}, [%[b0]] \n\t"
"vst1.32 {q0, q1}, [%[local_buffer]]! \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "q0", "q1");
#endif // __aarch64__
#else
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
*local_buffer++ = *b0++;
#endif // __ARM_NEON
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -302,27 +490,60 @@ void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
#if __aarch64__ #if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
asm volatile( asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t" "prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[buffer]], #48 \n\t" "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1", "v2"); : "memory", "v0", "v1", "v2");
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < j_length; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -330,27 +551,60 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
const float *b0; const int j_length = n - n_tail;
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); const float *b0 = &B(i, j);
asm volatile( asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t" "prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[buffer]], #64 \n\t" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t"
: [buffer] "+r"(buffer) : [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3"); : "memory", "v0", "v1", "v2", "v3");
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, n - n_tail); const float *b0 = &B(i, j_length);
for (int j = n - n_tail; j < n; ++j) { for (int j = j_length; j < n; ++j) {
*buffer++ = *b0++; *local_buffer++ = *b0++;
} }
for (int j = n; j < n + (NR - n_tail); ++j) { for (int j = n; j < j_length + NR; ++j) {
*buffer++ = 0; *local_buffer++ = 0;
}
}
}
}
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const int j_length = n - n_tail;
#pragma omp parallel for
for (int j = 0; j < n - n_tail; j += NR) {
float *local_buffer = buffer + j * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j);
asm volatile(
"prfm pldl2keep, [%[b0], #64] \n\t"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t"
: [local_buffer] "+r"(local_buffer)
: [b0] "r"(b0)
: "memory", "v0", "v1", "v2", "v3");
}
}
if (n_tail != 0) {
float *local_buffer = buffer + j_length * k;
for (int i = 0; i < k; ++i) {
const float *b0 = &B(i, j_length);
for (int j = j_length; j < n; ++j) {
*local_buffer++ = *b0++;
}
for (int j = n; j < j_length + NR; ++j) {
*local_buffer++ = 0;
} }
} }
} }
...@@ -392,6 +646,42 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, ...@@ -392,6 +646,42 @@ void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
} }
} }
// 分块矩阵乘法
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias) {
#pragma omp parallel for
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
#if __aarch64__
// AddDot8x12(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x16(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#else
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
#endif
}
}
if (alpha != 1) {
WriteWithAlphaBeta(mc, nc, c, C, ldc);
return;
}
if (beta == 0) {
WriteBasic(mc, nc, c, C, ldc);
return;
}
if (beta == 1 && !relu) {
WriteWithAddV1(mc, nc, c, C, ldc, bias);
return;
}
if (beta == 1 && relu) {
WriteWithAddReluV1(mc, nc, c, C, ldc, bias);
return;
}
}
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc, const float *b, float beta, float *c, float *C, int ldc,
...@@ -577,6 +867,43 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { ...@@ -577,6 +867,43 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
} }
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
...@@ -619,6 +946,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { ...@@ -619,6 +946,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias) { float *new_bias) {
...@@ -1448,6 +1817,44 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { ...@@ -1448,6 +1817,44 @@ void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16; int nc1 = nc / 16;
...@@ -1522,6 +1929,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { ...@@ -1522,6 +1929,48 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
} }
} }
// C = A * B + bias, relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) { float *bias) {
...@@ -2049,11 +2498,33 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -2049,11 +2498,33 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
} }
} }
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {}
void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias) {}
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {}
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias) {}
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias) {}
#endif // __ARM_NEON #endif // __ARM_NEON
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu) { const float *B, int ldb, float beta, float *C, int ldc, bool relu,
float *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster) // L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024; int L1 = 32 * 1024;
...@@ -2103,8 +2574,8 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -2103,8 +2574,8 @@ void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
#else #else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif #endif
InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc, InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
relu); &C(i, j), ldc, relu, bias + i);
} }
} }
...@@ -2177,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -2177,6 +2648,221 @@ void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(zero); paddle_mobile::memory::Free(zero);
} }
// 32位 float 矩阵乘法
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
int L1 = 32 * 1024;
KC = k;
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_8c;
procAddDot = AddDot6x8;
#endif
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
procPackB(KC, NC, NC % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_8c;
procAddDot = AddDot6x8;
#endif
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
procPackA(MC, KC, MC % MR, A, lda, packedA);
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int i = 0; i < m; i += MC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int mc;
mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A);
InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc;
nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
int L1 = 32 * 1024;
KC = k;
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_6r;
procPackB = PackMatrixB_omp_8c;
procAddDot = AddDot6x8;
#endif
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
procPackB(KC, NC, NC % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_16c;
procAddDot = AddDot6x16;
#else
procPackA = PackMatrixA_omp_6r;
procPackB = PackMatrixB_8c;
procAddDot = AddDot6x8;
#endif
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
procPackA(MC, KC, MC % MR, A, lda, packedA);
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) {
#pragma omp parallel for
for (int i = 0; i < m; i += MC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int mc;
mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackA(mc, KC, mc % MR, &A(i, 0), lda, local_A);
InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0),
ldc, relu, new_scale + i, new_bias + i);
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc;
nc = s_min(n - j, NC);
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
procPackB(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j),
ldc, relu, new_scale, new_bias);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
......
...@@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, ...@@ -50,6 +50,10 @@ void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor) // 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
...@@ -58,10 +62,19 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -58,10 +62,19 @@ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu); float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBias(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *bias);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc, const float *b, float beta, float *c, float *C, int ldc,
...@@ -91,8 +104,13 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc); ...@@ -91,8 +104,13 @@ void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc); void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C // C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc); void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias
void WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, float *bias);
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc); void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + bias ,relu(C)
void WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias);
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale, void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias); float *new_bias);
...@@ -120,13 +138,24 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, ...@@ -120,13 +138,24 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu); const float *B, int ldb, float beta, float *C, int ldc, bool relu,
float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom // 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda, void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias); bool relu, float *new_scale, float *new_bias);
// 32位 float 矩阵乘法(openmp 多线程版本)
void Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom(openmp 多线程版本)
void SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -22,7 +22,8 @@ namespace math { ...@@ -22,7 +22,8 @@ namespace math {
template <> template <>
void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha, const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu) { framework::Tensor *matrix_out, float beta, bool relu,
float *bias) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
...@@ -41,8 +42,13 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -41,8 +42,13 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1]; int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, bias);
#else
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N, Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu); beta, matrix_out->data<float>(), N, relu, bias);
#endif
} }
template <> template <>
...@@ -69,10 +75,17 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -69,10 +75,17 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
int N = dim_out[1]; int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0]; int K = (!trans_a) ? dim_a[1] : dim_a[0];
#ifdef _OPENMP
SgemmWithBn_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, new_scale->data<float>() + group,
new_bias->data<float>() + group);
#else
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu, N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>() + group, new_scale->data<float>() + group,
new_bias->data<float>() + group); new_bias->data<float>() + group);
#endif
} }
} // namespace math } // namespace math
......
...@@ -21,11 +21,11 @@ namespace paddle_mobile { ...@@ -21,11 +21,11 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
// matrix multiply with continuous memory
template <typename T> template <typename T>
void matmul(const framework::Tensor &matrix_a, bool trans_a, void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha, const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false); framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr);
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
...@@ -665,6 +665,16 @@ class FeedParam : public OpParam { ...@@ -665,6 +665,16 @@ class FeedParam : public OpParam {
Tensor *input_x_; Tensor *input_x_;
Tensor *out_; Tensor *out_;
int batch_size; int batch_size;
#ifdef PADDLE_MOBILE_FPGA
private:
fpga::BypassArgs fpga_bypass_args;
public:
const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; }
void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; }
#endif
}; };
class FetchParam : public OpParam { class FetchParam : public OpParam {
......
...@@ -49,9 +49,9 @@ int main() { ...@@ -49,9 +49,9 @@ int main() {
auto time1 = time(); auto time1 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float>(aa, false, bb, false, paddle_mobile::operators::math::matmul<float>(
static_cast<float>(1), &cc, aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
static_cast<float>(0), false); false, biasptr);
// paddle_mobile::operators::math::matmulWithBn<float>( // paddle_mobile::operators::math::matmulWithBn<float>(
// aa, false, bb, false, static_cast<float>(1), &cc, // aa, false, bb, false, static_cast<float>(1), &cc,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册