未验证 提交 c061e83f 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #599 from smilejames/develop

add macro definition:__ARM_NEON, __aarch64__
...@@ -15,7 +15,7 @@ file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h) ...@@ -15,7 +15,7 @@ file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h)
include_directories(src/) include_directories(src/)
if(IS_IOS) if(IS_IOS)
set(CMAKE_CXX_FLAGS "-fobjc-abi-version=2 -fobjc-arc -std=gnu++11 -stdlib=libc++ -O3 -s -isysroot ${CMAKE_OSX_SYSROOT} ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-mfpu=neon -fobjc-abi-version=2 -fobjc-arc -std=gnu++11 -stdlib=libc++ -O3 -s -isysroot ${CMAKE_OSX_SYSROOT} ${CMAKE_CXX_FLAGS}")
else() else()
set(CMAKE_CXX_FLAGS "-std=c++14 -O3 -s ${CMAKE_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "-std=c++14 -O3 -s ${CMAKE_CXX_FLAGS}")
endif() endif()
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/depthwise_conv_3x3.h" #include "operators/math/depthwise_conv_3x3.h"
#ifdef __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include <vector> #include <vector>
...@@ -23,7 +23,6 @@ namespace math { ...@@ -23,7 +23,6 @@ namespace math {
void DepthwiseConv3x3(const Tensor *input, vector<int> strides, void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
vector<int> paddings, const Tensor *filter, Tensor *bias, vector<int> paddings, const Tensor *filter, Tensor *bias,
Tensor *output, bool if_bias) { Tensor *output, bool if_bias) {
#ifdef __ARM_NEON
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
const int input_height = input->dims()[2]; const int input_height = input->dims()[2];
...@@ -181,7 +180,27 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -181,7 +180,27 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
} }
} else { } else {
#if defined(ARMV17) #if __ARM_NEON
#if __aarch64__
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1);
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
if (if_bias) {
output_data[ph * output_width + pw] += vget_lane_f32(res, 0);
} else {
output_data[ph * output_width + pw] = vget_lane_f32(res, 0);
}
#else
asm volatile( asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t" "vld1.32 {q1}, [%[pos1]] \n\t"
...@@ -209,26 +228,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -209,26 +228,10 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
[filter2] "r"(filter2), [filter3] "r"(filter3), [filter2] "r"(filter2), [filter3] "r"(filter3),
[output_ptr] "r"(output_ptr), [zero] "r"(zero) [output_ptr] "r"(output_ptr), [zero] "r"(zero)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
#endif // __aarch64__
#else #else
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t v_filter1 = vld1q_f32(filter1); #endif // __ARM_NEON
const float32x4_t v_filter2 = vld1q_f32(filter2);
const float32x4_t v_filter3 = vld1q_f32(filter3);
float32x4_t mula = vmulq_f32(data1, v_filter1);
mula = vmlaq_f32(mula, data2, v_filter2);
mula = vmlaq_f32(mula, data3, v_filter3);
float32x2_t res = vpadd_f32(
vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula));
res = vpadd_f32(res, res);
if (if_bias) {
output_data[ph * output_width + pw] += vget_lane_f32(res, 0);
} else {
output_data[ph * output_width + pw] = vget_lane_f32(res, 0);
}
#endif
} }
} }
} }
...@@ -239,12 +242,11 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides, ...@@ -239,12 +242,11 @@ void DepthwiseConv3x3(const Tensor *input, vector<int> strides,
input_data += input_batch_stride; input_data += input_batch_stride;
output_data += output_batch_stride; output_data += output_batch_stride;
} }
#endif
} }
void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor *bias, bool if_bias) { Tensor *output, Tensor *bias, bool if_bias) {
#ifdef __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
...@@ -520,7 +522,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -520,7 +522,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale, Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) { const Tensor *new_bias, bool if_relu) {
#ifdef __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
...@@ -824,7 +826,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -824,7 +826,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter,
void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale, Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) { const Tensor *new_bias, bool if_relu) {
#ifdef __ARM_NEON #if __ARM_NEON
const int batch_size = input->dims()[0]; const int batch_size = input->dims()[0];
...@@ -1022,7 +1024,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, ...@@ -1022,7 +1024,7 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter,
void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, Tensor bias, bool if_bias) { Tensor *output, Tensor bias, bool if_bias) {
#ifdef __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
...@@ -1225,7 +1227,7 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1225,7 +1227,7 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter,
void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale, Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) { const Tensor *new_bias, bool if_relu) {
#ifdef __ARM_NEON #if __ARM_NEON
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include "operators/math/gemm.h" #include "operators/math/gemm.h"
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#ifndef X86 #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#ifdef _OPENMP #ifdef _OPENMP
...@@ -136,6 +136,10 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb, ...@@ -136,6 +136,10 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
for (int j = 0; j < n - n_tail; j += NR) { for (int j = 0; j < n - n_tail; j += NR) {
for (int i = 0; i < k; ++i) { for (int i = 0; i < k; ++i) {
b0 = &B(i, j); b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
#else
asm volatile( asm volatile(
"pld [%[b0]] \n\t" "pld [%[b0]] \n\t"
"vld1.32 {q0, q1}, [%[b0]] \n\t" "vld1.32 {q0, q1}, [%[b0]] \n\t"
...@@ -143,6 +147,10 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb, ...@@ -143,6 +147,10 @@ void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
: [buffer] "+r"(buffer) : [buffer] "+r"(buffer)
: [b0] "r"(b0) : [b0] "r"(b0)
: "memory", "q0", "q0"); : "memory", "q0", "q0");
#endif // __aarch64__
#else
#endif // __ARM_NEON
} }
} }
if (n_tail != 0) { if (n_tail != 0) {
...@@ -206,7 +214,9 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a, ...@@ -206,7 +214,9 @@ void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
} }
} }
#if defined(IOS) #if __ARM_NEON
#if __aarch64__
void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) { void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) {
// init C // init C
float32x4_t cv0 = vdupq_n_f32(0.0); float32x4_t cv0 = vdupq_n_f32(0.0);
...@@ -255,9 +265,9 @@ void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) { ...@@ -255,9 +265,9 @@ void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) {
} }
} }
} }
} // namespace math
#elif defined(ARMV7) #else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr; const float *a_ptr, *b_ptr;
a_ptr = a; a_ptr = a;
...@@ -328,151 +338,6 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { ...@@ -328,151 +338,6 @@ void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
"q10", "q11", "q12", "q13"); "q10", "q11", "q12", "q13");
} }
#else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
float *c0, *c1, *c2, *c3;
c0 = c;
c1 = c + ldc;
c2 = c + 2 * ldc;
c3 = c + 3 * ldc;
for (int p = 0; p < k; p += 1) {
// first row
c0[0] += a[0] * b[0];
c0[1] += a[0] * b[1];
c0[2] += a[0] * b[2];
c0[3] += a[0] * b[3];
// second row
c1[0] += a[1] * b[0];
c1[1] += a[1] * b[1];
c1[2] += a[1] * b[2];
c1[3] += a[1] * b[3];
// third row
c2[0] += a[2] * b[0];
c2[1] += a[2] * b[1];
c2[2] += a[2] * b[2];
c2[3] += a[2] * b[3];
// fourth row
c3[0] += a[3] * b[0];
c3[1] += a[3] * b[1];
c3[2] += a[3] * b[2];
c3[3] += a[3] * b[3];
a += 4;
b += 4;
}
}
#endif
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
int L2 = 1 * 1024 * 1024;
KC = k;
MC = L2 / (2 * KC * sizeof(float));
NC = MC;
// make sure MC is multiple of 4, and NC is multiple of 8
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + 4 - 1) / 4 * 4;
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + 8 - 1) / 8 * 8;
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_(KC, nc, nc % NR, &B(0, j), ldb, packedB);
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA);
InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc,
relu);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
int L2 = 1 * 1024 * 1024;
KC = k;
MC = L2 / (2 * KC * sizeof(float));
NC = MC;
// make sure MC is multiple of 4, and NC is multiple of 8
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + 4 - 1) / 4 * 4;
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + 8 - 1) / 8 * 8;
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_(KC, nc, nc % NR, &B(0, j), ldb, packedB);
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA);
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
bool relu) { bool relu) {
...@@ -1699,6 +1564,153 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *scale, ...@@ -1699,6 +1564,153 @@ void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *scale,
"q12", "q13", "q14"); "q12", "q13", "q14");
} }
#endif // __aarch64__
#else
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
float *c0, *c1, *c2, *c3;
c0 = c;
c1 = c + ldc;
c2 = c + 2 * ldc;
c3 = c + 3 * ldc;
for (int p = 0; p < k; p += 1) {
// first row
c0[0] += a[0] * b[0];
c0[1] += a[0] * b[1];
c0[2] += a[0] * b[2];
c0[3] += a[0] * b[3];
// second row
c1[0] += a[1] * b[0];
c1[1] += a[1] * b[1];
c1[2] += a[1] * b[2];
c1[3] += a[1] * b[3];
// third row
c2[0] += a[2] * b[0];
c2[1] += a[2] * b[1];
c2[2] += a[2] * b[2];
c2[3] += a[2] * b[3];
// fourth row
c3[0] += a[3] * b[0];
c3[1] += a[3] * b[1];
c3[2] += a[3] * b[2];
c3[3] += a[3] * b[3];
a += 4;
b += 4;
}
}
#endif // __ARM_NEON
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
int L2 = 1 * 1024 * 1024;
KC = k;
MC = L2 / (2 * KC * sizeof(float));
NC = MC;
// make sure MC is multiple of 4, and NC is multiple of 8
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + 4 - 1) / 4 * 4;
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + 8 - 1) / 8 * 8;
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_(KC, nc, nc % NR, &B(0, j), ldb, packedB);
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA);
InnerKernel(mc, nc, alpha, packedA, packedB, beta, packedC, &C(i, j), ldc,
relu);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
int L2 = 1 * 1024 * 1024;
KC = k;
MC = L2 / (2 * KC * sizeof(float));
NC = MC;
// make sure MC is multiple of 4, and NC is multiple of 8
int mblock_num = (m + MC - 1) / MC;
MC = (m + mblock_num - 1) / mblock_num;
MC = (MC + 4 - 1) / 4 * 4;
// DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n";
int nblock_num = (n + NC - 1) / NC;
NC = (n + nblock_num - 1) / nblock_num;
NC = (NC + 8 - 1) / 8 * 8;
// DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n";
packedA = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * KC));
packedB = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * KC * NC));
packedC = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * MC * NC));
zero = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * KC));
for (int l = 0; l < KC; ++l) {
zero[l] = 0;
}
int mc, nc;
for (int j = 0; j < n; j += NC) {
nc = s_min(n - j, NC);
PackMatrixB_(KC, nc, nc % NR, &B(0, j), ldb, packedB);
for (int i = 0; i < m; i += MC) {
mc = s_min(m - i, MC);
PackMatrixA_(mc, KC, mc % MR, &A(i, 0), lda, packedA);
InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC,
&C(i, j), ldc, relu, new_scale + i, new_bias + i);
}
}
paddle_mobile::memory::Free(packedA);
paddle_mobile::memory::Free(packedB);
paddle_mobile::memory::Free(packedC);
paddle_mobile::memory::Free(zero);
}
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
} // namespace paddle_mobile
...@@ -159,7 +159,7 @@ set (CMAKE_OSX_SYSROOT ${CMAKE_IOS_SDK_ROOT} CACHE PATH "Sysroot used for iOS su ...@@ -159,7 +159,7 @@ set (CMAKE_OSX_SYSROOT ${CMAKE_IOS_SDK_ROOT} CACHE PATH "Sysroot used for iOS su
# set the architecture for iOS # set the architecture for iOS
if (${IOS_PLATFORM} STREQUAL "OS") if (${IOS_PLATFORM} STREQUAL "OS")
set (IOS_ARCH armv7 armv7s arm64) set (IOS_ARCH armv7 armv7s)
elseif (${IOS_PLATFORM} STREQUAL "SIMULATOR") elseif (${IOS_PLATFORM} STREQUAL "SIMULATOR")
set (IOS_ARCH i386) set (IOS_ARCH i386)
elseif (${IOS_PLATFORM} STREQUAL "SIMULATOR64") elseif (${IOS_PLATFORM} STREQUAL "SIMULATOR64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册