“8c6e3ab9fc30f240bdc8e022a2977ce4d48d871d”上不存在“example/auto_compression/image_classification/run_ppclas.py”
提交 c02023ab 编写于 作者: H hjchen2

Revert package lhs in sgemm and depthwise conv5x5 to make it no problem on ios

上级 c02076f0
...@@ -42,6 +42,7 @@ inline DataLayout StringToDataLayout(const std::string &str) { ...@@ -42,6 +42,7 @@ inline DataLayout StringToDataLayout(const std::string &str) {
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION("Unknown storage order string: %s", s.c_str()) PADDLE_MOBILE_THROW_EXCEPTION("Unknown storage order string: %s", s.c_str())
} }
return DataLayout::kNCHW;
} }
inline std::string DataLayoutToString(const DataLayout &data_layout) { inline std::string DataLayoutToString(const DataLayout &data_layout) {
......
...@@ -82,6 +82,8 @@ struct Dim<0> { ...@@ -82,6 +82,8 @@ struct Dim<0> {
int64_t &operator[](int idx); int64_t &operator[](int idx);
int64_t operator[](int idx) const; int64_t operator[](int idx) const;
int64_t head;
}; };
namespace { namespace {
...@@ -131,6 +133,7 @@ int64_t &indexer(Dim<D> &dim, int idx) { ...@@ -131,6 +133,7 @@ int64_t &indexer(Dim<D> &dim, int idx) {
template <> template <>
int64_t &indexer<0>(Dim<0> &dim, int idx) { int64_t &indexer<0>(Dim<0> &dim, int idx) {
PADDLE_MOBILE_THROW_EXCEPTION("Invalid index") PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
return dim.head;
} }
template <int D> template <int D>
...@@ -147,6 +150,7 @@ int64_t indexer(const Dim<D> &dim, int idx) { ...@@ -147,6 +150,7 @@ int64_t indexer(const Dim<D> &dim, int idx) {
template <> template <>
int64_t indexer<0>(const Dim<0> &dim, int idx) { int64_t indexer<0>(const Dim<0> &dim, int idx) {
PADDLE_MOBILE_THROW_EXCEPTION("Invalid index") PADDLE_MOBILE_THROW_EXCEPTION("Invalid index")
return dim.head;
} }
} // namespace } // namespace
......
...@@ -201,16 +201,16 @@ inline void DepthwiseConv5x5(const ConvParam<CPU> &param) { ...@@ -201,16 +201,16 @@ inline void DepthwiseConv5x5(const ConvParam<CPU> &param) {
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<Otype>(); output->mutable_data<Otype>();
if (strides[0] == 1) { // if (strides[0] == 1) {
for (int i = 0; i < batch_size; i++) { // for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1); // Tensor in_batch = input->Slice(i, i + 1);
Tensor out_batch = output->Slice(i, i + 1); // Tensor out_batch = output->Slice(i, i + 1);
math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings, // math::DepthwiseConv5x5S1<Itype, Otype>(in_batch, *filter, paddings,
&out_batch); // &out_batch);
} // }
} else { // } else {
GemmConv<Itype, Otype>(param); GemmConv<Itype, Otype>(param);
} // }
} }
template <typename ParamType> template <typename ParamType>
......
...@@ -144,20 +144,21 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -144,20 +144,21 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
const float *input_data = input.data<float>(); const float *input_data = input.data<float>();
const float *filter_data = filter.data<float>(); const float *filter_data = filter.data<float>();
float *out_data = output->mutable_data<float>(); float *out_data = output->mutable_data<float>();
int input_h = input.dims()[2];
int input_w = input.dims()[3]; const int input_h = input.dims()[2];
int output_h = output->dims()[2]; const int input_w = input.dims()[3];
int output_w = output->dims()[3]; const int output_h = output->dims()[2];
int padding_h = paddings[0]; const int output_w = output->dims()[3];
int padding_w = paddings[1]; const int padding_h = paddings[0];
int image_size = input_h * input_w; const int padding_w = paddings[1];
int out_image_size = output_h * output_w; const int image_size = input_h * input_w;
int valid_h_start = padding_h; const int out_image_size = output_h * output_w;
int valid_h_end = output_h - valid_h_start; const int valid_h_start = padding_h;
int valid_h = valid_h_end - valid_h_start; const int valid_h_end = output_h - valid_h_start;
int valid_w_start = padding_w; const int valid_h = valid_h_end - valid_h_start;
int valid_w_end = output_w - valid_w_start; const int valid_w_start = padding_w;
int valid_w = valid_w_end - valid_w_start; const int valid_w_end = output_w - valid_w_start;
const int valid_w = valid_w_end - valid_w_start;
#pragma omp parallel for #pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) { for (int g = 0; g < input.dims()[1]; ++g) {
......
...@@ -18,7 +18,8 @@ limitations under the License. */ ...@@ -18,7 +18,8 @@ limitations under the License. */
#ifdef _OPENMP #ifdef _OPENMP
#include <omp.h> #include <omp.h>
#endif #endif
#include <sys/time.h> // #include <sys/time.h>
// #include <iostream>
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include "operators/math/gemm/cpu_info.h" #include "operators/math/gemm/cpu_info.h"
...@@ -158,7 +159,8 @@ class GemmExecutor : public Executor { ...@@ -158,7 +159,8 @@ class GemmExecutor : public Executor {
} }
} }
} }
strategy_.write(lhs_range, N_, local_C, ldc_, C + lhs_block * ldc, ldc); strategy_.write(lhs_range, N_, alpha, local_C, ldc_, beta,
C + lhs_block * ldc, ldc);
} }
} else { } else {
strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true); strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true);
...@@ -188,7 +190,8 @@ class GemmExecutor : public Executor { ...@@ -188,7 +190,8 @@ class GemmExecutor : public Executor {
} }
} }
} }
strategy_.write(M_, rhs_range, local_C, ldc_, C + rhs_block, ldc); strategy_.write(M_, rhs_range, alpha, local_C, ldc_, beta,
C + rhs_block, ldc);
} }
} }
......
...@@ -31,275 +31,345 @@ inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) { ...@@ -31,275 +31,345 @@ inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) {
void pack_lhs_6r(const int m, const int k, const float *A, const int lda, void pack_lhs_6r(const int m, const int k, const float *A, const int lda,
float *output, const bool unroll) { float *output, const bool unroll) {
uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5}; float *zero = new float[k];
int remain_k = k & 0x3; memset(zero, 0, k * sizeof(float));
uint32x4_t vzero = vdupq_n_u32(0);
uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k));
#pragma omp parallel for if (unroll) const int m_tail = m % 6;
for (int i = 0; i < m - 5; i += 6) { const int i_length = m - m_tail;
for (int i = 0; i < i_length; i += 6) {
const float *a0 = A + i * lda; const float *a0 = A + i * lda;
const float *a1 = A + (i + 1) * lda; const float *a1 = A + (i + 1) * lda;
const float *a2 = A + (i + 2) * lda; const float *a2 = A + (i + 2) * lda;
const float *a3 = A + (i + 3) * lda; const float *a3 = A + (i + 3) * lda;
const float *a4 = A + (i + 4) * lda; const float *a4 = A + (i + 4) * lda;
const float *a5 = A + (i + 5) * lda; const float *a5 = A + (i + 5) * lda;
float *out_ptr = output + i * k; float *local_buffer = output + i * k;
for (int j = 0; j < k; ++j) {
int loops = k >> 2; *local_buffer++ = *a0++;
if (loops > 0) { *local_buffer++ = *a1++;
#if __aarch64__ *local_buffer++ = *a2++;
for (int l = 0; l < loops; ++l) { *local_buffer++ = *a3++;
float32x4_t _d0 = vld1q_f32(a0); *local_buffer++ = *a4++;
float32x4_t _d1 = vld1q_f32(a1); *local_buffer++ = *a5++;
float32x4_t _d2 = vld1q_f32(a2); }
float32x4_t _d3 = vld1q_f32(a3); }
float32x4_t _d4 = vld1q_f32(a4); if (m_tail != 0) {
float32x4_t _d5 = vld1q_f32(a5); const float *a0 = A + i_length * lda;
float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
_d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
_d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
_d2 =
vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
_d3 =
vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1]));
vst1q_f32(out_ptr, _d0);
vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0]));
vst1q_f32(out_ptr + 6, _d1);
vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1]));
vst1q_f32(out_ptr + 12, _d2);
vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0]));
vst1q_f32(out_ptr + 18, _d3);
vst1_f32(out_ptr + 22, vget_high_f32(_q3.val[1]));
a0 += 4;
a1 += 4;
a2 += 4;
a3 += 4;
a4 += 4;
a5 += 4;
out_ptr += 24;
}
#else
asm volatile(
"loop_4k_%=: \n"
"vld1.32 {d0-d1}, [%[a0]]! \n"
"vld1.32 {d2-d3}, [%[a1]]! \n"
"vld1.32 {d4-d5}, [%[a2]]! \n"
"vld1.32 {d6-d7}, [%[a3]]! \n"
"vld1.32 {d8-d9}, [%[a4]]! \n"
"vld1.32 {d10-d11}, [%[a5]]! \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vtrn.32 q4, q5 \n"
"vswp.32 d1, d4 \n"
"vswp.32 d3, d6 \n"
"vst1.32 {q0}, [%[out]]! \n"
"vst1.32 {d8}, [%[out]]! \n"
"vst1.32 {q1}, [%[out]]! \n"
"vst1.32 {d10}, [%[out]]! \n"
"vst1.32 {q2}, [%[out]]! \n"
"vst1.32 {d9}, [%[out]]! \n"
"vst1.32 {q3}, [%[out]]! \n"
"vst1.32 {d11}, [%[out]]! \n"
"subs %[loops], #1 \n"
"bne loop_4k_%= \n"
: [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2),
[a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5");
#endif
}
if (remain_k > 0) {
float32x4_t _d0 = vld1q_f32(a0);
float32x4_t _d1 = vld1q_f32(a1);
float32x4_t _d2 = vld1q_f32(a2);
float32x4_t _d3 = vld1q_f32(a3);
float32x4_t _d4 = vld1q_f32(a4);
float32x4_t _d5 = vld1q_f32(a5);
_d0 = vandq_f32_u32(_d0, vmask1);
_d1 = vandq_f32_u32(_d1, vmask1);
_d2 = vandq_f32_u32(_d2, vmask1);
_d3 = vandq_f32_u32(_d3, vmask1);
_d4 = vandq_f32_u32(_d4, vmask1);
_d5 = vandq_f32_u32(_d5, vmask1);
float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
_d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
_d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
_d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
switch (remain_k) {
case 3:
vst1q_f32(out_ptr + 12, _d2);
vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0]));
case 2:
vst1q_f32(out_ptr + 6, _d1);
vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1]));
case 1:
vst1q_f32(out_ptr, _d0);
vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0]));
default:
break;
}
}
}
int remain_m = m % 6;
if (remain_m) {
int remain_m_start = m - remain_m;
const float *a0 = A + remain_m_start * lda;
const float *a1 = a0 + lda; const float *a1 = a0 + lda;
const float *a2 = a0 + 2 * lda; const float *a2 = a0 + 2 * lda;
const float *a3 = a0 + 3 * lda; const float *a3 = a0 + 3 * lda;
const float *a4 = a0 + 4 * lda; const float *a4 = a0 + 4 * lda;
const float *a5 = a0 + 5 * lda; const float *a5 = a0 + 5 * lda;
float *out_ptr = output + remain_m_start * k; float *local_buffer = output + i_length * k;
switch (m_tail) {
uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_m));
uint32x4_t vmask3 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_m));
int loops = k >> 2;
if (loops > 0) {
#if __aarch64__
for (int l = 0; l < loops; ++l) {
float32x4_t _d0 = vld1q_f32(a0);
float32x4_t _d1 = vld1q_f32(a1);
float32x4_t _d2 = vld1q_f32(a2);
float32x4_t _d3 = vld1q_f32(a3);
float32x4_t _d4 = vld1q_f32(a4);
float32x4_t _d5 = vld1q_f32(a5);
float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
_d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
_d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
_d2 =
vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
_d3 =
vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1]));
_d0 = vandq_f32_u32(_d0, vmask2);
_d1 = vandq_f32_u32(_d1, vmask2);
_d2 = vandq_f32_u32(_d2, vmask2);
_d3 = vandq_f32_u32(_d3, vmask2);
_d4 = vandq_f32_u32(_q3.val[0], vmask3);
_d5 = vandq_f32_u32(_q3.val[1], vmask3);
vst1q_f32(out_ptr, _d0);
vst1_f32(out_ptr + 4, vget_low_f32(_d4));
vst1q_f32(out_ptr + 6, _d1);
vst1_f32(out_ptr + 10, vget_low_f32(_d5));
vst1q_f32(out_ptr + 12, _d2);
vst1_f32(out_ptr + 16, vget_high_f32(_d4));
vst1q_f32(out_ptr + 18, _d3);
vst1_f32(out_ptr + 22, vget_high_f32(_d5));
a0 += 4;
a1 += 4;
a2 += 4;
a3 += 4;
a4 += 4;
a5 += 4;
out_ptr += 24;
}
#else
asm volatile(
"loop_4k_%=: \n"
"vld1.32 {d0-d1}, [%[a0]]! \n"
"vld1.32 {d2-d3}, [%[a1]]! \n"
"vld1.32 {d4-d5}, [%[a2]]! \n"
"vld1.32 {d6-d7}, [%[a3]]! \n"
"vld1.32 {d8-d9}, [%[a4]]! \n"
"vld1.32 {d10-d11}, [%[a5]]! \n"
"vtrn.32 q0, q1 \n"
"vtrn.32 q2, q3 \n"
"vtrn.32 q4, q5 \n"
"vswp.32 d1, d4 \n"
"vswp.32 d3, d6 \n"
"vbif q0, %q[vzero], %q[vmask2] \n"
"vbif q1, %q[vzero], %q[vmask2] \n"
"vbif q2, %q[vzero], %q[vmask2] \n"
"vbif q3, %q[vzero], %q[vmask2] \n"
"vbif q4, %q[vzero], %q[vmask3] \n"
"vbif q5, %q[vzero], %q[vmask3] \n"
"vst1.32 {q0}, [%[out]]! \n"
"vst1.32 {d8}, [%[out]]! \n"
"vst1.32 {q1}, [%[out]]! \n"
"vst1.32 {d10}, [%[out]]! \n"
"vst1.32 {q2}, [%[out]]! \n"
"vst1.32 {d9}, [%[out]]! \n"
"vst1.32 {q3}, [%[out]]! \n"
"vst1.32 {d11}, [%[out]]! \n"
"subs %[loops], #1 \n"
"bne loop_4k_%= \n"
: [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2),
[a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops)
: [vmask2] "w"(vmask2), [vmask3] "w"(vmask3), [vzero] "w"(vzero)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5");
#endif
}
if (remain_k > 0) {
float32x4_t _d0 = vld1q_f32(a0);
float32x4_t _d1 = vld1q_f32(a1);
float32x4_t _d2 = vld1q_f32(a2);
float32x4_t _d3 = vld1q_f32(a3);
float32x4_t _d4 = vld1q_f32(a4);
float32x4_t _d5 = vld1q_f32(a5);
_d0 = vandq_f32_u32(_d0, vmask1);
_d1 = vandq_f32_u32(_d1, vmask1);
_d2 = vandq_f32_u32(_d2, vmask1);
_d3 = vandq_f32_u32(_d3, vmask1);
_d4 = vandq_f32_u32(_d4, vmask1);
_d5 = vandq_f32_u32(_d5, vmask1);
float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
_d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0]));
_d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
_d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
// _d3 = vcombine_f32(vget_high_f32(_q0.val[1]),
// vget_high_f32(_q1.val[1]));
_d0 = vandq_f32_u32(_d0, vmask2);
_d1 = vandq_f32_u32(_d1, vmask2);
_d2 = vandq_f32_u32(_d2, vmask2);
// _d3 = vandq_f32_u32(_d3, vmask2);
_d4 = vandq_f32_u32(_q3.val[0], vmask3);
_d5 = vandq_f32_u32(_q3.val[1], vmask3);
switch (remain_k) {
case 3:
vst1q_f32(out_ptr + 12, _d2);
vst1_f32(out_ptr + 16, vget_high_f32(_d4));
case 2:
vst1q_f32(out_ptr + 6, _d1);
vst1_f32(out_ptr + 10, vget_low_f32(_d5));
case 1: case 1:
vst1q_f32(out_ptr, _d0); a1 = zero;
vst1_f32(out_ptr + 4, vget_low_f32(_d4)); case 2:
a2 = zero;
case 3:
a3 = zero;
case 4:
a4 = zero;
case 5:
a5 = zero;
break;
default: default:
break; break;
} }
} for (int j = 0; j < k; ++j) {
} *local_buffer++ = *a0++;
*local_buffer++ = *a1++;
*local_buffer++ = *a2++;
*local_buffer++ = *a3++;
*local_buffer++ = *a4++;
*local_buffer++ = *a5++;
}
delete[] zero;
}
// uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5};
// int remain_k = k & 0x3;
// uint32x4_t vzero = vdupq_n_u32(0);
// uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k));
//
// std::cout << "m: " << m << ", k: " << k << std::endl;
// #pragma omp parallel for if (unroll)
// for (int i = 0; i < m - 5; i += 6) {
// std::cout << "i: " << i << std::endl;
// const float *a0 = A + i * lda;
// const float *a1 = A + (i + 1) * lda;
// const float *a2 = A + (i + 2) * lda;
// const float *a3 = A + (i + 3) * lda;
// const float *a4 = A + (i + 4) * lda;
// const float *a5 = A + (i + 5) * lda;
// float *out_ptr = output + i * k;
//
// int loops = k >> 2;
// if (loops > 0) {
// #if __aarch64__
// for (int l = 0; l < loops; ++l) {
// float32x4_t _d0 = vld1q_f32(a0);
// float32x4_t _d1 = vld1q_f32(a1);
// float32x4_t _d2 = vld1q_f32(a2);
// float32x4_t _d3 = vld1q_f32(a3);
// float32x4_t _d4 = vld1q_f32(a4);
// float32x4_t _d5 = vld1q_f32(a5);
//
// float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
// float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
// float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
// _d0 = vcombine_f32(vget_low_f32(_q0.val[0]),
// vget_low_f32(_q1.val[0])); _d1 =
// vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
// _d2 =
// vcombine_f32(vget_high_f32(_q0.val[0]),
// vget_high_f32(_q1.val[0]));
// _d3 =
// vcombine_f32(vget_high_f32(_q0.val[1]),
// vget_high_f32(_q1.val[1]));
//
// vst1q_f32(out_ptr, _d0);
// vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0]));
// vst1q_f32(out_ptr + 6, _d1);
// vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1]));
// vst1q_f32(out_ptr + 12, _d2);
// vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0]));
// vst1q_f32(out_ptr + 18, _d3);
// vst1_f32(out_ptr + 22, vget_high_f32(_q3.val[1]));
//
// a0 += 4;
// a1 += 4;
// a2 += 4;
// a3 += 4;
// a4 += 4;
// a5 += 4;
// out_ptr += 24;
// }
// #else
// asm volatile(
// "loop_4k_%=: \n"
// "vld1.32 {d0-d1}, [%[a0]]! \n"
// "vld1.32 {d2-d3}, [%[a1]]! \n"
// "vld1.32 {d4-d5}, [%[a2]]! \n"
// "vld1.32 {d6-d7}, [%[a3]]! \n"
// "vld1.32 {d8-d9}, [%[a4]]! \n"
// "vld1.32 {d10-d11}, [%[a5]]! \n"
// "vtrn.32 q0, q1 \n"
// "vtrn.32 q2, q3 \n"
// "vtrn.32 q4, q5 \n"
// "vswp.32 d1, d4 \n"
// "vswp.32 d3, d6 \n"
//
// "vst1.32 {q0}, [%[out]]! \n"
// "vst1.32 {d8}, [%[out]]! \n"
// "vst1.32 {q1}, [%[out]]! \n"
// "vst1.32 {d10}, [%[out]]! \n"
// "vst1.32 {q2}, [%[out]]! \n"
// "vst1.32 {d9}, [%[out]]! \n"
// "vst1.32 {q3}, [%[out]]! \n"
// "vst1.32 {d11}, [%[out]]! \n"
//
// "subs %[loops], #1 \n"
// "bne loop_4k_%= \n"
// : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2]
// "+r"(a2),
// [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops)
// :
// : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5");
// #endif
// }
//
// if (remain_k > 0) {
// float32x4_t _d0 = vld1q_f32(a0);
// float32x4_t _d1 = vld1q_f32(a1);
// float32x4_t _d2 = vld1q_f32(a2);
// float32x4_t _d3 = vld1q_f32(a3);
// float32x4_t _d4 = vld1q_f32(a4);
// float32x4_t _d5 = vld1q_f32(a5);
//
// _d0 = vandq_f32_u32(_d0, vmask1);
// _d1 = vandq_f32_u32(_d1, vmask1);
// _d2 = vandq_f32_u32(_d2, vmask1);
// _d3 = vandq_f32_u32(_d3, vmask1);
// _d4 = vandq_f32_u32(_d4, vmask1);
// _d5 = vandq_f32_u32(_d5, vmask1);
//
// float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
// float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
// float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
// _d0 = vcombine_f32(vget_low_f32(_q0.val[0]),
// vget_low_f32(_q1.val[0])); _d1 =
// vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); _d2
// = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
//
// switch (remain_k) {
// case 3:
// vst1q_f32(out_ptr + 12, _d2);
// vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0]));
// case 2:
// vst1q_f32(out_ptr + 6, _d1);
// vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1]));
// case 1:
// vst1q_f32(out_ptr, _d0);
// vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0]));
// default:
// break;
// }
// }
// }
//
// int remain_m = m % 6;
// if (remain_m) {
// int remain_m_start = m - remain_m;
// std::cout << "remain_m_start: " << remain_m_start << std::endl;
// const float *a0 = A + remain_m_start * lda;
// const float *a1 = a0 + lda;
// const float *a2 = a0 + 2 * lda;
// const float *a3 = a0 + 3 * lda;
// const float *a4 = a0 + 4 * lda;
// const float *a5 = a0 + 5 * lda;
// float *out_ptr = output + remain_m_start * k;
//
// uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_m));
// uint32x4_t vmask3 = vcltq_u32(vld1q_u32(mask + 4),
// vdupq_n_u32(remain_m));
//
// int loops = k >> 2;
// if (loops > 0) {
// #if __aarch64__
// for (int l = 0; l < loops; ++l) {
// float32x4_t _d0 = vld1q_f32(a0);
// float32x4_t _d1 = vld1q_f32(a1);
// float32x4_t _d2 = vld1q_f32(a2);
// float32x4_t _d3 = vld1q_f32(a3);
// float32x4_t _d4 = vld1q_f32(a4);
// float32x4_t _d5 = vld1q_f32(a5);
//
// float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
// float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
// float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
// _d0 = vcombine_f32(vget_low_f32(_q0.val[0]),
// vget_low_f32(_q1.val[0])); _d1 =
// vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1]));
// _d2 =
// vcombine_f32(vget_high_f32(_q0.val[0]),
// vget_high_f32(_q1.val[0]));
// _d3 =
// vcombine_f32(vget_high_f32(_q0.val[1]),
// vget_high_f32(_q1.val[1]));
//
// _d0 = vandq_f32_u32(_d0, vmask2);
// _d1 = vandq_f32_u32(_d1, vmask2);
// _d2 = vandq_f32_u32(_d2, vmask2);
// _d3 = vandq_f32_u32(_d3, vmask2);
// _d4 = vandq_f32_u32(_q3.val[0], vmask3);
// _d5 = vandq_f32_u32(_q3.val[1], vmask3);
//
// vst1q_f32(out_ptr, _d0);
// vst1_f32(out_ptr + 4, vget_low_f32(_d4));
// vst1q_f32(out_ptr + 6, _d1);
// vst1_f32(out_ptr + 10, vget_low_f32(_d5));
// vst1q_f32(out_ptr + 12, _d2);
// vst1_f32(out_ptr + 16, vget_high_f32(_d4));
// vst1q_f32(out_ptr + 18, _d3);
// vst1_f32(out_ptr + 22, vget_high_f32(_d5));
//
// a0 += 4;
// a1 += 4;
// a2 += 4;
// a3 += 4;
// a4 += 4;
// a5 += 4;
// out_ptr += 24;
// }
// #else
// asm volatile(
// "loop_4k_%=: \n"
// "vld1.32 {d0-d1}, [%[a0]]! \n"
// "vld1.32 {d2-d3}, [%[a1]]! \n"
// "vld1.32 {d4-d5}, [%[a2]]! \n"
// "vld1.32 {d6-d7}, [%[a3]]! \n"
// "vld1.32 {d8-d9}, [%[a4]]! \n"
// "vld1.32 {d10-d11}, [%[a5]]! \n"
// "vtrn.32 q0, q1 \n"
// "vtrn.32 q2, q3 \n"
// "vtrn.32 q4, q5 \n"
// "vswp.32 d1, d4 \n"
// "vswp.32 d3, d6 \n"
//
// "vbif q0, %q[vzero], %q[vmask2] \n"
// "vbif q1, %q[vzero], %q[vmask2] \n"
// "vbif q2, %q[vzero], %q[vmask2] \n"
// "vbif q3, %q[vzero], %q[vmask2] \n"
// "vbif q4, %q[vzero], %q[vmask3] \n"
// "vbif q5, %q[vzero], %q[vmask3] \n"
//
// "vst1.32 {q0}, [%[out]]! \n"
// "vst1.32 {d8}, [%[out]]! \n"
// "vst1.32 {q1}, [%[out]]! \n"
// "vst1.32 {d10}, [%[out]]! \n"
// "vst1.32 {q2}, [%[out]]! \n"
// "vst1.32 {d9}, [%[out]]! \n"
// "vst1.32 {q3}, [%[out]]! \n"
// "vst1.32 {d11}, [%[out]]! \n"
//
// "subs %[loops], #1 \n"
// "bne loop_4k_%= \n"
// : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2]
// "+r"(a2),
// [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops)
// : [vmask2] "w"(vmask2), [vmask3] "w"(vmask3), [vzero] "w"(vzero)
// : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5");
// #endif
// }
//
// if (remain_k > 0) {
// float32x4_t _d0 = vld1q_f32(a0);
// float32x4_t _d1 = vld1q_f32(a1);
// float32x4_t _d2 = vld1q_f32(a2);
// float32x4_t _d3 = vld1q_f32(a3);
// float32x4_t _d4 = vld1q_f32(a4);
// float32x4_t _d5 = vld1q_f32(a5);
//
// _d0 = vandq_f32_u32(_d0, vmask1);
// _d1 = vandq_f32_u32(_d1, vmask1);
// _d2 = vandq_f32_u32(_d2, vmask1);
// _d3 = vandq_f32_u32(_d3, vmask1);
// _d4 = vandq_f32_u32(_d4, vmask1);
// _d5 = vandq_f32_u32(_d5, vmask1);
//
// float32x4x2_t _q0 = vtrnq_f32(_d0, _d1);
// float32x4x2_t _q1 = vtrnq_f32(_d2, _d3);
// float32x4x2_t _q3 = vtrnq_f32(_d4, _d5);
// _d0 = vcombine_f32(vget_low_f32(_q0.val[0]),
// vget_low_f32(_q1.val[0])); _d1 =
// vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); _d2
// = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0]));
// // _d3 = vcombine_f32(vget_high_f32(_q0.val[1]),
// // vget_high_f32(_q1.val[1]));
//
// _d0 = vandq_f32_u32(_d0, vmask2);
// _d1 = vandq_f32_u32(_d1, vmask2);
// _d2 = vandq_f32_u32(_d2, vmask2);
// // _d3 = vandq_f32_u32(_d3, vmask2);
// _d4 = vandq_f32_u32(_q3.val[0], vmask3);
// _d5 = vandq_f32_u32(_q3.val[1], vmask3);
//
// switch (remain_k) {
// case 3:
// vst1q_f32(out_ptr + 12, _d2);
// vst1_f32(out_ptr + 16, vget_high_f32(_d4));
// case 2:
// vst1q_f32(out_ptr + 6, _d1);
// vst1_f32(out_ptr + 10, vget_low_f32(_d5));
// case 1:
// vst1q_f32(out_ptr, _d0);
// vst1_f32(out_ptr + 4, vget_low_f32(_d4));
// default:
// break;
// }
// }
// }
} }
#if __aarch64__ #if __aarch64__
...@@ -575,12 +645,52 @@ void pack_rhs_8c(int k, int n, const float *B, int ldb, float *output, ...@@ -575,12 +645,52 @@ void pack_rhs_8c(int k, int n, const float *B, int ldb, float *output,
} }
#endif // __aarch64__ #endif // __aarch64__
#if __aarch64__ void write_back_alpha_beta(const int mc, const int nc, const float alpha,
void write_back(const int mc, const int nc, const float *c, const int ldc1, const float *c, const int ldc1, const float beta,
float *C, const int ldc2) { float *C, const int ldc2) {
int nc1 = nc / 4; int nc1 = nc / 4;
int _nc1 = nc % 4; int _nc1 = nc % 4;
float32x4_t _alpha = vdupq_n_f32(alpha);
float32x4_t _beta = vdupq_n_f32(beta);
float32x4_t cv, cv2;
for (int i = 0; i < mc; ++i) {
const float *c_ptr = c + i * ldc1;
float *C_ptr = C + i * ldc2;
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv = vmulq_f32(_alpha, cv);
cv2 = vld1q_f32(C_ptr);
cv = vmlaq_f32(cv, _beta, cv2);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv = vmulq_f32(_alpha, cv);
cv2 = vld1q_f32(C_ptr);
cv = vmlaq_f32(cv, _beta, cv2);
switch (_nc1) {
case 3:
vst1q_lane_f32(C_ptr + 2, cv, 2);
case 2:
vst1_f32(C_ptr, vget_low_f32(cv));
break;
case 1:
vst1q_lane_f32(C_ptr, cv, 0);
break;
}
}
}
}
#if __aarch64__
void write_back_alpha1_beta0(const int mc, const int nc, const float *c,
const int ldc1, float *C, const int ldc2) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
const float *c_ptr; const float *c_ptr;
float *C_ptr; float *C_ptr;
float32x4_t cv; float32x4_t cv;
...@@ -595,23 +705,60 @@ void write_back(const int mc, const int nc, const float *c, const int ldc1, ...@@ -595,23 +705,60 @@ void write_back(const int mc, const int nc, const float *c, const int ldc1,
} }
if (_nc1 != 0) { if (_nc1 != 0) {
cv = vld1q_f32(c_ptr); cv = vld1q_f32(c_ptr);
if (_nc1 >= 1) { switch (_nc1) {
case 3:
vst1q_lane_f32(C_ptr + 2, cv, 2);
case 2:
vst1_f32(C_ptr, vget_low_f32(cv));
break;
case 1:
vst1q_lane_f32(C_ptr, cv, 0); vst1q_lane_f32(C_ptr, cv, 0);
C_ptr++; break;
} }
if (_nc1 >= 2) {
vst1q_lane_f32(C_ptr, cv, 1);
C_ptr++;
} }
if (_nc1 >= 3) { }
vst1q_lane_f32(C_ptr, cv, 2); }
void write_back_alpha1_beta1(const int mc, const int nc, const float *c,
const int ldc1, float *C, const int ldc2) {
int nc1 = nc / 4;
int _nc1 = nc % 4;
const float *c_ptr;
float *C_ptr;
float32x4_t cv, cv2;
for (int i = 0; i < mc; ++i) {
c_ptr = c + i * ldc1;
C_ptr = C + i * ldc2;
for (int j = 0; j < nc1; ++j) {
cv = vld1q_f32(c_ptr);
cv2 = vld1q_f32(C_ptr);
cv = vaddq_f32(cv, cv2);
vst1q_f32(C_ptr, cv);
c_ptr += 4;
C_ptr += 4;
}
if (_nc1 != 0) {
cv = vld1q_f32(c_ptr);
cv2 = vld1q_f32(C_ptr);
cv = vaddq_f32(cv, cv2);
switch (_nc1) {
case 3:
vst1q_lane_f32(C_ptr + 2, cv, 2);
case 2:
vst1_f32(C_ptr, vget_low_f32(cv));
break;
case 1:
vst1q_lane_f32(C_ptr, cv, 0);
break;
} }
} }
} }
} }
#else #else
void write_back(const int mc, const int nc, const float *c, const int ldc1, void write_back_alpha1_beta0(const int mc, const int nc, const float *c,
float *C, const int ldc2) { const int ldc1, float *C, const int ldc2) {
int nc1 = nc / 16; int nc1 = nc / 16;
int nc2 = nc % 16; int nc2 = nc % 16;
int step1 = 4 * (ldc1 - 16 * nc1); int step1 = 4 * (ldc1 - 16 * nc1);
...@@ -663,7 +810,78 @@ void write_back(const int mc, const int nc, const float *c, const int ldc1, ...@@ -663,7 +810,78 @@ void write_back(const int mc, const int nc, const float *c, const int ldc1,
} }
} }
} }
#endif
void write_back_alpha1_beta1(const int mc, const int nc, const float *c,
const int ldc1, float *C, const int ldc2) {
int nc1 = nc / 16;
int nc2 = nc % 16;
int step1 = 4 * (ldc1 - 16 * nc1);
int step2 = 4 * ldc2;
int volatile m = mc;
const float *volatile c_ptr = c;
float *volatile C_ptr = C;
if (nc1 > 0) {
asm volatile(
"subs %[mc], %[mc], #1 \n\t"
"blt end_mc_%= \n\t"
"loop_mc_%=: \n\t"
"mov r6, %[C_ptr] \n\t"
"mov r5, %[nc1] \n\t"
"subs r5, r5, #1 \n\t"
"blt end_nc1_%= \n\t"
"loop_nc1_%=: \n\t"
"vld1.32 {q0, q1}, [%[c_ptr]]! \n\t"
"vld1.32 {q2, q3}, [r6] \n\t"
"vadd.f32 q0, q0, q2 \n\t"
"vadd.f32 q1, q1, q3 \n\t"
"vst1.32 {q0, q1}, [r6]! \n\t"
"vld1.32 {q0, q1}, [%[c_ptr]]! \n\t"
"vld1.32 {q2, q3}, [r6] \n\t"
"vadd.f32 q0, q0, q2 \n\t"
"vadd.f32 q1, q1, q3 \n\t"
"vst1.32 {q0, q1}, [r6]! \n\t"
"subs r5, r5, #1 \n\t"
"bge loop_nc1_%= \n\t"
"end_nc1_%=: \n\t"
"add %[c_ptr], %[c_ptr], %[step1] \n\t"
"add %[C_ptr], %[C_ptr], %[step2] \n\t"
"subs %[mc], %[mc], #1 \n\t"
"bge loop_mc_%= \n\t"
"end_mc_%=: \n\t"
:
: [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1),
[step1] "r"(step1), [step2] "r"(step2)
: "memory", "r5", "r6", "q0", "q1", "q2", "q3");
}
if (nc2 != 0) {
for (int i = 0; i < mc; i++) {
const float *c0 = c_ptr + nc1 * 16 + i * ldc1;
float *C0 = C_ptr + nc1 * 16 + i * ldc2;
for (int j = 0; j < nc2; j++) {
*C0++ += *c0++;
}
}
}
}
#endif // __aarch64__
void write_back(const int mc, const int nc, const float alpha, const float *c,
const int ldc1, const float beta, float *C, const int ldc2) {
if (alpha == 1.f && beta == 0.f) {
write_back_alpha1_beta0(mc, nc, c, ldc1, C, ldc2);
} else if (alpha == 1.f && beta == 1.f) {
write_back_alpha1_beta1(mc, nc, c, ldc1, C, ldc2);
} else {
write_back_alpha_beta(mc, nc, alpha, c, ldc1, beta, C, ldc2);
}
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -31,8 +31,9 @@ struct SgemmStrategy { ...@@ -31,8 +31,9 @@ struct SgemmStrategy {
Itype *, const bool); Itype *, const bool);
typedef void (*kernelFunc)(const Itype *, const Itype *, const int, Otype *, typedef void (*kernelFunc)(const Itype *, const Itype *, const int, Otype *,
const int); const int);
typedef void (*WriteFunc)(const int, const int, const Otype *, const int, typedef void (*WriteFunc)(const int, const int, const float alpha,
Otype *, const int); const Otype *, const int, const float beta, Otype *,
const int);
packLhsFunc pack_lhs; packLhsFunc pack_lhs;
packRhsFunc pack_rhs; packRhsFunc pack_rhs;
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "operators/math/gru_compute.h" #include "operators/math/gru_compute.h"
#include "common/types.h" #include "common/types.h"
#include "operators/math/activation.h" #include "operators/math/activation.h"
#include "operators/math/gemm.h" #include "operators/math/gemm/cblas.h"
#include "operators/math/gru_cpu_kernel.h" #include "operators/math/gru_cpu_kernel.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -29,35 +29,19 @@ struct GRUUnitFunctor<CPU, T> { ...@@ -29,35 +29,19 @@ struct GRUUnitFunctor<CPU, T> {
static void compute(GRUMetaValue<T> value, int frame_size, int batch_size, static void compute(GRUMetaValue<T> value, int frame_size, int batch_size,
const ActivationType active_node, const ActivationType active_node,
const ActivationType active_gate) { const ActivationType active_gate) {
Gemm gemm;
if (value.prev_out_value) { if (value.prev_out_value) {
#ifdef _OPENMP cblas_sgemm(false, false, batch_size, frame_size * 2, frame_size, 1.f,
gemm.Sgemm_omp(batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight, value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3, false, frame_size * 2, 1.f, value.gate_value, frame_size * 3);
static_cast<float *>(nullptr));
#else
gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1,
value.prev_out_value, frame_size, value.gate_weight,
frame_size * 2, 1, value.gate_value, frame_size * 3, false,
static_cast<float *>(nullptr));
#endif
} }
forward_reset_output(value, frame_size, batch_size, active_gate); forward_reset_output(value, frame_size, batch_size, active_gate);
if (value.prev_out_value) { if (value.prev_out_value) {
#ifdef _OPENMP cblas_sgemm(false, false, batch_size, frame_size, frame_size, 1.f,
gemm.Sgemm_omp(batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2,
frame_size * 3, false, static_cast<float *>(nullptr));
#else
gemm.Sgemm(batch_size, frame_size, frame_size, 1,
value.reset_output_value, frame_size, value.state_weight, value.reset_output_value, frame_size, value.state_weight,
frame_size, 1, value.gate_value + frame_size * 2, frame_size, 1.f, value.gate_value + frame_size * 2,
frame_size * 3, false, static_cast<float *>(nullptr)); frame_size * 3);
#endif
} }
forward_final_output(value, frame_size, batch_size, active_node); forward_final_output(value, frame_size, batch_size, active_node);
...@@ -65,6 +49,7 @@ struct GRUUnitFunctor<CPU, T> { ...@@ -65,6 +49,7 @@ struct GRUUnitFunctor<CPU, T> {
}; };
template struct GRUUnitFunctor<CPU, float>; template struct GRUUnitFunctor<CPU, float>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -71,34 +71,11 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -71,34 +71,11 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
a[index++] = tmp[i * n + j]; a[index++] = tmp[i * n + j];
} }
} }
if (M == 1) {
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#else
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#endif
} else {
cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data<float>(), N, cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N); beta, matrix_out->data<float>(), N);
}
} else {
if (M == 1) {
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N, relu, bias);
#else
gemm.Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, bias);
#endif
} else { } else {
cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data<float>(), K, cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N);
N);
}
} }
} }
......
...@@ -803,9 +803,9 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -803,9 +803,9 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"dup v15.4s, wzr \n" "dup v15.4s, wzr \n"
"cmp %[inter], #0 \n" "cmp %[inter], #0 \n"
"ble loop_1c_%= \n" "ble 2f \n"
// loop 2 channels // loop 2 channels
"loop_2c_%=: \n" "1: \n"
"ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n"
"ld1 {v4.4s, v5.4s}, [%[in_ptr]], #32 \n" "ld1 {v4.4s, v5.4s}, [%[in_ptr]], #32 \n"
...@@ -829,12 +829,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -829,12 +829,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"fmla v15.4s, v5.4s, v1.s[3] \n" "fmla v15.4s, v5.4s, v1.s[3] \n"
"subs %[inter], %[inter], #1 \n" "subs %[inter], %[inter], #1 \n"
"bne loop_2c_%= \n" "bne 1b \n"
// loop 1 channel // loop 1 channel
"loop_1c_%=: \n" "2: \n"
"cmp %[remain], #0 \n" "cmp %[remain], #0 \n"
"ble store_res_%= \n" "ble 3f \n"
"ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n"
"ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n"
...@@ -847,7 +847,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -847,7 +847,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"fmla v14.4s, v2.4s, v0.s[3] \n" "fmla v14.4s, v2.4s, v0.s[3] \n"
"fmla v15.4s, v3.4s, v0.s[3] \n" "fmla v15.4s, v3.4s, v0.s[3] \n"
"store_res_%=: \n" "3: \n"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[uv_ptr]], #64 \n" "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[uv_ptr]], #64 \n"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[uv_ptr]], #64 \n" "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[uv_ptr]], #64 \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr), : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr),
......
...@@ -5,7 +5,7 @@ TOTAL_ERRORS=0 ...@@ -5,7 +5,7 @@ TOTAL_ERRORS=0
# The trick to remove deleted files: https://stackoverflow.com/a/2413151 # The trick to remove deleted files: https://stackoverflow.com/a/2413151
for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}' | \ for file in $(git diff --cached --name-status | awk '$1 != "D" {print $2}' | \
grep -v ".pb.cpp" | grep -v ".pb.h" | grep -v ".pb-c.h" | grep -v ".pb-c.c" | \ grep -v ".pb.cpp" | grep -v ".pb.h" | grep -v ".pb-c.h" | grep -v ".pb-c.c" | \
grep -v "protobuf-c.h" | grep -v "protobuf-c.c"); do grep -v "protobuf-c.h" | grep -v "protobuf-c.c" | grep -v "dim.h"); do
cpplint $file; cpplint $file;
TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?); TOTAL_ERRORS=$(expr $TOTAL_ERRORS + $?);
done done
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册