提交 6c8a856b 编写于 作者: Y yanantao78

do trim

上级 19042491
...@@ -25,7 +25,6 @@ limitations under the License. */ ...@@ -25,7 +25,6 @@ limitations under the License. */
#define c(i, j) c[(i)*ldc + (j)] #define c(i, j) c[(i)*ldc + (j)]
#define c1(i, j) c1[(i)*ldc + (j)] #define c1(i, j) c1[(i)*ldc + (j)]
void print_matirx(int m, int n, int ldc, float *c) { void print_matirx(int m, int n, int ldc, float *c) {
for (int i = 0; i < m; ++i) { for (int i = 0; i < m; ++i) {
std::cout << c(i, 0); std::cout << c(i, 0);
...@@ -48,7 +47,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { ...@@ -48,7 +47,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
float *c1 = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); float *c1 = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float* scale = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m)); float* scale = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m));
float* bias = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m)); float* bias = static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m));
srand(unsigned(time(0))); srand(unsigned(time(0)));
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
a[i] = t1 + rand() % t2; a[i] = t1 + rand() % t2;
...@@ -62,7 +61,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { ...@@ -62,7 +61,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
for (int i = 0; i < m; ++i) { for (int i = 0; i < m; ++i) {
bias[i] = t1 + rand() % t2; bias[i] = t1 + rand() % t2;
} }
for (int i = 0; i < m; ++i) { for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
float r = 0; float r = 0;
...@@ -77,7 +76,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { ...@@ -77,7 +76,7 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
c1(i, j) = r; c1(i, j) = r;
} }
} }
paddle_mobile::operators::math::SgemmWithBn(m, n, k, 0.9, a, lda, paddle_mobile::operators::math::SgemmWithBn(m, n, k, 0.9, a, lda,
b, ldb, 0.3, c, ldc, relu, scale, bias); b, ldb, 0.3, c, ldc, relu, scale, bias);
int eq = 0; int eq = 0;
...@@ -89,22 +88,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { ...@@ -89,22 +88,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
++neq; ++neq;
} }
} }
if (pr > 0) { if (pr > 0) {
std::cout << "A:" << std::endl; std::cout << "A:" << std::endl;
print_matirx(m, k, lda, a); print_matirx(m, k, lda, a);
std::cout << "B:" << std::endl; std::cout << "B:" << std::endl;
print_matirx(k, n, ldb, b); print_matirx(k, n, ldb, b);
std::cout << "C:" << std::endl; std::cout << "C:" << std::endl;
print_matirx(m, n, ldc, c); print_matirx(m, n, ldc, c);
std::cout << "C1:" << std::endl; std::cout << "C1:" << std::endl;
print_matirx(m, n, ldc, c1); print_matirx(m, n, ldc, c1);
} }
std::cout << "mnk=" << m << " " << n << " " << k << std::cout << "mnk=" << m << " " << n << " " << k <<
" relu=" << relu << " relu=" << relu <<
" eq=" << eq << " neq=" << neq << std::endl; " eq=" << eq << " neq=" << neq << std::endl;
...@@ -114,19 +110,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { ...@@ -114,19 +110,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) {
paddle_mobile::memory::Free(c1); paddle_mobile::memory::Free(c1);
paddle_mobile::memory::Free(scale); paddle_mobile::memory::Free(scale);
paddle_mobile::memory::Free(bias); paddle_mobile::memory::Free(bias);
return 0; return 0;
} }
int main() { int main() {
do_sgemm(9, 9, 9, true, 10, 10, 10); do_sgemm(9, 9, 9, true, 10, 10, 10);
do_sgemm(10, 6, 12, false, 10, 10, 0); do_sgemm(10, 6, 12, false, 10, 10, 0);
do_sgemm(512, 256, 384, false, 10, 10, 0); do_sgemm(512, 256, 384, false, 10, 10, 0);
do_sgemm(1366, 768, 256, false, 10, 10, 0); do_sgemm(1366, 768, 256, false, 10, 10, 0);
do_sgemm(1255, 755, 333, false, 10, 10, 0); do_sgemm(1255, 755, 333, false, 10, 10, 0);
do_sgemm(555, 777, 999, false, 10, 10, 0); do_sgemm(555, 777, 999, false, 10, 10, 0);
do_sgemm(10, 6, 12, true, -4, 10, 0); do_sgemm(10, 6, 12, true, -4, 10, 0);
do_sgemm(512, 256, 384, true, -4, 10, 0); do_sgemm(512, 256, 384, true, -4, 10, 0);
do_sgemm(1366, 768, 256, true, -4, 10, 0); do_sgemm(1366, 768, 256, true, -4, 10, 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册