未验证 提交 87ac1969 编写于 作者: H Houjiang Chen 提交者: GitHub

Merge branch 'develop' into ocr_attention

...@@ -83,6 +83,8 @@ const char *G_OP_TYPE_LOGICAL_NOT = "logical_not"; ...@@ -83,6 +83,8 @@ const char *G_OP_TYPE_LOGICAL_NOT = "logical_not";
const char *G_OP_TYPE_LOGICAL_XOR = "logical_xor"; const char *G_OP_TYPE_LOGICAL_XOR = "logical_xor";
const char *G_OP_TYPE_WRITE_TO_ARRAY = "write_to_array"; const char *G_OP_TYPE_WRITE_TO_ARRAY = "write_to_array";
const char *G_OP_TYPE_READ_FROM_ARRAY = "read_from_array"; const char *G_OP_TYPE_READ_FROM_ARRAY = "read_from_array";
const char *G_OP_TYPE_IS_EMPTY = "is_empty";
const char *G_OP_TYPE_INCREMENT = "increment";
const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
...@@ -200,6 +202,8 @@ std::unordered_map< ...@@ -200,6 +202,8 @@ std::unordered_map<
{G_OP_TYPE_LOGICAL_NOT, {{"X"}, {"Out"}}}, {G_OP_TYPE_LOGICAL_NOT, {{"X"}, {"Out"}}},
{G_OP_TYPE_WRITE_TO_ARRAY, {{"X", "I"}, {"Out"}}}, {G_OP_TYPE_WRITE_TO_ARRAY, {{"X", "I"}, {"Out"}}},
{G_OP_TYPE_READ_FROM_ARRAY, {{"X", "I"}, {"Out"}}}, {G_OP_TYPE_READ_FROM_ARRAY, {{"X", "I"}, {"Out"}}},
{G_OP_TYPE_IS_EMPTY, {{"X"}, {"Out"}}},
{G_OP_TYPE_INCREMENT, {{"X"}, {"Out"}}},
{G_OP_TYPE_SLICE, {{"Input"}, {"Out"}}}, {G_OP_TYPE_SLICE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_ANCHOR_GENERATOR, {{"Input"}, {"Anchors", "Variances"}}}, {G_OP_TYPE_ANCHOR_GENERATOR, {{"Input"}, {"Anchors", "Variances"}}},
{G_OP_TYPE_GENERATE_PROPOSALS, {G_OP_TYPE_GENERATE_PROPOSALS,
......
...@@ -172,6 +172,8 @@ extern const char *G_OP_TYPE_LOGICAL_NOT; ...@@ -172,6 +172,8 @@ extern const char *G_OP_TYPE_LOGICAL_NOT;
extern const char *G_OP_TYPE_LOGICAL_XOR; extern const char *G_OP_TYPE_LOGICAL_XOR;
extern const char *G_OP_TYPE_WRITE_TO_ARRAY; extern const char *G_OP_TYPE_WRITE_TO_ARRAY;
extern const char *G_OP_TYPE_READ_FROM_ARRAY; extern const char *G_OP_TYPE_READ_FROM_ARRAY;
extern const char *G_OP_TYPE_IS_EMPTY;
extern const char *G_OP_TYPE_INCREMENT;
extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE;
......
...@@ -306,6 +306,12 @@ LOAD_OP1(write_to_array, CPU); ...@@ -306,6 +306,12 @@ LOAD_OP1(write_to_array, CPU);
#ifdef READ_FROM_ARRAY_OP #ifdef READ_FROM_ARRAY_OP
LOAD_OP1(read_from_array, CPU); LOAD_OP1(read_from_array, CPU);
#endif #endif
#ifdef IS_EMPTY_OP
LOAD_OP1(is_empty, CPU);
#endif
#ifdef INCREMENT_OP
LOAD_OP1(increment, CPU);
#endif
#ifdef ANCHOR_GENERATOR_OP #ifdef ANCHOR_GENERATOR_OP
LOAD_OP1(anchor_generator, CPU); LOAD_OP1(anchor_generator, CPU);
#endif #endif
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#include "operators/increment_op.h"
#include "framework/op_proto_maker.h"
#include "framework/op_registry.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void IncrementOp<Dtype, T>::InferShape() const {
auto input = this->param_.InputX();
auto out = this->param_.Out();
PADDLE_MOBILE_ENFORCE(input->numel() == 1, "input's numel should be 1");
out->Resize(input->dims());
out->set_lod(input->lod());
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(increment, ops::IncrementOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/increment_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class IncrementOp
: public framework::OperatorWithKernel<DeviceType,
IncrementParam<DeviceType>,
IncrementKernel<DeviceType, T>> {
public:
IncrementOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, IncrementParam<DeviceType>,
IncrementKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IS_EMPTY_OP
#include "operators/is_empty_op.h"
#include "framework/op_proto_maker.h"
#include "framework/op_registry.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void IsEmptyOp<Dtype, T>::InferShape() const {
auto out = this->param_.Out();
out->Resize({1});
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(is_empty, ops::IsEmptyOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IS_EMPTY_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/is_empty_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class IsEmptyOp
: public framework::OperatorWithKernel<DeviceType, IsEmptyParam<DeviceType>,
IsEmptyKernel<DeviceType, T>> {
public:
IsEmptyOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, IsEmptyParam<DeviceType>,
IsEmptyKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
protected:
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#include "operators/kernel/increment_kernel.h"
#include <operators/kernel/central-arm-func/increment_arm_func.h>
namespace paddle_mobile {
namespace operators {
template <>
bool IncrementKernel<CPU, float>::Init(IncrementParam<CPU> *param) {
return true;
}
template <>
void IncrementKernel<CPU, float>::Compute(const IncrementParam<CPU> &param) {
IncrementCompute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#include "operators/kernel/is_empty_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool IsEmptyKernel<CPU, float>::Init(IsEmptyParam<CPU> *param) {
return true;
}
template <>
void IsEmptyKernel<CPU, float>::Compute(const IsEmptyParam<CPU> &param) {
const framework::Tensor *input = param.InputX();
framework::Tensor *out = param.Out();
out->mutable_data<bool>()[0] = input->numel() == 0;
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#pragma once
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void IncrementCompute(const IncrementParam<CPU> &param) {
const framework::Tensor *input = param.InputX();
framework::Tensor *out = param.Out();
int step = param.Step();
out->mutable_data<P>();
const P *input_data = input->data<P>();
P *out_data = out->data<P>();
*out_data = *input_data + step;
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef INCREMENT_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class IncrementKernel
: public framework::OpKernelBase<DeviceType, IncrementParam<DeviceType>> {
public:
void Compute(const IncrementParam<DeviceType> &param);
bool Init(IncrementParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IS_EMPTY_OP
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class IsEmptyKernel
: public framework::OpKernelBase<DeviceType, IsEmptyParam<DeviceType>> {
public:
void Compute(const IsEmptyParam<DeviceType> &param);
bool Init(IsEmptyParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -415,6 +415,7 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, ...@@ -415,6 +415,7 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
} }
} }
#if __ARM_NEON
#if __aarch64__ #if __aarch64__
void Gemm::PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, void Gemm::PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
...@@ -538,6 +539,7 @@ void Gemm::PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, ...@@ -538,6 +539,7 @@ void Gemm::PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B,
} }
} }
#endif // __aarch64__ #endif // __aarch64__
#endif // __ARM_NEON
// 分块矩阵乘法 // 分块矩阵乘法
void Gemm::InnerKernel(int mc, int nc, float alpha, const float *a, void Gemm::InnerKernel(int mc, int nc, float alpha, const float *a,
...@@ -688,42 +690,7 @@ void Gemm::InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b, ...@@ -688,42 +690,7 @@ void Gemm::InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b,
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
// init C
float32x4_t cv0 = vdupq_n_f32(0.0);
float32x4_t cv1 = vdupq_n_f32(0.0);
float32x4_t cv2 = vdupq_n_f32(0.0);
float32x4_t cv3 = vdupq_n_f32(0.0);
float32x4_t av;
float32x4_t bv;
float32x2_t av01;
float32x2_t av23;
for (int p = 0; p < k; p += 1) {
av = vld1q_f32(a);
bv = vld1q_f32(b);
av01 = vget_low_f32(av);
cv0 = vmlaq_lane_f32(cv0, bv, av01, 0);
cv1 = vmlaq_lane_f32(cv1, bv, av01, 1);
av23 = vget_high_f32(av);
cv2 = vmlaq_lane_f32(cv2, bv, av23, 0);
cv3 = vmlaq_lane_f32(cv3, bv, av23, 1);
a += MR;
b += NR;
}
vst1q_f32(c, cv0);
vst1q_f32(c + ldc, cv1);
vst1q_f32(c + 2 * ldc, cv2);
vst1q_f32(c + 3 * ldc, cv3);
// float32x4x4_t cv = {cv0, cv1, cv2, cv3};
}
void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
// init C // init C
float32x4_t cv0 = vdupq_n_f32(0.0); float32x4_t cv0 = vdupq_n_f32(0.0);
float32x4_t cv1 = vdupq_n_f32(0.0); float32x4_t cv1 = vdupq_n_f32(0.0);
...@@ -733,6 +700,10 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -733,6 +700,10 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
float32x4_t cv5 = vdupq_n_f32(0.0); float32x4_t cv5 = vdupq_n_f32(0.0);
float32x4_t cv6 = vdupq_n_f32(0.0); float32x4_t cv6 = vdupq_n_f32(0.0);
float32x4_t cv7 = vdupq_n_f32(0.0); float32x4_t cv7 = vdupq_n_f32(0.0);
float32x4_t cv8 = vdupq_n_f32(0.0);
float32x4_t cv9 = vdupq_n_f32(0.0);
float32x4_t cv10 = vdupq_n_f32(0.0);
float32x4_t cv11 = vdupq_n_f32(0.0);
float32x4_t av; float32x4_t av;
float32x4_t bv0; float32x4_t bv0;
...@@ -740,23 +711,31 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -740,23 +711,31 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
float32x2_t av01; float32x2_t av01;
float32x2_t av23; float32x2_t av23;
float32x2_t av45;
for (int p = 0; p < k; p += 1) { for (int p = 0; p < k; p += 1) {
av = vld1q_f32(a); av = vld1q_f32(a);
av01 = vget_low_f32(av);
av23 = vget_high_f32(av);
av45 = vld1_f32(a + 4);
bv0 = vld1q_f32(b); bv0 = vld1q_f32(b);
bv1 = vld1q_f32(b + 4); bv1 = vld1q_f32(b + 4);
av01 = vget_low_f32(av);
cv0 = vmlaq_lane_f32(cv0, bv0, av01, 0); cv0 = vmlaq_lane_f32(cv0, bv0, av01, 0);
cv1 = vmlaq_lane_f32(cv1, bv1, av01, 0); cv1 = vmlaq_lane_f32(cv1, bv1, av01, 0);
cv2 = vmlaq_lane_f32(cv2, bv0, av01, 1); cv2 = vmlaq_lane_f32(cv2, bv0, av01, 1);
cv3 = vmlaq_lane_f32(cv3, bv1, av01, 1); cv3 = vmlaq_lane_f32(cv3, bv1, av01, 1);
av23 = vget_high_f32(av);
cv4 = vmlaq_lane_f32(cv4, bv0, av23, 0); cv4 = vmlaq_lane_f32(cv4, bv0, av23, 0);
cv5 = vmlaq_lane_f32(cv5, bv1, av23, 0); cv5 = vmlaq_lane_f32(cv5, bv1, av23, 0);
cv6 = vmlaq_lane_f32(cv6, bv0, av23, 1); cv6 = vmlaq_lane_f32(cv6, bv0, av23, 1);
cv7 = vmlaq_lane_f32(cv7, bv1, av23, 1); cv7 = vmlaq_lane_f32(cv7, bv1, av23, 1);
cv8 = vmlaq_lane_f32(cv8, bv0, av45, 0);
cv9 = vmlaq_lane_f32(cv9, bv1, av45, 0);
cv10 = vmlaq_lane_f32(cv10, bv0, av45, 1);
cv11 = vmlaq_lane_f32(cv11, bv1, av45, 1);
a += MR; a += MR;
b += NR; b += NR;
} }
...@@ -769,131 +748,719 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -769,131 +748,719 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
vst1q_f32(c + 2 * ldc + 4, cv5); vst1q_f32(c + 2 * ldc + 4, cv5);
vst1q_f32(c + 3 * ldc, cv6); vst1q_f32(c + 3 * ldc, cv6);
vst1q_f32(c + 3 * ldc + 4, cv7); vst1q_f32(c + 3 * ldc + 4, cv7);
vst1q_f32(c + 4 * ldc, cv8);
vst1q_f32(c + 4 * ldc + 4, cv9);
vst1q_f32(c + 5 * ldc, cv10);
vst1q_f32(c + 5 * ldc + 4, cv11);
} }
// 分块矩阵乘法结果回写 void Gemm::AddDot8x12(int k, const float *a, const float *b, float *c,
// C = A * B int ldc) {
void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { const float *a_ptr, *b_ptr;
int nc1 = nc / 4; a_ptr = a;
int _nc1 = nc % 4; b_ptr = b;
int kc1 = k;
int step = 4 * ldc;
asm volatile(
"dup v5.4s, wzr \n\t"
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
float *c_ptr, *C_ptr; "dup v17.4s, wzr \n\t"
float32x4_t cv; "dup v18.4s, wzr \n\t"
for (int i = 0; i < mc; ++i) { "dup v19.4s, wzr \n\t"
c_ptr = c + i * NC; "dup v20.4s, wzr \n\t"
C_ptr = C + i * ldc; "dup v21.4s, wzr \n\t"
for (int j = 0; j < nc1; ++j) { "dup v22.4s, wzr \n\t"
cv = vld1q_f32(c_ptr); "dup v23.4s, wzr \n\t"
vst1q_f32(C_ptr, cv); "dup v24.4s, wzr \n\t"
c_ptr += 4; "dup v25.4s, wzr \n\t"
C_ptr += 4; "dup v26.4s, wzr \n\t"
} "dup v27.4s, wzr \n\t"
if (_nc1 != 0) { "dup v28.4s, wzr \n\t"
cv = vld1q_f32(c_ptr);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
}
}
}
}
// C = alpha * A * B + beta * C "subs %[kc1], %[kc1], #1 \n\t"
void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} "blt 2f \n\t"
"1: \n\t"
// C = A * B + C "prfm pldl1keep, [%[a_ptr], #32] \n\t"
void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { "prfm pldl1keep, [%[b_ptr], #48] \n\t"
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr; "ld1 {v0.4s, v1.4s}, [%[a_ptr]], #32 \n\t"
float32x4_t cv; "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48 \n\t"
float32x4_t cv1;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv1 = vld1q_f32(C_ptr);
cv = vaddq_f32(cv, cv1);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv1 = vld1q_f32(C_ptr);
cv = vaddq_f32(cv, cv1);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
}
}
}
}
// C = A * B + bias
void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr; "fmla v5.4s, v2.4s, v0.s[0] \n\t"
float32x4_t cv; "fmla v6.4s, v3.4s, v0.s[0] \n\t"
float32x4_t biasv; "fmla v7.4s, v4.4s, v0.s[0] \n\t"
for (int i = 0; i < mc; ++i) { "fmla v8.4s, v2.4s, v0.s[1] \n\t"
c_ptr = c + i * NC; "fmla v9.4s, v3.4s, v0.s[1] \n\t"
C_ptr = C + i * ldc; "fmla v10.4s, v4.4s, v0.s[1] \n\t"
biasv = vld1q_dup_f32(bias + i); "fmla v11.4s, v2.4s, v0.s[2] \n\t"
for (int j = 0; j < nc1; ++j) { "fmla v12.4s, v3.4s, v0.s[2] \n\t"
cv = vld1q_f32(c_ptr); "fmla v13.4s, v4.4s, v0.s[2] \n\t"
cv = vaddq_f32(cv, biasv); "fmla v14.4s, v2.4s, v0.s[3] \n\t"
vst1q_f32(C_ptr, cv); "fmla v15.4s, v3.4s, v0.s[3] \n\t"
c_ptr += 4; "fmla v16.4s, v4.4s, v0.s[3] \n\t"
C_ptr += 4;
} "fmla v17.4s, v2.4s, v1.s[0] \n\t"
if (_nc1 != 0) { "fmla v18.4s, v3.4s, v1.s[0] \n\t"
cv = vld1q_f32(c_ptr); "fmla v19.4s, v4.4s, v1.s[0] \n\t"
cv = vaddq_f32(cv, biasv); "fmla v20.4s, v2.4s, v1.s[1] \n\t"
if (_nc1 >= 1) { "fmla v21.4s, v3.4s, v1.s[1] \n\t"
vst1q_lane_f32(C_ptr, cv, 0); "fmla v22.4s, v4.4s, v1.s[1] \n\t"
C_ptr++; "fmla v23.4s, v2.4s, v1.s[2] \n\t"
} "fmla v24.4s, v3.4s, v1.s[2] \n\t"
if (_nc1 >= 2) { "fmla v25.4s, v4.4s, v1.s[2] \n\t"
vst1q_lane_f32(C_ptr, cv, 1); "fmla v26.4s, v2.4s, v1.s[3] \n\t"
C_ptr++; "fmla v27.4s, v3.4s, v1.s[3] \n\t"
} "fmla v28.4s, v4.4s, v1.s[3] \n\t"
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2); "subs %[kc1], %[kc1], #1 \n\t"
C_ptr++; "bge 1b \n\t"
} "2: \n\t"
}
} "st1 {v5.4s, v6.4s, v7.4s}, [%[c]], %[step] \n\t"
"st1 {v8.4s, v9.4s, v10.4s}, [%[c]], %[step] \n\t"
"st1 {v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s}, [%[c]], %[step] \n\t"
"st1 {v17.4s, v18.4s, v19.4s}, [%[c]], %[step] \n\t"
"st1 {v20.4s, v21.4s, v22.4s}, [%[c]], %[step] \n\t"
"st1 {v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s}, [%[c]], %[step] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28");
} }
// C = A * B + C, relu(C) void Gemm::AddDot6x16(int k, const float *a, const float *b, float *c,
void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { int ldc) {
int nc1 = nc / 4; const float *a_ptr, *b_ptr;
int _nc1 = nc % 4; a_ptr = a;
b_ptr = b;
int kc1 = k;
int step = 4 * ldc;
int step1 = 4 * 6;
asm volatile(
float *c_ptr, *C_ptr; "dup v6.4s, wzr \n\t"
float32x4_t cv; "dup v7.4s, wzr \n\t"
float32x4_t cv1; "dup v8.4s, wzr \n\t"
float32x4_t zero = vdupq_n_f32(0.0); "dup v9.4s, wzr \n\t"
for (int i = 0; i < mc; ++i) { "dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t"
"dup v18.4s, wzr \n\t"
"dup v19.4s, wzr \n\t"
"dup v20.4s, wzr \n\t"
"dup v21.4s, wzr \n\t"
"dup v22.4s, wzr \n\t"
"dup v23.4s, wzr \n\t"
"dup v24.4s, wzr \n\t"
"dup v25.4s, wzr \n\t"
"dup v26.4s, wzr \n\t"
"dup v27.4s, wzr \n\t"
"dup v28.4s, wzr \n\t"
"dup v29.4s, wzr \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 2f \n\t"
"1: \n\t"
"prfm pldl1keep, [%[a_ptr], #24] \n\t"
"prfm pldl1keep, [%[b_ptr], #64] \n\t"
"ld1 {v0.4s, v1.4s}, [%[a_ptr]], %[step1] \n\t"
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], #64 \n\t"
"fmla v6.4s, v2.4s, v0.s[0] \n\t"
"fmla v7.4s, v3.4s, v0.s[0] \n\t"
"fmla v8.4s, v4.4s, v0.s[0] \n\t"
"fmla v9.4s, v5.4s, v0.s[0] \n\t"
"fmla v10.4s, v2.4s, v0.s[1] \n\t"
"fmla v11.4s, v3.4s, v0.s[1] \n\t"
"fmla v12.4s, v4.4s, v0.s[1] \n\t"
"fmla v13.4s, v5.4s, v0.s[1] \n\t"
"fmla v14.4s, v2.4s, v0.s[2] \n\t"
"fmla v15.4s, v3.4s, v0.s[2] \n\t"
"fmla v16.4s, v4.4s, v0.s[2] \n\t"
"fmla v17.4s, v5.4s, v0.s[2] \n\t"
"fmla v18.4s, v2.4s, v0.s[3] \n\t"
"fmla v19.4s, v3.4s, v0.s[3] \n\t"
"fmla v20.4s, v4.4s, v0.s[3] \n\t"
"fmla v21.4s, v5.4s, v0.s[3] \n\t"
"fmla v22.4s, v2.4s, v1.s[0] \n\t"
"fmla v23.4s, v3.4s, v1.s[0] \n\t"
"fmla v24.4s, v4.4s, v1.s[0] \n\t"
"fmla v25.4s, v5.4s, v1.s[0] \n\t"
"fmla v26.4s, v2.4s, v1.s[1] \n\t"
"fmla v27.4s, v3.4s, v1.s[1] \n\t"
"fmla v28.4s, v4.4s, v1.s[1] \n\t"
"fmla v29.4s, v5.4s, v1.s[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge 1b \n\t"
"2: \n\t"
"st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t"
"st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t"
"st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t"
"st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t"
"st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step), [step1] "r"(step1)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29");
}
#else
void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k / 4;
int kc2 = k % 4;
int step = 4 * ldc;
asm volatile(
"pld [%[a_ptr]] \n\t"
"pld [%[b_ptr]] \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.32 {q0, q1}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q2, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q2, d1[1] \n\t"
"vmla.f32 q10, q3, d2[0] \n\t"
"vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q4, q5}, [%[a_ptr]]! \n\t"
"vld1.32 {q6, q7}, [%[b_ptr]]! \n\t"
"vmla.f32 q10, q6, d8[0] \n\t"
"vmla.f32 q11, q6, d8[1] \n\t"
"vmla.f32 q12, q6, d9[0] \n\t"
"vmla.f32 q13, q6, d9[1] \n\t"
"vmla.f32 q10, q7, d10[0] \n\t"
"vmla.f32 q11, q7, d10[1] \n\t"
"vmla.f32 q12, q7, d11[0] \n\t"
"vmla.f32 q13, q7, d11[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a_ptr]]! \n\t"
"vld1.32 {q1}, [%[b_ptr]]! \n\t"
"vmla.f32 q10, q1, d0[0] \n\t"
"vmla.f32 q11, q1, d0[1] \n\t"
"vmla.f32 q12, q1, d1[0] \n\t"
"vmla.f32 q13, q1, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t"
"mov r5, %[c] \n\t"
"mov r6, %[step] \n\t"
"vst1.32 {q10}, [r5], r6 \n\t"
"vst1.32 {q11}, [r5], r6 \n\t"
"vst1.32 {q12}, [r5], r6 \n\t"
"vst1.32 {q13}, [r5] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q10", "q11", "q12", "q13");
}
void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k / 4;
int kc2 = k % 4;
int step = 4 * ldc;
asm volatile(
"pld [%[a_ptr]] \n\t"
"pld [%[b_ptr]] \n\t"
"vmov.f32 q8, #0.0 \n\t"
"vmov.f32 q9, #0.0 \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"vmov.f32 q14, #0.0 \n\t"
"vmov.f32 q15, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.32 {q0, q1}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vld1.32 {q4, q5}, [%[b_ptr]]! \n\t"
"vmla.f32 q8, q2, d0[0] \n\t"
"vmla.f32 q9, q3, d0[0] \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q3, d1[0] \n\t"
"vmla.f32 q14, q2, d1[1] \n\t"
"vmla.f32 q15, q3, d1[1] \n\t"
"vmla.f32 q8, q4, d2[0] \n\t"
"vmla.f32 q9, q5, d2[0] \n\t"
"vmla.f32 q10, q4, d2[1] \n\t"
"vmla.f32 q11, q5, d2[1] \n\t"
"vmla.f32 q12, q4, d3[0] \n\t"
"vmla.f32 q13, q5, d3[0] \n\t"
"vmla.f32 q14, q4, d3[1] \n\t"
"vmla.f32 q15, q5, d3[1] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.32 {q0, q1}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vld1.32 {q4, q5}, [%[b_ptr]]! \n\t"
"vmla.f32 q8, q2, d0[0] \n\t"
"vmla.f32 q9, q3, d0[0] \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q3, d1[0] \n\t"
"vmla.f32 q14, q2, d1[1] \n\t"
"vmla.f32 q15, q3, d1[1] \n\t"
"vmla.f32 q8, q4, d2[0] \n\t"
"vmla.f32 q9, q5, d2[0] \n\t"
"vmla.f32 q10, q4, d2[1] \n\t"
"vmla.f32 q11, q5, d2[1] \n\t"
"vmla.f32 q12, q4, d3[0] \n\t"
"vmla.f32 q13, q5, d3[0] \n\t"
"vmla.f32 q14, q4, d3[1] \n\t"
"vmla.f32 q15, q5, d3[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q8, q2, d0[0] \n\t"
"vmla.f32 q9, q3, d0[0] \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q3, d1[0] \n\t"
"vmla.f32 q14, q2, d1[1] \n\t"
"vmla.f32 q15, q3, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t"
"mov r5, %[c] \n\t"
"mov r6, %[step] \n\t"
"vst1.32 {q8, q9}, [r5], r6 \n\t"
"vst1.32 {q10, q11}, [r5], r6 \n\t"
"vst1.32 {q12, q13}, [r5], r6 \n\t"
"vst1.32 {q14, q15}, [r5] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9",
"q10", "q11", "q12", "q13", "q14", "q15");
}
void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k / 8;
int kc2 = k % 8;
int step = sizeof(float) * ldc;
asm volatile(
"pld [%[a_ptr]] \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr]] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vmov.f32 q4, #0.0 \n\t"
"vmov.f32 q5, #0.0 \n\t"
"vmov.f32 q6, #0.0 \n\t"
"vmov.f32 q7, #0.0 \n\t"
"vmov.f32 q8, #0.0 \n\t"
"vmov.f32 q9, #0.0 \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"vmov.f32 q14, #0.0 \n\t"
"vmov.f32 q15, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 2f \n\t"
"1: \n\t"
"pld [%[a_ptr], #128] \n\t"
"pld [%[b_ptr], #128] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[a_ptr], #128] \n\t"
"pld [%[b_ptr], #128] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[a_ptr], #128] \n\t"
"pld [%[b_ptr], #128] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[a_ptr], #128] \n\t"
"pld [%[b_ptr], #128] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge 1b \n\t"
"2: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt 4f \n\t"
"3: \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge 3b \n\t"
"4: \n\t"
"mov r5, %[c] \n\t"
"mov r6, %[step] \n\t"
"vst1.32 {q4, q5}, [r5], r6 \n\t"
"vst1.32 {q6, q7}, [r5], r6 \n\t"
"vst1.32 {q8, q9}, [r5], r6 \n\t"
"vst1.32 {q10, q11}, [r5], r6 \n\t"
"vst1.32 {q12, q13}, [r5], r6 \n\t"
"vst1.32 {q14, q15}, [r5] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
#endif // __aarch64__
#endif // __ARM_NEON
#if __ARM_NEON
#if __aarch64__
// 分块矩阵乘法结果回写
// C = A * B
void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
}
}
}
}
// C = alpha * A * B + beta * C
void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
// C = A * B + C
void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t cv1;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv1 = vld1q_f32(C_ptr);
cv = vaddq_f32(cv, cv1);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv1 = vld1q_f32(C_ptr);
cv = vaddq_f32(cv, cv1);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
}
}
}
}
// C = A * B + bias
void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t biasv;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC;
C_ptr = C + i * ldc;
biasv = vld1q_dup_f32(bias + i);
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vaddq_f32(cv, biasv);
if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++;
}
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
}
if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2);
C_ptr++;
}
}
}
}
// C = A * B + C, relu(C)
void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
float *c_ptr, *C_ptr;
float32x4_t cv;
float32x4_t cv1;
float32x4_t zero = vdupq_n_f32(0.0);
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * NC; c_ptr = c + i * NC;
C_ptr = C + i * ldc; C_ptr = C + i * ldc;
for (int j = 0; j < nc1; ++j) { for (int j = 0; j < nc1; ++j) {
...@@ -1188,82 +1755,8 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -1188,82 +1755,8 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
} }
} }
void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu) {}
#else #else
void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k / 4;
int kc2 = k % 4;
int step = 4 * ldc;
asm volatile(
"pld [%[a_ptr]] \n\t"
"pld [%[b_ptr]] \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.32 {q0, q1}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q10, q2, d0[0] \n\t"
"vmla.f32 q11, q2, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q2, d1[1] \n\t"
"vmla.f32 q10, q3, d2[0] \n\t"
"vmla.f32 q11, q3, d2[1] \n\t"
"vmla.f32 q12, q3, d3[0] \n\t"
"vmla.f32 q13, q3, d3[1] \n\t"
"vld1.32 {q4, q5}, [%[a_ptr]]! \n\t"
"vld1.32 {q6, q7}, [%[b_ptr]]! \n\t"
"vmla.f32 q10, q6, d8[0] \n\t"
"vmla.f32 q11, q6, d8[1] \n\t"
"vmla.f32 q12, q6, d9[0] \n\t"
"vmla.f32 q13, q6, d9[1] \n\t"
"vmla.f32 q10, q7, d10[0] \n\t"
"vmla.f32 q11, q7, d10[1] \n\t"
"vmla.f32 q12, q7, d11[0] \n\t"
"vmla.f32 q13, q7, d11[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a_ptr]]! \n\t"
"vld1.32 {q1}, [%[b_ptr]]! \n\t"
"vmla.f32 q10, q1, d0[0] \n\t"
"vmla.f32 q11, q1, d0[1] \n\t"
"vmla.f32 q12, q1, d1[0] \n\t"
"vmla.f32 q13, q1, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t"
"mov r5, %[c] \n\t"
"mov r6, %[step] \n\t"
"vst1.32 {q10}, [r5], r6 \n\t"
"vst1.32 {q11}, [r5], r6 \n\t"
"vst1.32 {q12}, [r5], r6 \n\t"
"vst1.32 {q13}, [r5] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q10", "q11", "q12", "q13");
}
void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C, int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu) { int ldc, bool relu) {
...@@ -1486,10 +1979,10 @@ void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, ...@@ -1486,10 +1979,10 @@ void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A,
} }
} }
/*
void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C, int lda, const float *B, int ldb, float beta,
int ldc, bool relu, float *new_scale, float *new_bias) { float *C, int ldc, bool relu, float *new_scale,
float *new_bias) {
float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n)); float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3; const float *a0, *b0, *b1, *b2, *b3;
...@@ -1697,114 +2190,6 @@ void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, ...@@ -1697,114 +2190,6 @@ void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
VecWriteWithBn(n, bufferC, C, ldc, new_scale, new_bias); VecWriteWithBn(n, bufferC, C, ldc, new_scale, new_bias);
} }
} }
*/
void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k / 4;
int kc2 = k % 4;
int step = 4 * ldc;
asm volatile(
"pld [%[a_ptr]] \n\t"
"pld [%[b_ptr]] \n\t"
"vmov.f32 q8, #0.0 \n\t"
"vmov.f32 q9, #0.0 \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"vmov.f32 q14, #0.0 \n\t"
"vmov.f32 q15, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt end_kc1_%= \n\t"
"loop_kc1_%=: \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.32 {q0, q1}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vld1.32 {q4, q5}, [%[b_ptr]]! \n\t"
"vmla.f32 q8, q2, d0[0] \n\t"
"vmla.f32 q9, q3, d0[0] \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q3, d1[0] \n\t"
"vmla.f32 q14, q2, d1[1] \n\t"
"vmla.f32 q15, q3, d1[1] \n\t"
"vmla.f32 q8, q4, d2[0] \n\t"
"vmla.f32 q9, q5, d2[0] \n\t"
"vmla.f32 q10, q4, d2[1] \n\t"
"vmla.f32 q11, q5, d2[1] \n\t"
"vmla.f32 q12, q4, d3[0] \n\t"
"vmla.f32 q13, q5, d3[0] \n\t"
"vmla.f32 q14, q4, d3[1] \n\t"
"vmla.f32 q15, q5, d3[1] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vld1.32 {q0, q1}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vld1.32 {q4, q5}, [%[b_ptr]]! \n\t"
"vmla.f32 q8, q2, d0[0] \n\t"
"vmla.f32 q9, q3, d0[0] \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q3, d1[0] \n\t"
"vmla.f32 q14, q2, d1[1] \n\t"
"vmla.f32 q15, q3, d1[1] \n\t"
"vmla.f32 q8, q4, d2[0] \n\t"
"vmla.f32 q9, q5, d2[0] \n\t"
"vmla.f32 q10, q4, d2[1] \n\t"
"vmla.f32 q11, q5, d2[1] \n\t"
"vmla.f32 q12, q4, d3[0] \n\t"
"vmla.f32 q13, q5, d3[0] \n\t"
"vmla.f32 q14, q4, d3[1] \n\t"
"vmla.f32 q15, q5, d3[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"bge loop_kc1_%= \n\t"
"end_kc1_%=: \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q8, q2, d0[0] \n\t"
"vmla.f32 q9, q3, d0[0] \n\t"
"vmla.f32 q10, q2, d0[1] \n\t"
"vmla.f32 q11, q3, d0[1] \n\t"
"vmla.f32 q12, q2, d1[0] \n\t"
"vmla.f32 q13, q3, d1[0] \n\t"
"vmla.f32 q14, q2, d1[1] \n\t"
"vmla.f32 q15, q3, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t"
"mov r5, %[c] \n\t"
"mov r6, %[step] \n\t"
"vst1.32 {q8, q9}, [r5], r6 \n\t"
"vst1.32 {q10, q11}, [r5], r6 \n\t"
"vst1.32 {q12, q13}, [r5], r6 \n\t"
"vst1.32 {q14, q15}, [r5] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[kc2] "r"(kc2), [step] "r"(step)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9",
"q10", "q11", "q12", "q13", "q14", "q15");
}
// C = A * B // C = A * B
void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
...@@ -2567,162 +2952,25 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, ...@@ -2567,162 +2952,25 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
cv = vld1q_f32(c_ptr); cv = vld1q_f32(c_ptr);
biasv = vld1q_f32(bias_ptr); biasv = vld1q_f32(bias_ptr);
cv = vmlaq_n_f32(nbias, cv, scale0); cv = vmlaq_n_f32(nbias, cv, scale0);
cv = vaddq_f32(cv, biasv); cv = vaddq_f32(cv, biasv);
cv = vmaxq_f32(cv, zero); cv = vmaxq_f32(cv, zero);
if (_nc1 >= 1) { if (_nc1 >= 1) {
vst1q_lane_f32(C_ptr, cv, 0); vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++; C_ptr++;
} }
if (_nc1 >= 2) { if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1); vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++; C_ptr++;
} }
if (_nc1 >= 3) { if (_nc1 >= 3) {
vst1q_lane_f32(C_ptr, cv, 2); vst1q_lane_f32(C_ptr, cv, 2);
}
}
}
}
// C = A * B
void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
int nc2 = _nc1 / 4;
int nc3 = 16 - 4 * (_nc1 % 4);
asm volatile(
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t"
"vst1.32 {q0, q1}, [%[C]]! \n\t"
"vld1.32 {q2, q3}, [%[c]]! \n\t"
"vst1.32 {q2, q3}, [%[C]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t"
"vld1.32 {q4}, [%[c]]! \n\t"
"vst1.32 {q4}, [%[C]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
"cmp %[nc3], #16 \n\t"
"beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q5}, [%[c]]! \n\t"
"vst1.32 {q5}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
:
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5");
}
// C = alpha * A * B + beta * C
void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {}
// C = A * B + C
void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
asm volatile(
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[C]] \n\t"
"vadd.f32 q10, q0, q2 \n\t"
"vadd.f32 q11, q1, q3 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[C]] \n\t"
"vadd.f32 q12, q4, q6 \n\t"
"vadd.f32 q13, q5, q7 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
: [C] "+r"(C), [c] "+r"(c)
: [nc1] "r"(nc1)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13");
if (_nc1 != 0) {
for (int j = 0; j < _nc1; j++) {
*C++ += *c++;
}
}
}
// C = A * B + C, relu(C)
void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
asm volatile(
"vmov.f32 q14, #0.0 \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[C]] \n\t"
"vadd.f32 q10, q0, q2 \n\t"
"vadd.f32 q11, q1, q3 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vmax.f32 q11, q11, q14 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[C]] \n\t"
"vadd.f32 q12, q4, q6 \n\t"
"vadd.f32 q13, q5, q7 \n\t"
"vmax.f32 q12, q12, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
: [C] "+r"(C), [c] "+r"(c)
: [nc1] "r"(nc1)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13");
if (_nc1 != 0) {
for (int j = 0; j < _nc1; j++) {
*C += *c;
if (*C < 0) {
*C = 0;
} }
C++;
c++;
} }
} }
} }
/* // C = A * B
// C = A * B, batchnorm(C) void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) {
void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
float *bias) {
int nc1 = n / 16; int nc1 = n / 16;
int _nc1 = n % 16; int _nc1 = n % 16;
int nc2 = _nc1 / 4; int nc2 = _nc1 / 4;
...@@ -2734,18 +2982,10 @@ void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { ...@@ -2734,18 +2982,10 @@ void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
"loop_nc1_%=: \n\t" "loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t" "vld1.32 {q0, q1}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[scale]]! \n\t" "vst1.32 {q0, q1}, [%[C]]! \n\t"
"vld1.32 {q10, q11}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q2 \n\t"
"vmla.f32 q11, q1, q3 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t" "vld1.32 {q2, q3}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[scale]]! \n\t" "vst1.32 {q2, q3}, [%[C]]! \n\t"
"vld1.32 {q12, q13}, [%[bias]]! \n\t"
"vmla.f32 q12, q4, q6 \n\t"
"vmla.f32 q13, q5, q7 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t" "subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t" "bge loop_nc1_%= \n\t"
...@@ -2755,11 +2995,8 @@ void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { ...@@ -2755,11 +2995,8 @@ void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
"blt end_nc2_%= \n\t" "blt end_nc2_%= \n\t"
"loop_nc2_%=: \n\t" "loop_nc2_%=: \n\t"
"vld1.32 {q0}, [%[c]]! \n\t" "vld1.32 {q4}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t" "vst1.32 {q4}, [%[C]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t" "subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t" "bge loop_nc2_%= \n\t"
...@@ -2767,663 +3004,264 @@ void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { ...@@ -2767,663 +3004,264 @@ void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
"cmp %[nc3], #16 \n\t" "cmp %[nc3], #16 \n\t"
"beq end_nc3_%= \n\t" "beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t" "sub %[c], %[c], %[nc3] \n\t"
"sub %[scale], %[scale], %[nc3] \n\t"
"sub %[bias], %[bias], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t" "sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q5}, [%[c]]! \n\t"
"vld1.32 {q0}, [%[c]]! \n\t" "vst1.32 {q5}, [%[C]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"end_nc3_%=: \n\t" "end_nc3_%=: \n\t"
: :
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3)
"r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", : "memory", "q0", "q1", "q2", "q3", "q4", "q5");
"q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13"); }
}
// C = alpha * A * B + beta * C
void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {}
// C = A * B, batchnorm(C), relu(C) // C = A * B + C
void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) {
*scale, float *bias) { int nc1 = n / 16; int _nc1 = n % 16; int nc2 = _nc1 / int nc1 = n / 16;
4; int nc3 = 16 - 4 * (_nc1 % 4); int _nc1 = n % 16;
asm volatile( asm volatile(
"vmov.f32 q14, #0.0 \n\t"
"subs %[nc1], %[nc1], #1 \n\t" "subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t" "blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t" "loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c]]! \n\t" "vld1.32 {q0, q1}, [%[c]]! \n\t"
"vld1.32 {q2, q3}, [%[scale]]! \n\t" "vld1.32 {q2, q3}, [%[C]] \n\t"
"vld1.32 {q10, q11}, [%[bias]]! \n\t" "vadd.f32 q10, q0, q2 \n\t"
"vmla.f32 q10, q0, q2 \n\t" "vadd.f32 q11, q1, q3 \n\t"
"vmla.f32 q11, q1, q3 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vmax.f32 q11, q11, q14 \n\t"
"vst1.32 {q10, q11}, [%[C]]! \n\t" "vst1.32 {q10, q11}, [%[C]]! \n\t"
"vld1.32 {q4, q5}, [%[c]]! \n\t" "vld1.32 {q4, q5}, [%[c]]! \n\t"
"vld1.32 {q6, q7}, [%[scale]]! \n\t" "vld1.32 {q6, q7}, [%[C]] \n\t"
"vld1.32 {q12, q13}, [%[bias]]! \n\t" "vadd.f32 q12, q4, q6 \n\t"
"vmla.f32 q12, q4, q6 \n\t" "vadd.f32 q13, q5, q7 \n\t"
"vmla.f32 q13, q5, q7 \n\t"
"vmax.f32 q12, q12, q14 \n\t"
"vmax.f32 q13, q13, q14 \n\t"
"vst1.32 {q12, q13}, [%[C]]! \n\t" "vst1.32 {q12, q13}, [%[C]]! \n\t"
"subs %[nc1], %[nc1], #1 \n\t" "subs %[nc1], %[nc1], #1 \n\t"
"bge loop_nc1_%= \n\t" "bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t" "end_nc1_%=: \n\t"
"subs %[nc2], %[nc2], #1 \n\t" : [C] "+r"(C), [c] "+r"(c)
"blt end_nc2_%= \n\t" : [nc1] "r"(nc1)
"loop_nc2_%=: \n\t" : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13");
"vld1.32 {q0}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"subs %[nc2], %[nc2], #1 \n\t"
"bge loop_nc2_%= \n\t"
"end_nc2_%=: \n\t"
"cmp %[nc3], #16 \n\t"
"beq end_nc3_%= \n\t"
"sub %[c], %[c], %[nc3] \n\t"
"sub %[scale], %[scale], %[nc3] \n\t"
"sub %[bias], %[bias], %[nc3] \n\t"
"sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q0}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
:
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3]
"r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2",
"q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13", "q14");
}
*/
#endif // __aarch64__
#else
void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
float *c0, *c1, *c2, *c3;
c0 = c;
c1 = c + ldc;
c2 = c + 2 * ldc;
c3 = c + 3 * ldc;
for (int p = 0; p < k; p += 1) {
// first row
c0[0] += a[0] * b[0];
c0[1] += a[0] * b[1];
c0[2] += a[0] * b[2];
c0[3] += a[0] * b[3];
// second row
c1[0] += a[1] * b[0];
c1[1] += a[1] * b[1];
c1[2] += a[1] * b[2];
c1[3] += a[1] * b[3];
// third row
c2[0] += a[2] * b[0];
c2[1] += a[2] * b[1];
c2[2] += a[2] * b[2];
c2[3] += a[2] * b[3];
// fourth row
c3[0] += a[3] * b[0];
c3[1] += a[3] * b[1];
c3[2] += a[3] * b[2];
c3[3] += a[3] * b[3];
a += 4;
b += 4;
}
}
void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
}
void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {}
void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {}
void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {}
void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {}
void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc,
float *bias) {}
void Gemm::WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc,
float *p, std::string mode, float *bias,
float *bias1) {}
void Gemm::WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias) {}
void Gemm::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias) {}
void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1) {
}
#endif // __ARM_NEON
// 32位 float 矩阵乘法
void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 512 * 1024;
KC = k;
MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
// make sure MC is multiple of MR, and NC is multiple of NR
if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
if (bias == nullptr) {
InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, nullptr);
} else {
InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, bias + i);
}
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias,
float *bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 512 * 1024;
KC = k;
MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
// make sure MC is multiple of MR, and NC is multiple of NR
if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
int mc, nc; if (_nc1 != 0) {
for (int j = 0; j < n; j += NC) { for (int j = 0; j < _nc1; j++) {
nc = s_min(n - j, NC); *C++ += *c++;
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
if (bias == nullptr) {
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i);
} else {
InnerKernelWithBnAdd(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i,
bias + i * ldc + j);
}
} }
} }
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
} }
void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, // C = A * B + C, relu(C)
const float *B, int ldb, float *C, int ldc, float *p, void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
std::string mode, float *bias, float *bias1) { int nc1 = n / 16;
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) int _nc1 = n % 16;
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 0.5 * 1024 * 1024;
KC = k; asm volatile(
MC = L1 / (KC * sizeof(float)); "vmov.f32 q14, #0.0 \n\t"
NC = L2 / (KC * sizeof(float)); "subs %[nc1], %[nc1], #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
// make sure MC is multiple of MR, and NC is multiple of NR "vld1.32 {q0, q1}, [%[c]]! \n\t"
if (MC == 0) { "vld1.32 {q2, q3}, [%[C]] \n\t"
MC = MR; "vadd.f32 q10, q0, q2 \n\t"
} else { "vadd.f32 q11, q1, q3 \n\t"
int mblock_num = (m + MC - 1) / MC; "vmax.f32 q10, q10, q14 \n\t"
MC = (m + mblock_num - 1) / mblock_num; "vmax.f32 q11, q11, q14 \n\t"
MC = (MC + MR - 1) / MR * MR; "vst1.32 {q10, q11}, [%[C]]! \n\t"
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>( "vld1.32 {q4, q5}, [%[c]]! \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); "vld1.32 {q6, q7}, [%[C]] \n\t"
packedB = static_cast<float *>( "vadd.f32 q12, q4, q6 \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); "vadd.f32 q13, q5, q7 \n\t"
packedC = static_cast<float *>( "vmax.f32 q12, q12, q14 \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); "vmax.f32 q13, q13, q14 \n\t"
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC)); "vst1.32 {q12, q13}, [%[C]]! \n\t"
for (int l = 0; l < KC; ++l) { "subs %[nc1], %[nc1], #1 \n\t"
zero[l] = 0; "bge loop_nc1_%= \n\t"
} "end_nc1_%=: \n\t"
int mc, nc; : [C] "+r"(C), [c] "+r"(c)
for (int j = 0; j < n; j += NC) { : [nc1] "r"(nc1)
nc = s_min(n - j, NC); : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
#if __aarch64__ "q12", "q13");
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); if (_nc1 != 0) {
#else for (int j = 0; j < _nc1; j++) {
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); *C += *c;
#endif if (*C < 0) {
for (int i = 0; i < m; i += MC) { *C = 0;
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
if (bias1 == nullptr) {
InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc,
p + i, mode, bias + i, nullptr);
} else {
InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc,
p + i, mode, bias + i, bias1 + i * ldc + j);
} }
C++;
c++;
} }
} }
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
} }
// 32位 float 矩阵乘法 // C = A * B, batchnorm(C)
void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
const float *B, int ldb, float beta, float *C, int ldc, float *bias) {
bool relu, float *bias) { int nc1 = n / 16;
#ifndef __aarch64__ int _nc1 = n % 16;
if (m == 1 && bias == nullptr) { int nc2 = _nc1 / 4;
return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu); int nc3 = 16 - 4 * (_nc1 % 4);
}
#endif // __aarch64__
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
// int L1 = 64 / max_threads * 1024; asm volatile(
int L = (max_threads > 2) ? 64 : 32; "subs %[nc1], %[nc1], #1 \n\t"
int L1 = L / max_threads * 1024; "blt end_nc1_%= \n\t"
KC = k; "loop_nc1_%=: \n\t"
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__ "vld1.32 {q0, q1}, [%[c]]! \n\t"
procPackA = &Gemm::PackMatrixA_6r; "vld1.32 {q2, q3}, [%[scale]]! \n\t"
procPackB = &Gemm::PackMatrixB_omp_16c; "vld1.32 {q10, q11}, [%[bias]]! \n\t"
procAddDot = &Gemm::AddDot6x16; "vmla.f32 q10, q0, q2 \n\t"
#else "vmla.f32 q11, q1, q3 \n\t"
procPackA = &Gemm::PackMatrixA_6r; "vst1.32 {q10, q11}, [%[C]]! \n\t"
procPackB = &Gemm::PackMatrixB_omp_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
packedB = static_cast<float *>( "vld1.32 {q4, q5}, [%[c]]! \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); "vld1.32 {q6, q7}, [%[scale]]! \n\t"
(*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); "vld1.32 {q12, q13}, [%[bias]]! \n\t"
packedA = static_cast<float *>( "vmla.f32 q12, q4, q6 \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); "vmla.f32 q13, q5, q7 \n\t"
} else { "vst1.32 {q12, q13}, [%[C]]! \n\t"
// 对 B 分块
NC = L1 / (KC * sizeof(float));
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__ "subs %[nc1], %[nc1], #1 \n\t"
procPackA = &Gemm::PackMatrixA_omp_6r; "bge loop_nc1_%= \n\t"
procPackB = &Gemm::PackMatrixB_16c; "end_nc1_%=: \n\t"
procAddDot = &Gemm::AddDot6x16;
#else
procPackA = &Gemm::PackMatrixA_omp_6r; "subs %[nc2], %[nc2], #1 \n\t"
procPackB = &Gemm::PackMatrixB_8c; "blt end_nc2_%= \n\t"
procAddDot = &Gemm::AddDot6x8; "loop_nc2_%=: \n\t"
#endif
packedA = static_cast<float *>( "vld1.32 {q0}, [%[c]]! \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); "vld1.32 {q1}, [%[scale]]! \n\t"
(*this.*procPackA)(m, KC, m % MR, A, lda, packedA); "vld1.32 {q10}, [%[bias]]! \n\t"
packedB = static_cast<float *>( "vmla.f32 q10, q0, q1 \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); "vst1.32 {q10}, [%[C]]! \n\t"
}
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) { "subs %[nc2], %[nc2], #1 \n\t"
#pragma omp parallel for "bge loop_nc2_%= \n\t"
for (int i = 0; i < m; i += MC) { "end_nc2_%=: \n\t"
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int mc; "cmp %[nc3], #16 \n\t"
mc = s_min(m - i, MC); "beq end_nc3_%= \n\t"
float *local_A = packedA + MC * KC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A);
if (bias == nullptr) {
InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, nullptr);
} else {
InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, bias + i);
}
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc; "sub %[c], %[c], %[nc3] \n\t"
nc = s_min(n - j, NC); "sub %[scale], %[scale], %[nc3] \n\t"
float *local_B = packedB + KC * NC * local_threads; "sub %[bias], %[bias], %[nc3] \n\t"
float *local_C = packedC + MC * NC * local_threads; "sub %[C], %[C], %[nc3] \n\t"
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
paddle_mobile::memory::Free(packedA); "vld1.32 {q0}, [%[c]]! \n\t"
paddle_mobile::memory::Free(packedB); "vld1.32 {q1}, [%[scale]]! \n\t"
paddle_mobile::memory::Free(packedC); "vld1.32 {q10}, [%[bias]]! \n\t"
paddle_mobile::memory::Free(zero); "vmla.f32 q10, q0, q1 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
:
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3),
[scale] "r"(scale), [bias] "r"(bias)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13");
} }
void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, // C = A * B, batchnorm(C), relu(C)
int lda, const float *B, int ldb, float beta, void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *scale,
float *C, int ldc, bool relu, float *new_scale, float *bias) {
float *new_bias, float *bias) { int nc1 = n / 16;
#ifdef _OPENMP int _nc1 = n % 16;
int max_threads = omp_get_max_threads(); int nc2 = _nc1 / 4;
#else int nc3 = 16 - 4 * (_nc1 % 4);
int max_threads = 1;
#endif
int L1 = 64 / max_threads * 1024; asm volatile(
KC = k; "vmov.f32 q14, #0.0 \n\t"
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC)); "subs %[nc1], %[nc1], #1 \n\t"
memset(static_cast<void *>(zero), 0, sizeof(float) * KC); "blt end_nc1_%= \n\t"
if (m > n) { "loop_nc1_%=: \n\t"
// 对 A 分块
MC = L1 / (KC * sizeof(float));
if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// 补齐 B
NC = (n + NR - 1) / NR * NR;
#if __aarch64__ "vld1.32 {q0, q1}, [%[c]]! \n\t"
procPackA = &Gemm::PackMatrixA_6r; "vld1.32 {q2, q3}, [%[scale]]! \n\t"
procPackB = &Gemm::PackMatrixB_omp_16c; "vld1.32 {q10, q11}, [%[bias]]! \n\t"
procAddDot = &Gemm::AddDot6x16; "vmla.f32 q10, q0, q2 \n\t"
#else "vmla.f32 q11, q1, q3 \n\t"
procPackA = &Gemm::PackMatrixA_6r; "vmax.f32 q10, q10, q14 \n\t"
procPackB = &Gemm::PackMatrixB_omp_8c; "vmax.f32 q11, q11, q14 \n\t"
procAddDot = &Gemm::AddDot6x8; "vst1.32 {q10, q11}, [%[C]]! \n\t"
#endif
packedB = static_cast<float *>( "vld1.32 {q4, q5}, [%[c]]! \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); "vld1.32 {q6, q7}, [%[scale]]! \n\t"
(*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); "vld1.32 {q12, q13}, [%[bias]]! \n\t"
packedA = static_cast<float *>( "vmla.f32 q12, q4, q6 \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); "vmla.f32 q13, q5, q7 \n\t"
} else { "vmax.f32 q12, q12, q14 \n\t"
// 对 B 分块 "vmax.f32 q13, q13, q14 \n\t"
NC = L1 / (KC * sizeof(float)); "vst1.32 {q12, q13}, [%[C]]! \n\t"
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR - 1) / MR * MR;
#if __aarch64__ "subs %[nc1], %[nc1], #1 \n\t"
procPackA = &Gemm::PackMatrixA_omp_6r; "bge loop_nc1_%= \n\t"
procPackB = &Gemm::PackMatrixB_16c; "end_nc1_%=: \n\t"
procAddDot = &Gemm::AddDot6x16;
#else
procPackA = &Gemm::PackMatrixA_omp_6r;
procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
packedA = static_cast<float *>( "subs %[nc2], %[nc2], #1 \n\t"
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); "blt end_nc2_%= \n\t"
(*this.*procPackA)(m, KC, m % MR, A, lda, packedA); "loop_nc2_%=: \n\t"
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
if (m > n) { "vld1.32 {q0}, [%[c]]! \n\t"
#pragma omp parallel for "vld1.32 {q1}, [%[scale]]! \n\t"
for (int i = 0; i < m; i += MC) { "vld1.32 {q10}, [%[bias]]! \n\t"
#ifdef _OPENMP "vmla.f32 q10, q0, q1 \n\t"
int local_threads = omp_get_thread_num(); "vmax.f32 q10, q10, q14 \n\t"
#else "vst1.32 {q10}, [%[C]]! \n\t"
int local_threads = 0;
#endif
int mc; "subs %[nc2], %[nc2], #1 \n\t"
mc = s_min(m - i, MC); "bge loop_nc2_%= \n\t"
float *local_A = packedA + MC * KC * local_threads; "end_nc2_%=: \n\t"
float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A);
if (bias == nullptr) {
InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, new_scale + i, new_bias + i);
} else {
InnerKernelWithBnAdd(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, new_scale + i, new_bias + i,
bias + i * ldc);
}
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
int nc; "cmp %[nc3], #16 \n\t"
nc = s_min(n - j, NC); "beq end_nc3_%= \n\t"
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B);
if (bias == nullptr) {
InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, new_scale, new_bias);
} else {
InnerKernelWithBnAdd(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, new_scale, new_bias,
bias + j);
}
}
}
paddle_mobile::memory::Free(packedA); "sub %[c], %[c], %[nc3] \n\t"
paddle_mobile::memory::Free(packedB); "sub %[scale], %[scale], %[nc3] \n\t"
paddle_mobile::memory::Free(packedC); "sub %[bias], %[bias], %[nc3] \n\t"
paddle_mobile::memory::Free(zero); "sub %[C], %[C], %[nc3] \n\t"
"vld1.32 {q0}, [%[c]]! \n\t"
"vld1.32 {q1}, [%[scale]]! \n\t"
"vld1.32 {q10}, [%[bias]]! \n\t"
"vmla.f32 q10, q0, q1 \n\t"
"vmax.f32 q10, q10, q14 \n\t"
"vst1.32 {q10}, [%[C]]! \n\t"
"end_nc3_%=: \n\t"
:
: [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3),
[scale] "r"(scale), [bias] "r"(bias)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11",
"q12", "q13", "q14");
} }
void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, #endif // __aarch64__
const float *B, int ldb, float *C, int ldc, #endif // __ARM_NEON
float *p, std::string mode, float *bias,
float *bias1) { // 32位 float 矩阵乘法
#ifdef _OPENMP void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
int max_threads = omp_get_max_threads(); const float *B, int ldb, float beta, float *C, int ldc,
#else bool relu, float *bias) {
int max_threads = 1; // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
#endif // L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 512 * 1024;
int L1 = 8 * 1024;
KC = k; KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float)); MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
// make sure MC is multiple of MR, and NC is multiple of NR
if (MC == 0) { if (MC == 0) {
MC = MR; MC = MR;
} else { } else {
...@@ -3431,27 +3269,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3431,27 +3269,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
MC = (m + mblock_num - 1) / mblock_num; MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR; MC = (MC + MR - 1) / MR * MR;
} }
// 补齐 B // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
NC = (n + NR - 1) / NR * NR;
#if __aarch64__
procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_16c;
procAddDot = &Gemm::AddDot6x16;
#else
procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
(*this.*procPackB)(KC, n, n % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
if (NC == 0) { if (NC == 0) {
NC = NR; NC = NR;
} else { } else {
...@@ -3459,553 +3277,582 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, ...@@ -3459,553 +3277,582 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
NC = (n + nblock_num - 1) / nblock_num; NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR; NC = (NC + NR - 1) / NR * NR;
} }
// 补齐 A // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
MC = (m + MR - 1) / MR * MR;
#if __aarch64__
procPackA = &Gemm::PackMatrixA_omp_6r;
procPackB = &Gemm::PackMatrixB_16c;
procAddDot = &Gemm::AddDot6x16;
#else
procPackA = &Gemm::PackMatrixA_omp_6r;
procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
packedA = static_cast<float *>( packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(m, KC, m % MR, A, lda, packedA);
packedB = static_cast<float *>( packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
}
packedC = static_cast<float *>( packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) { int mc, nc;
#pragma omp parallel for for (int j = 0; j < n; j += NC) {
for (int i = 0; i < m; i += MC) { nc = s_min(n - j, NC);
#ifdef _OPENMP #if __aarch64__
int local_threads = omp_get_thread_num(); // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else #else
int local_threads = 0; PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif #endif
for (int i = 0; i < m; i += MC) {
int mc;
mc = s_min(m - i, MC); mc = s_min(m - i, MC);
float *local_A = packedA + MC * KC * local_threads; #if __aarch64__
float *local_C = packedC + MC * NC * local_threads; PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
(*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
if (bias1 == nullptr) {
InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc,
p + i, mode, bias + i, nullptr);
} else {
InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc,
p + i, mode, bias + i, bias1 + i * ldc);
}
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else #else
int local_threads = 0; PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif #endif
if (bias == nullptr) {
int nc; InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
nc = s_min(n - j, NC); &C(i, j), ldc, relu, nullptr);
float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B);
if (bias1 == nullptr) {
InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p,
mode, bias, nullptr);
} else { } else {
InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC,
mode, bias, bias1 + j); &C(i, j), ldc, relu, bias + i);
}
} }
} }
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) {
#if __ARM_NEON
#if __aarch64__
// init C
float32x4_t cv0 = vdupq_n_f32(0.0);
float32x4_t cv1 = vdupq_n_f32(0.0);
float32x4_t cv2 = vdupq_n_f32(0.0);
float32x4_t cv3 = vdupq_n_f32(0.0);
float32x4_t cv4 = vdupq_n_f32(0.0);
float32x4_t cv5 = vdupq_n_f32(0.0);
float32x4_t cv6 = vdupq_n_f32(0.0);
float32x4_t cv7 = vdupq_n_f32(0.0);
float32x4_t cv8 = vdupq_n_f32(0.0);
float32x4_t cv9 = vdupq_n_f32(0.0);
float32x4_t cv10 = vdupq_n_f32(0.0);
float32x4_t cv11 = vdupq_n_f32(0.0);
float32x4_t av;
float32x4_t bv0;
float32x4_t bv1;
float32x2_t av01;
float32x2_t av23;
float32x2_t av45;
for (int p = 0; p < k; p += 1) {
av = vld1q_f32(a);
av01 = vget_low_f32(av);
av23 = vget_high_f32(av);
av45 = vld1_f32(a + 4);
bv0 = vld1q_f32(b);
bv1 = vld1q_f32(b + 4);
cv0 = vmlaq_lane_f32(cv0, bv0, av01, 0);
cv1 = vmlaq_lane_f32(cv1, bv1, av01, 0);
cv2 = vmlaq_lane_f32(cv2, bv0, av01, 1);
cv3 = vmlaq_lane_f32(cv3, bv1, av01, 1);
cv4 = vmlaq_lane_f32(cv4, bv0, av23, 0);
cv5 = vmlaq_lane_f32(cv5, bv1, av23, 0);
cv6 = vmlaq_lane_f32(cv6, bv0, av23, 1);
cv7 = vmlaq_lane_f32(cv7, bv1, av23, 1);
cv8 = vmlaq_lane_f32(cv8, bv0, av45, 0);
cv9 = vmlaq_lane_f32(cv9, bv1, av45, 0);
cv10 = vmlaq_lane_f32(cv10, bv0, av45, 1);
cv11 = vmlaq_lane_f32(cv11, bv1, av45, 1);
a += MR;
b += NR;
} }
vst1q_f32(c, cv0); paddle_mobile::memory::Free(packedA);
vst1q_f32(c + 4, cv1); paddle_mobile::memory::Free(packedB);
vst1q_f32(c + ldc, cv2); paddle_mobile::memory::Free(packedC);
vst1q_f32(c + ldc + 4, cv3); paddle_mobile::memory::Free(zero);
vst1q_f32(c + 2 * ldc, cv4); }
vst1q_f32(c + 2 * ldc + 4, cv5);
vst1q_f32(c + 3 * ldc, cv6);
vst1q_f32(c + 3 * ldc + 4, cv7);
vst1q_f32(c + 4 * ldc, cv8);
vst1q_f32(c + 4 * ldc + 4, cv9);
vst1q_f32(c + 5 * ldc, cv10);
vst1q_f32(c + 5 * ldc + 4, cv11);
#else
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
int kc1 = k / 8;
int kc2 = k % 8;
int step = sizeof(float) * ldc;
asm volatile(
"pld [%[a_ptr]] \n\t"
"pld [%[a_ptr], #64] \n\t"
"pld [%[b_ptr]] \n\t"
"pld [%[b_ptr], #64] \n\t"
"vmov.f32 q4, #0.0 \n\t"
"vmov.f32 q5, #0.0 \n\t"
"vmov.f32 q6, #0.0 \n\t"
"vmov.f32 q7, #0.0 \n\t"
"vmov.f32 q8, #0.0 \n\t"
"vmov.f32 q9, #0.0 \n\t"
"vmov.f32 q10, #0.0 \n\t"
"vmov.f32 q11, #0.0 \n\t"
"vmov.f32 q12, #0.0 \n\t"
"vmov.f32 q13, #0.0 \n\t"
"vmov.f32 q14, #0.0 \n\t"
"vmov.f32 q15, #0.0 \n\t"
"subs %[kc1], %[kc1], #1 \n\t"
"blt 2f \n\t"
"1: \n\t"
"pld [%[a_ptr], #128] \n\t"
"pld [%[b_ptr], #128] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[a_ptr], #128] \n\t"
"pld [%[b_ptr], #128] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t"
"vmla.f32 q5, q3, d0[0] \n\t"
"vmla.f32 q6, q2, d0[1] \n\t"
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t"
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t"
"vmla.f32 q4, q2, d0[0] \n\t" void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
"vmla.f32 q5, q3, d0[0] \n\t" int lda, const float *B, int ldb, float beta, float *C,
"vmla.f32 q6, q2, d0[1] \n\t" int ldc, bool relu, float *new_scale, float *new_bias,
"vmla.f32 q7, q3, d0[1] \n\t" float *bias) {
"vmla.f32 q8, q2, d1[0] \n\t" // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
"vmla.f32 q9, q3, d1[0] \n\t" // L2 cache is 0.5~4 Mib (Contex-A72 cluster)
"vmla.f32 q10, q2, d1[1] \n\t" int L1 = 32 * 1024;
"vmla.f32 q11, q3, d1[1] \n\t" int L2 = 512 * 1024;
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[a_ptr], #128] \n\t" KC = k;
"pld [%[b_ptr], #128] \n\t" MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" // make sure MC is multiple of MR, and NC is multiple of NR
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
"vmla.f32 q4, q2, d0[0] \n\t" packedA = static_cast<float *>(
"vmla.f32 q5, q3, d0[0] \n\t" paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
"vmla.f32 q6, q2, d0[1] \n\t" packedB = static_cast<float *>(
"vmla.f32 q7, q3, d0[1] \n\t" paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
"vmla.f32 q8, q2, d1[0] \n\t" packedC = static_cast<float *>(
"vmla.f32 q9, q3, d1[0] \n\t" paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
"vmla.f32 q10, q2, d1[1] \n\t" zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
"vmla.f32 q11, q3, d1[1] \n\t" memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" int mc, nc;
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
if (bias == nullptr) {
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i);
} else {
InnerKernelWithBnAdd(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i,
bias + i * ldc + j);
}
}
}
"vmla.f32 q4, q2, d0[0] \n\t" paddle_mobile::memory::Free(packedA);
"vmla.f32 q5, q3, d0[0] \n\t" paddle_mobile::memory::Free(packedB);
"vmla.f32 q6, q2, d0[1] \n\t" paddle_mobile::memory::Free(packedC);
"vmla.f32 q7, q3, d0[1] \n\t" paddle_mobile::memory::Free(zero);
"vmla.f32 q8, q2, d1[0] \n\t" }
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"pld [%[a_ptr], #128] \n\t" void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
"pld [%[b_ptr], #128] \n\t" const float *B, int ldb, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 32 * 1024;
int L2 = 0.5 * 1024 * 1024;
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" KC = k;
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" MC = L1 / (KC * sizeof(float));
NC = L2 / (KC * sizeof(float));
"vmla.f32 q4, q2, d0[0] \n\t" // make sure MC is multiple of MR, and NC is multiple of NR
"vmla.f32 q5, q3, d0[0] \n\t" if (MC == 0) {
"vmla.f32 q6, q2, d0[1] \n\t" MC = MR;
"vmla.f32 q7, q3, d0[1] \n\t" } else {
"vmla.f32 q8, q2, d1[0] \n\t" int mblock_num = (m + MC - 1) / MC;
"vmla.f32 q9, q3, d1[0] \n\t" MC = (m + mblock_num - 1) / mblock_num;
"vmla.f32 q10, q2, d1[1] \n\t" MC = (MC + MR - 1) / MR * MR;
"vmla.f32 q11, q3, d1[1] \n\t" }
"vmla.f32 q12, q2, d2[0] \n\t" // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
"vmla.f32 q13, q3, d2[0] \n\t" if (NC == 0) {
"vmla.f32 q14, q2, d2[1] \n\t" NC = NR;
"vmla.f32 q15, q3, d2[1] \n\t" } else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" packedA = static_cast<float *>(
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
"vmla.f32 q4, q2, d0[0] \n\t" for (int l = 0; l < KC; ++l) {
"vmla.f32 q5, q3, d0[0] \n\t" zero[l] = 0;
"vmla.f32 q6, q2, d0[1] \n\t" }
"vmla.f32 q7, q3, d0[1] \n\t"
"vmla.f32 q8, q2, d1[0] \n\t"
"vmla.f32 q9, q3, d1[0] \n\t"
"vmla.f32 q10, q2, d1[1] \n\t"
"vmla.f32 q11, q3, d1[1] \n\t"
"vmla.f32 q12, q2, d2[0] \n\t"
"vmla.f32 q13, q3, d2[0] \n\t"
"vmla.f32 q14, q2, d2[1] \n\t"
"vmla.f32 q15, q3, d2[1] \n\t"
"subs %[kc1], %[kc1], #1 \n\t" int mc, nc;
"bge 1b \n\t" for (int j = 0; j < n; j += NC) {
"2: \n\t" nc = s_min(n - j, NC);
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#else
PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
#endif
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
#if __aarch64__
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#endif
if (bias1 == nullptr) {
InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc,
p + i, mode, bias + i, nullptr);
} else {
InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc,
p + i, mode, bias + i, bias1 + i * ldc + j);
}
}
}
"subs %[kc2], %[kc2], #1 \n\t" paddle_mobile::memory::Free(packedA);
"blt 4f \n\t" paddle_mobile::memory::Free(packedB);
"3: \n\t" paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
"vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" // 32位 float 矩阵乘法
"vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *bias) {
#ifndef __aarch64__
if (m == 1 && bias == nullptr) {
return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu);
}
#endif // __aarch64__
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
"vmla.f32 q4, q2, d0[0] \n\t" // int L1 = 64 / max_threads * 1024;
"vmla.f32 q5, q3, d0[0] \n\t" int L = (max_threads > 2) ? 64 : 32;
"vmla.f32 q6, q2, d0[1] \n\t" int L1 = L / max_threads * 1024;
"vmla.f32 q7, q3, d0[1] \n\t" KC = k;
"vmla.f32 q8, q2, d1[0] \n\t" zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
"vmla.f32 q9, q3, d1[0] \n\t" memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
"vmla.f32 q10, q2, d1[1] \n\t" if (m > n) {
"vmla.f32 q11, q3, d1[1] \n\t" // 对 A 分块
"vmla.f32 q12, q2, d2[0] \n\t" MC = L1 / (KC * sizeof(float));
"vmla.f32 q13, q3, d2[0] \n\t" if (MC == 0) {
"vmla.f32 q14, q2, d2[1] \n\t" MC = MR;
"vmla.f32 q15, q3, d2[1] \n\t" } else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// 补齐 B
NC = (n + NR - 1) / NR * NR;
"subs %[kc2], %[kc2], #1 \n\t" #if __aarch64__
"bge 3b \n\t" procPackA = &Gemm::PackMatrixA_6r;
"4: \n\t" procPackB = &Gemm::PackMatrixB_omp_16c;
procAddDot = &Gemm::AddDot6x16;
#else
procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
"mov r5, %[c] \n\t" packedB = static_cast<float *>(
"mov r6, %[step] \n\t" paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
"vst1.32 {q4, q5}, [r5], r6 \n\t" (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB);
"vst1.32 {q6, q7}, [r5], r6 \n\t" packedA = static_cast<float *>(
"vst1.32 {q8, q9}, [r5], r6 \n\t" paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
"vst1.32 {q10, q11}, [r5], r6 \n\t" } else {
"vst1.32 {q12, q13}, [r5], r6 \n\t" // 对 B 分块
"vst1.32 {q14, q15}, [r5] \n\t" NC = L1 / (KC * sizeof(float));
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR - 1) / MR * MR;
: #if __aarch64__
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), procPackA = &Gemm::PackMatrixA_omp_6r;
[kc2] "r"(kc2), [step] "r"(step) procPackB = &Gemm::PackMatrixB_16c;
: "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", procAddDot = &Gemm::AddDot6x16;
"q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); #else
#endif // __aarch64__ procPackA = &Gemm::PackMatrixA_omp_6r;
procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
#endif // __ARM_NEON packedA = static_cast<float *>(
} paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
(*this.*procPackA)(m, KC, m % MR, A, lda, packedA);
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
#if __aarch64__ if (m > n) {
void Gemm::AddDot8x12(int k, const float *a, const float *b, float *c, #pragma omp parallel for
int ldc) { for (int i = 0; i < m; i += MC) {
const float *a_ptr, *b_ptr; #ifdef _OPENMP
a_ptr = a; int local_threads = omp_get_thread_num();
b_ptr = b; #else
int kc1 = k; int local_threads = 0;
int step = 4 * ldc; #endif
asm volatile(
"dup v5.4s, wzr \n\t"
"dup v6.4s, wzr \n\t"
"dup v7.4s, wzr \n\t"
"dup v8.4s, wzr \n\t"
"dup v9.4s, wzr \n\t"
"dup v10.4s, wzr \n\t"
"dup v11.4s, wzr \n\t"
"dup v12.4s, wzr \n\t"
"dup v13.4s, wzr \n\t"
"dup v14.4s, wzr \n\t"
"dup v15.4s, wzr \n\t"
"dup v16.4s, wzr \n\t"
"dup v17.4s, wzr \n\t" int mc;
"dup v18.4s, wzr \n\t" mc = s_min(m - i, MC);
"dup v19.4s, wzr \n\t" float *local_A = packedA + MC * KC * local_threads;
"dup v20.4s, wzr \n\t" float *local_C = packedC + MC * NC * local_threads;
"dup v21.4s, wzr \n\t" (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A);
"dup v22.4s, wzr \n\t" if (bias == nullptr) {
"dup v23.4s, wzr \n\t" InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
"dup v24.4s, wzr \n\t" &C(i, 0), ldc, relu, nullptr);
"dup v25.4s, wzr \n\t" } else {
"dup v26.4s, wzr \n\t" InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C,
"dup v27.4s, wzr \n\t" &C(i, 0), ldc, relu, bias + i);
"dup v28.4s, wzr \n\t" }
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
"subs %[kc1], %[kc1], #1 \n\t" int nc;
"blt 2f \n\t" nc = s_min(n - j, NC);
"1: \n\t" float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B);
InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, bias);
}
}
"prfm pldl1keep, [%[a_ptr], #32] \n\t" paddle_mobile::memory::Free(packedA);
"prfm pldl1keep, [%[b_ptr], #48] \n\t" paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
"ld1 {v0.4s, v1.4s}, [%[a_ptr]], #32 \n\t" void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
"ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48 \n\t" int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu, float *new_scale,
float *new_bias, float *bias) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
"fmla v5.4s, v2.4s, v0.s[0] \n\t" int L1 = 64 / max_threads * 1024;
"fmla v6.4s, v3.4s, v0.s[0] \n\t" KC = k;
"fmla v7.4s, v4.4s, v0.s[0] \n\t" zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
"fmla v8.4s, v2.4s, v0.s[1] \n\t" memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
"fmla v9.4s, v3.4s, v0.s[1] \n\t" if (m > n) {
"fmla v10.4s, v4.4s, v0.s[1] \n\t" // 对 A 分块
"fmla v11.4s, v2.4s, v0.s[2] \n\t" MC = L1 / (KC * sizeof(float));
"fmla v12.4s, v3.4s, v0.s[2] \n\t" if (MC == 0) {
"fmla v13.4s, v4.4s, v0.s[2] \n\t" MC = MR;
"fmla v14.4s, v2.4s, v0.s[3] \n\t" } else {
"fmla v15.4s, v3.4s, v0.s[3] \n\t" int mblock_num = (m + MC - 1) / MC;
"fmla v16.4s, v4.4s, v0.s[3] \n\t" MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// 补齐 B
NC = (n + NR - 1) / NR * NR;
"fmla v17.4s, v2.4s, v1.s[0] \n\t" #if __aarch64__
"fmla v18.4s, v3.4s, v1.s[0] \n\t" procPackA = &Gemm::PackMatrixA_6r;
"fmla v19.4s, v4.4s, v1.s[0] \n\t" procPackB = &Gemm::PackMatrixB_omp_16c;
"fmla v20.4s, v2.4s, v1.s[1] \n\t" procAddDot = &Gemm::AddDot6x16;
"fmla v21.4s, v3.4s, v1.s[1] \n\t" #else
"fmla v22.4s, v4.4s, v1.s[1] \n\t" procPackA = &Gemm::PackMatrixA_6r;
"fmla v23.4s, v2.4s, v1.s[2] \n\t" procPackB = &Gemm::PackMatrixB_omp_8c;
"fmla v24.4s, v3.4s, v1.s[2] \n\t" procAddDot = &Gemm::AddDot6x8;
"fmla v25.4s, v4.4s, v1.s[2] \n\t" #endif
"fmla v26.4s, v2.4s, v1.s[3] \n\t"
"fmla v27.4s, v3.4s, v1.s[3] \n\t"
"fmla v28.4s, v4.4s, v1.s[3] \n\t"
"subs %[kc1], %[kc1], #1 \n\t" packedB = static_cast<float *>(
"bge 1b \n\t" paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
"2: \n\t" (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB);
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR - 1) / MR * MR;
"st1 {v5.4s, v6.4s, v7.4s}, [%[c]], %[step] \n\t" #if __aarch64__
"st1 {v8.4s, v9.4s, v10.4s}, [%[c]], %[step] \n\t" procPackA = &Gemm::PackMatrixA_omp_6r;
"st1 {v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" procPackB = &Gemm::PackMatrixB_16c;
"st1 {v14.4s, v15.4s, v16.4s}, [%[c]], %[step] \n\t" procAddDot = &Gemm::AddDot6x16;
"st1 {v17.4s, v18.4s, v19.4s}, [%[c]], %[step] \n\t" #else
"st1 {v20.4s, v21.4s, v22.4s}, [%[c]], %[step] \n\t" procPackA = &Gemm::PackMatrixA_omp_6r;
"st1 {v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" procPackB = &Gemm::PackMatrixB_8c;
"st1 {v26.4s, v27.4s, v28.4s}, [%[c]], %[step] \n\t" procAddDot = &Gemm::AddDot6x8;
: #endif
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28");
}
void Gemm::AddDot6x16(int k, const float *a, const float *b, float *c, packedA = static_cast<float *>(
int ldc) { paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
const float *a_ptr, *b_ptr; (*this.*procPackA)(m, KC, m % MR, A, lda, packedA);
a_ptr = a; packedB = static_cast<float *>(
b_ptr = b; paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
int kc1 = k; }
int step = 4 * ldc; packedC = static_cast<float *>(
int step1 = 4 * 6; paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
asm volatile(
"dup v6.4s, wzr \n\t" if (m > n) {
"dup v7.4s, wzr \n\t" #pragma omp parallel for
"dup v8.4s, wzr \n\t" for (int i = 0; i < m; i += MC) {
"dup v9.4s, wzr \n\t" #ifdef _OPENMP
"dup v10.4s, wzr \n\t" int local_threads = omp_get_thread_num();
"dup v11.4s, wzr \n\t" #else
"dup v12.4s, wzr \n\t" int local_threads = 0;
"dup v13.4s, wzr \n\t" #endif
"dup v14.4s, wzr \n\t" int mc;
"dup v15.4s, wzr \n\t" mc = s_min(m - i, MC);
"dup v16.4s, wzr \n\t" float *local_A = packedA + MC * KC * local_threads;
"dup v17.4s, wzr \n\t" float *local_C = packedC + MC * NC * local_threads;
"dup v18.4s, wzr \n\t" (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A);
"dup v19.4s, wzr \n\t" if (bias == nullptr) {
"dup v20.4s, wzr \n\t" InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C,
"dup v21.4s, wzr \n\t" &C(i, 0), ldc, relu, new_scale + i, new_bias + i);
} else {
InnerKernelWithBnAdd(mc, n, alpha, local_A, packedB, beta, local_C,
&C(i, 0), ldc, relu, new_scale + i, new_bias + i,
bias + i * ldc);
}
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
"dup v22.4s, wzr \n\t" int nc;
"dup v23.4s, wzr \n\t" nc = s_min(n - j, NC);
"dup v24.4s, wzr \n\t" float *local_B = packedB + KC * NC * local_threads;
"dup v25.4s, wzr \n\t" float *local_C = packedC + MC * NC * local_threads;
"dup v26.4s, wzr \n\t" (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B);
"dup v27.4s, wzr \n\t" if (bias == nullptr) {
"dup v28.4s, wzr \n\t" InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C,
"dup v29.4s, wzr \n\t" &C(0, j), ldc, relu, new_scale, new_bias);
} else {
InnerKernelWithBnAdd(m, nc, alpha, packedA, local_B, beta, local_C,
&C(0, j), ldc, relu, new_scale, new_bias,
bias + j);
}
}
}
"subs %[kc1], %[kc1], #1 \n\t" paddle_mobile::memory::Free(packedA);
"blt 2f \n\t" paddle_mobile::memory::Free(packedB);
"1: \n\t" paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
"prfm pldl1keep, [%[a_ptr], #24] \n\t" void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
"prfm pldl1keep, [%[b_ptr], #64] \n\t" const float *B, int ldb, float *C, int ldc,
float *p, std::string mode, float *bias,
float *bias1) {
#ifdef _OPENMP
int max_threads = omp_get_max_threads();
#else
int max_threads = 1;
#endif
"ld1 {v0.4s, v1.4s}, [%[a_ptr]], %[step1] \n\t" int L1 = 8 * 1024;
"ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], #64 \n\t" KC = k;
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
memset(static_cast<void *>(zero), 0, sizeof(float) * KC);
if (m > n) {
// 对 A 分块
MC = L1 / (KC * sizeof(float));
if (MC == 0) {
MC = MR;
} else {
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + MR - 1) / MR * MR;
}
// 补齐 B
NC = (n + NR - 1) / NR * NR;
"fmla v6.4s, v2.4s, v0.s[0] \n\t" #if __aarch64__
"fmla v7.4s, v3.4s, v0.s[0] \n\t" procPackA = &Gemm::PackMatrixA_6r;
"fmla v8.4s, v4.4s, v0.s[0] \n\t" procPackB = &Gemm::PackMatrixB_omp_16c;
"fmla v9.4s, v5.4s, v0.s[0] \n\t" procAddDot = &Gemm::AddDot6x16;
#else
procPackA = &Gemm::PackMatrixA_6r;
procPackB = &Gemm::PackMatrixB_omp_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
"fmla v10.4s, v2.4s, v0.s[1] \n\t" packedB = static_cast<float *>(
"fmla v11.4s, v3.4s, v0.s[1] \n\t" paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
"fmla v12.4s, v4.4s, v0.s[1] \n\t" (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB);
"fmla v13.4s, v5.4s, v0.s[1] \n\t" packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads));
} else {
// 对 B 分块
NC = L1 / (KC * sizeof(float));
if (NC == 0) {
NC = NR;
} else {
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + NR - 1) / NR * NR;
}
// 补齐 A
MC = (m + MR - 1) / MR * MR;
"fmla v14.4s, v2.4s, v0.s[2] \n\t" #if __aarch64__
"fmla v15.4s, v3.4s, v0.s[2] \n\t" procPackA = &Gemm::PackMatrixA_omp_6r;
"fmla v16.4s, v4.4s, v0.s[2] \n\t" procPackB = &Gemm::PackMatrixB_16c;
"fmla v17.4s, v5.4s, v0.s[2] \n\t" procAddDot = &Gemm::AddDot6x16;
#else
procPackA = &Gemm::PackMatrixA_omp_6r;
procPackB = &Gemm::PackMatrixB_8c;
procAddDot = &Gemm::AddDot6x8;
#endif
"fmla v18.4s, v2.4s, v0.s[3] \n\t" packedA = static_cast<float *>(
"fmla v19.4s, v3.4s, v0.s[3] \n\t" paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
"fmla v20.4s, v4.4s, v0.s[3] \n\t" (*this.*procPackA)(m, KC, m % MR, A, lda, packedA);
"fmla v21.4s, v5.4s, v0.s[3] \n\t" packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads));
}
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads));
"fmla v22.4s, v2.4s, v1.s[0] \n\t" if (m > n) {
"fmla v23.4s, v3.4s, v1.s[0] \n\t" #pragma omp parallel for
"fmla v24.4s, v4.4s, v1.s[0] \n\t" for (int i = 0; i < m; i += MC) {
"fmla v25.4s, v5.4s, v1.s[0] \n\t" #ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
"fmla v26.4s, v2.4s, v1.s[1] \n\t" int mc;
"fmla v27.4s, v3.4s, v1.s[1] \n\t" mc = s_min(m - i, MC);
"fmla v28.4s, v4.4s, v1.s[1] \n\t" float *local_A = packedA + MC * KC * local_threads;
"fmla v29.4s, v5.4s, v1.s[1] \n\t" float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A);
if (bias1 == nullptr) {
InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc,
p + i, mode, bias + i, nullptr);
} else {
InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc,
p + i, mode, bias + i, bias1 + i * ldc);
}
}
} else {
#pragma omp parallel for
for (int j = 0; j < n; j += NC) {
#ifdef _OPENMP
int local_threads = omp_get_thread_num();
#else
int local_threads = 0;
#endif
"subs %[kc1], %[kc1], #1 \n\t" int nc;
"bge 1b \n\t" nc = s_min(n - j, NC);
"2: \n\t" float *local_B = packedB + KC * NC * local_threads;
float *local_C = packedC + MC * NC * local_threads;
(*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B);
if (bias1 == nullptr) {
InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p,
mode, bias, nullptr);
} else {
InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p,
mode, bias, bias1 + j);
}
}
}
"st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t" paddle_mobile::memory::Free(packedA);
"st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" paddle_mobile::memory::Free(packedB);
"st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t" paddle_mobile::memory::Free(packedC);
"st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t" paddle_mobile::memory::Free(zero);
"st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t"
"st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t"
:
: [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1),
[step] "r"(step), [step1] "r"(step1)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9",
"v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29");
} }
#endif // __aarch64__
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -46,15 +46,6 @@ namespace math { ...@@ -46,15 +46,6 @@ namespace math {
class Gemm { class Gemm {
public: public:
/*
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
*/
typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *); typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *);
typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *, typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *,
int); int);
...@@ -62,31 +53,31 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -62,31 +53,31 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
FnPack procPackB; FnPack procPackB;
FnAddDot procAddDot; FnAddDot procAddDot;
// 将 A 矩阵分块复制到连续内存(RowMajor) // 将 A\B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
#if __aarch64__
void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
#endif
// 分块矩阵乘法 // 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
...@@ -106,22 +97,16 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -106,22 +97,16 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *c, float *C, int ldc, float *p, float *c, float *C, int ldc, float *p,
std::string mode, float *bias, float *bias1); std::string mode, float *bias, float *bias1);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
/*
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float
*C, int ldc, bool relu, float *new_scale, float *new_bias);
*/
// 计算一个更小的 C 矩阵分块 // 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); #if __aarch64__
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc); void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc);
#else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc);
#endif
// 分块矩阵乘法结果回写 // 分块矩阵乘法结果回写
// C = A * B // C = A * B
...@@ -149,6 +134,18 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -149,6 +134,18 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias, float *bias1); float *new_scale, float *new_bias, float *bias1);
// 向量矩阵乘法 (M = 1)
#if __aarch64__
#else
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu, float *new_scale,
float *new_bias);
// 向量矩阵乘法结果回写 // 向量矩阵乘法结果回写
// C = A * B // C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc); void VecWriteBasic(int n, float *c, float *C, int ldc);
...@@ -158,14 +155,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, ...@@ -158,14 +155,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void VecWriteWithAdd(int n, float *c, float *C, int ldc); void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C) // C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc); void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
/*
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias); float *new_bias);
// C = A * B, batchnorm(C), relu(C) // C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
*new_scale, float *new_bias); float *new_bias);
*/ #endif
// 32位 float 矩阵乘法 // 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
......
...@@ -3066,5 +3066,52 @@ class ReadFromArrayParam : public OpParam { ...@@ -3066,5 +3066,52 @@ class ReadFromArrayParam : public OpParam {
}; };
#endif #endif
#ifdef IS_EMPTY_OP
template <typename Dtype>
class IsEmptyParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
IsEmptyParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return output_; }
public:
GType *input_x_;
GType *output_;
};
#endif // IS_EMPTY_OP
#ifdef INCREMENT_OP
template <typename Dtype>
class IncrementParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
IncrementParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
step_ = OpParam::GetAttr<int>("step", attrs);
}
const GType *InputX() const { return input_x_; }
GType *Out() const { return output_; }
int Step() const { return step_; }
public:
GType *input_x_;
GType *output_;
int step_;
};
#endif // INCREMENT_OP
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -437,6 +437,14 @@ if (NOT FOUND_MATCH) ...@@ -437,6 +437,14 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-logical-xor-op operators/test_logical_xor_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-logical-xor-op operators/test_logical_xor_op.cpp test_helper.h test_include.h)
target_link_libraries(test-logical-xor-op paddle-mobile) target_link_libraries(test-logical-xor-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-increment-op operators/test_increment_op.cpp test_helper.h test_include.h)
target_link_libraries(test-increment-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-is-empty-op operators/test_is_empty_op.cpp test_helper.h test_include.h)
target_link_libraries(test-is-empty-op paddle-mobile)
ADD_EXECUTABLE(test-conv-bn-relu-op operators/test_conv_bn_relu_op.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-conv-bn-relu-op operators/test_conv_bn_relu_op.cpp test_helper.h test_include.h)
target_link_libraries(test-conv-bn-relu-op paddle-mobile) target_link_libraries(test-conv-bn-relu-op paddle-mobile)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/increment_op.h"
namespace paddle_mobile {
template <typename T>
void Increment(const framework::Tensor *input, framework::Tensor *out,
int step) {
auto input_data = input->data<T>();
auto out_data = out->data<T>();
*out_data = *input_data + step;
}
int TestIncrementOp(const std::vector<int> input_shape, int step) {
framework::DDim input_dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"inputX"});
outputs["Out"] = std::vector<std::string>({"output"});
auto x_var = scope.get()->Var("inputX");
auto x = x_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(x, input_dims, 0, 100);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["step"].Set<int>(step);
auto *op = new operators::IncrementOp<CPU, float>("increment", inputs,
outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Increment<float>(x, &output_cmp, step);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestIncrementOp({1}, 4);
paddle_mobile::TestIncrementOp({1}, 10);
DLOG << "test increment op pass.";
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/is_empty_op.h"
namespace paddle_mobile {
void IsEmpty(const framework::Tensor *input, framework::Tensor *out) {
out->data<bool>()[0] = input->numel() == 0;
}
int TestIsEmptyOp(const std::vector<int> input_shape) {
framework::DDim input_dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"inputX"});
outputs["Out"] = std::vector<std::string>({"output"});
auto x_var = scope.get()->Var("inputX");
auto x = x_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(x, input_dims, 0, 100);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::IsEmptyOp<CPU, float>("is_empty", inputs, outputs,
attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
bool *output_cmp_data = output_cmp.mutable_data<bool>(output->dims());
IsEmpty(x, &output_cmp);
const bool *output_data = output->data<bool>();
for (int i = 0; i < output->numel(); ++i) {
if (output_data[i] != output_cmp_data[i]) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestIsEmptyOp({1, 3, 100, 100});
paddle_mobile::TestIsEmptyOp({0});
DLOG << "test is_empty op pass.";
return 0;
}
...@@ -288,6 +288,8 @@ if(NOT FOUND_MATCH) ...@@ -288,6 +288,8 @@ if(NOT FOUND_MATCH)
set(WHILE_OP ON) set(WHILE_OP ON)
set(WRITE_TO_ARRAY_OP ON) set(WRITE_TO_ARRAY_OP ON)
set(READ_FROM_ARRAY_OP ON) set(READ_FROM_ARRAY_OP ON)
set(IS_EMPTY_OP ON)
set(INCREMENT_OP ON)
set(ANCHOR_GENERATOR_OP ON) set(ANCHOR_GENERATOR_OP ON)
set(PROPOSAL_OP ON) set(PROPOSAL_OP ON)
set(PSROI_POOL_OP ON) set(PSROI_POOL_OP ON)
...@@ -576,6 +578,12 @@ endif() ...@@ -576,6 +578,12 @@ endif()
if (READ_FROM_ARRAY_OP) if (READ_FROM_ARRAY_OP)
add_definitions(-DREAD_FROM_ARRAY_OP) add_definitions(-DREAD_FROM_ARRAY_OP)
endif() endif()
if (IS_EMPTY_OP)
add_definitions(-DIS_EMPTY_OP)
endif()
if (INCREMENT_OP)
add_definitions(-DINCREMENT_OP)
endif()
if (ANCHOR_GENERATOR_OP) if (ANCHOR_GENERATOR_OP)
add_definitions(-DANCHOR_GENERATOR_OP) add_definitions(-DANCHOR_GENERATOR_OP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册