test_gemm_perf.cpp 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <iostream>
#include "../test_helper.h"
17
#include "../test_include.h"
18 19 20 21 22 23 24
#include "operators/math/gemm.h"
#include "operators/math/math_function.h"

#define a(i, j) a[(i)*lda + (j)]
#define b(i, j) b[(i)*ldb + (j)]
#define c1(i, j) c1[(i)*ldc + (j)]

Z
Zhen Wang 已提交
25 26 27
#define m 1024
#define n 1024
#define k 1024
28 29

int main() {
30
  paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
31
  paddle_mobile.SetThreadNum(4);
Z
Zhen Wang 已提交
32
  Tensor aa, bb, cc;
33 34 35 36 37 38 39 40 41 42 43 44 45
  auto aaptr = aa.mutable_data<float>({m, k});
  auto bbptr = bb.mutable_data<float>({k, n});
  auto ccptr = cc.mutable_data<float>({m, n});

  for (int i = 0; i < m * k; ++i) {
    aaptr[i] = 2;
  }
  for (int i = 0; i < k * n; ++i) {
    bbptr[i] = 2;
  }
  for (int i = 0; i < m * n; ++i) {
    ccptr[i] = 2;
  }
Z
Zhen Wang 已提交
46

47
  Tensor aa_int8, bb_int8, cc_int32, cc_int8;
Z
Zhen Wang 已提交
48 49
  auto aaptr_int8 = aa_int8.mutable_data<int8_t>({m, k});
  auto bbptr_int8 = bb_int8.mutable_data<int8_t>({k, n});
50 51
  auto ccptr_int32 = cc_int32.mutable_data<int32_t>({m, n});
  auto ccptr_int8 = cc_int8.mutable_data<int8_t>({m, n});
Z
ZhenWang 已提交
52 53
  int32_t* bias_data_col = new int32_t[m];
  int32_t* bias_data_row = new int32_t[n];
Z
Zhen Wang 已提交
54 55 56 57 58 59 60 61

  for (int i = 0; i < m * k; ++i) {
    aaptr_int8[i] = static_cast<int8_t>(2);
  }
  for (int i = 0; i < k * n; ++i) {
    bbptr_int8[i] = static_cast<int8_t>(2);
  }
  for (int i = 0; i < m * n; ++i) {
62 63 64 65
    ccptr_int32[i] = static_cast<int32_t>(2);
  }

  for (int i = 0; i < m; ++i) {
Z
ZhenWang 已提交
66 67 68 69 70
    bias_data_col[i] = 2;
  }

  for (int i = 0; i < n; ++i) {
    bias_data_row[i] = 2;
71 72
  }

Z
Zhen Wang 已提交
73 74
  // float
  // warm-up 10 times
75
  for (int j = 0; j < 10; ++j) {
76 77
    paddle_mobile::operators::math::matmul<float>(
        aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
Z
Zhen Wang 已提交
78 79
        false, nullptr);
  }
80

Z
ZhenWang 已提交
81
  auto time_start0 = time();
Z
Zhen Wang 已提交
82 83 84 85
  for (int j = 0; j < 10; ++j) {
    paddle_mobile::operators::math::matmul<float>(
        aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
        false, nullptr);
86
  }
Z
ZhenWang 已提交
87 88 89
  auto time_end0 = time();
  std::cout << "float gemm  cost :" << time_diff(time_start0, time_end0) / 10
            << "ms\n";
Z
Zhen Wang 已提交
90

91
  // int8_t without bias
Z
Zhen Wang 已提交
92 93
  // warm-up 10 times
  for (int j = 0; j < 10; ++j) {
Z
ZhenWang 已提交
94
    paddle_mobile::operators::math::matmul<float, int32_t>(
95
        aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
Z
ZhenWang 已提交
96
        static_cast<float>(0));
Z
Zhen Wang 已提交
97 98
  }

Z
ZhenWang 已提交
99
  auto time_start1 = time();
Z
Zhen Wang 已提交
100
  for (int j = 0; j < 10; ++j) {
Z
ZhenWang 已提交
101
    paddle_mobile::operators::math::matmul<float, int32_t>(
102
        aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
Z
ZhenWang 已提交
103
        static_cast<float>(0));
Z
Zhen Wang 已提交
104
  }
Z
ZhenWang 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
  auto time_end1 = time();
  std::cout << "int8_t gemm  cost :" << time_diff(time_start1, time_end1) / 10
            << "ms\n";

  // int8_t with bias, column element wise add
  // warm-up 10 times
  for (int j = 0; j < 10; ++j) {
    paddle_mobile::operators::math::matmul(
        aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
        static_cast<float>(0), false, bias_data_col, false);
  }
  auto time_start2 = time();
  for (int j = 0; j < 10; ++j) {
    paddle_mobile::operators::math::matmul(
        aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
        static_cast<float>(0), false, bias_data_col, false);
  }
  auto time_end2 = time();
  std::cout << "int8_t gemm_with_bias(column add) cost :"
            << time_diff(time_start2, time_end2) / 10 << "ms\n";

  // int8_t with bias, row element wise add
  // warm-up 10 times
  for (int j = 0; j < 10; ++j) {
    paddle_mobile::operators::math::matmul(
        aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
        static_cast<float>(0), false, bias_data_row, true);
  }
  auto time_start3 = time();
  for (int j = 0; j < 10; ++j) {
    paddle_mobile::operators::math::matmul(
        aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
        static_cast<float>(0), false, bias_data_row, true);
  }
  auto time_end3 = time();
  std::cout << "int8_t gemm_with_bias(row add) cost :"
            << time_diff(time_start3, time_end3) / 10 << "ms\n";
142

143 144 145
  // int8_t with bias&relu
  // warm-up 10 times
  for (int j = 0; j < 10; ++j) {
Z
Zhen Wang 已提交
146
    paddle_mobile::operators::math::matmul(
147
        aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
Z
ZhenWang 已提交
148
        static_cast<float>(0), true, bias_data_col, false);
149
  }
Z
ZhenWang 已提交
150
  auto time_start4 = time();
151
  for (int j = 0; j < 10; ++j) {
Z
Zhen Wang 已提交
152
    paddle_mobile::operators::math::matmul(
153
        aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
Z
ZhenWang 已提交
154
        static_cast<float>(0), true, bias_data_col, false);
155
  }
Z
ZhenWang 已提交
156
  auto time_end4 = time();
157
  std::cout << "int8_t gemm_with_bias_relu cost :"
Z
ZhenWang 已提交
158
            << time_diff(time_start4, time_end4) / 10 << "ms\n";
159

Z
ZhenWang 已提交
160 161
  delete[] bias_data_row;
  delete[] bias_data_col;
162

163 164
  return 0;
}