From c66469316e9f3cac2b6e8a413eef70adfd72adc5 Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Wed, 18 Jul 2018 11:12:00 +0800 Subject: [PATCH] fix gemm --- src/operators/math/gemm.cpp | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index bb91adcc4d..22a351866d 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -1214,6 +1214,21 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { // C = A * B, batchnorm(C) void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, float *bias) { + if (nc < 4) { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + *C = (*c) * (*scale) + (*bias); + C++; + c++; + } + C += (ldc - nc); + c += (NC - nc); + scale++; + bias++; + } + return; + } + int volatile nc1 = nc / 16; int _nc1 = nc % 16; int volatile nc2 = _nc1 / 4; @@ -1300,6 +1315,24 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, // C = A * B, batchnorm(C), relu(C) void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, float *bias) { + if (nc < 4) { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + *C = (*c) * (*scale) + (*bias); + if (*C < 0) { + *C = 0; + } + C++; + c++; + } + C += (ldc - nc); + c += (NC - nc); + scale++; + bias++; + } + return; + } + int nc1 = nc / 16; int _nc1 = nc % 16; int nc2 = _nc1 / 4; -- GitLab