diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index bb91adcc4db412db137fdc12831bad75e069e38c..22a351866dd62d8bd93b3e97970af54180b878b7 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -1214,6 +1214,21 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { // C = A * B, batchnorm(C) void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, float *bias) { + if (nc < 4) { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + *C = (*c) * (*scale) + (*bias); + C++; + c++; + } + C += (ldc - nc); + c += (NC - nc); + scale++; + bias++; + } + return; + } + int volatile nc1 = nc / 16; int _nc1 = nc % 16; int volatile nc2 = _nc1 / 4; @@ -1300,6 +1315,24 @@ void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, // C = A * B, batchnorm(C), relu(C) void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale, float *bias) { + if (nc < 4) { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + *C = (*c) * (*scale) + (*bias); + if (*C < 0) { + *C = 0; + } + C++; + c++; + } + C += (ldc - nc); + c += (NC - nc); + scale++; + bias++; + } + return; + } + int nc1 = nc / 16; int _nc1 = nc % 16; int nc2 = _nc1 / 4;