未验证 提交 66d2ae25 编写于 作者: Y yiicy 提交者: GitHub

[ARM] add sgemmc4 common and small kernel, support for winograd, test=develop (#2471)

* unfinish sgemmc4

* finish armv8 sgemmc4

* arm add sgemmc4 with deal with remain

* [ARM] add sgemmc4 small kernel, test=develop
上级 a7f7d49b
...@@ -60,6 +60,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -60,6 +60,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
cc_library(math_arm SRCS cc_library(math_arm SRCS
funcs.cc funcs.cc
packed_sgemm.cc packed_sgemm.cc
packed_sgemm_c4.cc
sgemm.cc sgemm.cc
gemm_prepacked_int8.cc gemm_prepacked_int8.cc
gemm_s8.cc gemm_s8.cc
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include "lite/backends/arm/math/negative.h" #include "lite/backends/arm/math/negative.h"
#include "lite/backends/arm/math/norm.h" #include "lite/backends/arm/math/norm.h"
#include "lite/backends/arm/math/packed_sgemm.h" #include "lite/backends/arm/math/packed_sgemm.h"
#include "lite/backends/arm/math/packed_sgemm_c4.h"
#include "lite/backends/arm/math/pad2d.h" #include "lite/backends/arm/math/pad2d.h"
#include "lite/backends/arm/math/pooling.h" #include "lite/backends/arm/math/pooling.h"
#include "lite/backends/arm/math/power.h" #include "lite/backends/arm/math/power.h"
......
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
constexpr int MBLOCK_C4 = 4;
constexpr int NBLOCK_C4 = 8;
constexpr int KBLOCK_C4 = 4;
void sgemm_prepack_c4(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
const float* bias,
bool has_bias,
bool has_relu,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM)) if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(sgemm_compute_test SRCS sgemm_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(sgemv_compute_test SRCS sgemv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(sgemm_c4_compute_test SRCS sgemm_c4_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(gemm_int8_compute_test SRCS gemm_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(gemv_int8_compute_test SRCS gemv_int8_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(conv_compute_test SRCS conv_compute_test.cc DEPS arena_framework ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
#endif // LITE_WITH_ARM
#include "lite/core/context.h"
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/tests/utils/timer.h"
typedef paddle::lite::Tensor Tensor;
using paddle::lite::Timer;
DEFINE_int32(power_mode,
3,
"power mode: "
"0 for POWER_HIGH;"
"1 for POWER_LOW;"
"2 for POWER_FULL;"
"3 for NO_BIND");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(M, 512, "gemm_c4: M");
DEFINE_int32(N, 512, "gemm_c4: N");
DEFINE_int32(K, 512, "gemm_c4: K");
DEFINE_bool(flag_relu, false, "do relu");
DEFINE_bool(flag_bias, false, "with bias");
bool test_sgemm_c4(
int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) {
int m_round = (m + 3) / 4 * 4;
int k_round = (k + 3) / 4 * 4;
int size_a = m * k;
int size_b = n * k;
int size_a_c4 = m_round * k_round;
int size_b_c4 = k_round * n;
Tensor ta;
Tensor tb;
Tensor ta_c4;
Tensor tb_c4;
Tensor tc;
Tensor tc_basic;
Tensor tc_backup;
Tensor tbias;
ta.Resize({size_a});
tb.Resize({size_b});
ta_c4.Resize({size_a_c4});
tb_c4.Resize({size_b_c4});
tc.Resize({m_round * n});
tc_basic.Resize({m_round * n});
tbias.Resize({m});
ta.set_precision(PRECISION(kFloat));
tb.set_precision(PRECISION(kFloat));
ta_c4.set_precision(PRECISION(kFloat));
tb_c4.set_precision(PRECISION(kFloat));
tc.set_precision(PRECISION(kFloat));
tc_basic.set_precision(PRECISION(kFloat));
tbias.set_precision(PRECISION(kFloat));
fill_tensor_rand(ta, -1.f, 1.f);
fill_tensor_rand(tb, -1.f, 1.f);
fill_tensor_rand(tbias, -1.f, 1.f);
fill_tensor_rand(tc, -1.f, 1.f);
auto da = ta.mutable_data<float>();
auto db = tb.mutable_data<float>();
auto da_c4 = ta_c4.mutable_data<float>();
auto db_c4 = tb_c4.mutable_data<float>();
auto dc_basic = tc_basic.mutable_data<float>();
auto dbias = tbias.mutable_data<float>();
// trans A, B to c4
basic_trans_mat_to_c4(da, da_c4, k, m, k, true);
basic_trans_mat_to_c4(db, db_c4, n, k, n, false);
LOG(INFO) << "sgemm_c4 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm_c4(false,
false,
m,
n,
k,
1.f,
da,
k,
db,
n,
0.f,
dc_basic,
n,
dbias,
has_bias,
has_relu);
}
Timer t0;
#ifdef LITE_WITH_ARM
//! compute
double ops = 2.0 * m_round * n * k_round;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
auto dc = tc.mutable_data<float>();
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemm_prepack_c4(
m, n, k, da_c4, db_c4, dc, dbias, has_bias, has_relu, &ctx);
}
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.start();
paddle::lite::arm::math::sgemm_prepack_c4(
m, n, k, da_c4, db_c4, dc, dbias, has_bias, has_relu, &ctx);
t0.end();
}
LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k
<< ", power_mode: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.get_average_ms()
<< " ms, min time: " << t0.get_min_time()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.get_average_ms()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.get_min_time()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kFloat));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "a: ";
print_tensor(ta);
LOG(INFO) << "a_c4: ";
print_tensor(ta_c4);
LOG(INFO) << "b: ";
print_tensor(tb);
LOG(INFO) << "b_c4: ";
print_tensor(tb_c4);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "lite result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
}
#endif
return true;
}
TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789}) {
for (auto& k : {1, 3, 8, 59, 234}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_sgemm_c4(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " failed\n";
}
}
}
}
}
}
}
}
}
TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
auto flag = test_sgemm_c4(FLAGS_M,
FLAGS_N,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_power_mode,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!";
}
...@@ -14,6 +14,108 @@ ...@@ -14,6 +14,108 @@
#pragma once #pragma once
template <typename type>
static void basic_trans_mat_to_c4(const type* input,
type* output,
const int ldin,
const int M,
const int K,
bool pack_k) {
const int m_round = (M + 3) / 4 * 4;
int k_round = (K + 3) / 4 * 4;
if (!pack_k) {
k_round = K;
}
const int m_loop = m_round / 4;
type zero_buf[K];
memset(zero_buf, 0, K * sizeof(type));
for (int i = 0; i < m_loop; ++i) {
const type* in0 = input + i * 4 * ldin;
const type* in1 = in0 + ldin;
const type* in2 = in1 + ldin;
const type* in3 = in2 + ldin;
if (4 * (i + 1) - M > 0) {
switch (4 * (i + 1) - M) {
case 3:
in1 = zero_buf;
case 2:
in2 = zero_buf;
case 1:
in3 = zero_buf;
default:
break;
}
}
for (int j = 0; j < K; ++j) {
*output++ = *in0++;
*output++ = *in1++;
*output++ = *in2++;
*output++ = *in3++;
}
for (int j = K; j < k_round; ++j) {
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
}
}
}
template <typename type, typename type2>
static void basic_gemm_c4(bool trans_a,
bool trans_b,
int m,
int n,
int k,
type2 alpha,
const type* a,
int lda,
const type* b,
int ldb,
type2 beta,
type2* c,
int ldc,
const type2* bias,
bool flag_bias = false,
bool flag_relu = false) {
type2* tmp_c = reinterpret_cast<type2*>(malloc(m * ldc * sizeof(type2)));
memset(tmp_c, 0, m * ldc * sizeof(type2));
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0);
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
auto sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * lda + i];
} else {
av = a[i * lda + l];
}
if (trans_b) {
bv = b[j * ldb + l];
} else {
bv = b[l * ldb + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data;
if (flag_relu) {
tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
tmp_c[i * ldc + j] = tmp;
}
}
}
//! trans c to c4
basic_trans_mat_to_c4(tmp_c, c, ldc, m, n, false);
free(tmp_c);
}
template <typename type, typename type2> template <typename type, typename type2>
static void basic_gemm(bool trans_a, static void basic_gemm(bool trans_a,
bool trans_b, bool trans_b,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册