sgemv_compute_test.cc 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include "lite/tests/utils/fill_data.h"
#include "lite/tests/utils/naive_math_impl.h"
#ifdef LITE_WITH_ARM
#include "lite/backends/arm/math/funcs.h"
#endif  // LITE_WITH_ARM
#include "lite/core/context.h"
23
#include "lite/core/profile/timer.h"
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"

typedef paddle::lite::Tensor Tensor;

DEFINE_int32(cluster, 3, "cluster id");
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");

DEFINE_int32(M, 512, "sgemv: M");
DEFINE_int32(K, 512, "sgemv: K");

DEFINE_bool(traA, false, "gemv: A transpose");

41
DEFINE_int32(flag_act, 0, "do act");
42
DEFINE_bool(flag_bias, false, "with bias");
43 44 45 46 47 48 49 50 51 52 53
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_double(clipped_coef, 6.0, "clipped relu coef");
bool test_sgemv(bool tra,
                int m,
                int k,
                bool has_bias,
                int flag_act,
                int cls,
                int ths,
                float six = 6.f,
                float alpha = 1.f) {
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
  Tensor ta;
  Tensor tb;
  Tensor tc;
  Tensor tc_basic;
  Tensor tbias;

  ta.Resize({m, k});
  tb.Resize({k, 1});
  tc.Resize({m, 1});
  tc_basic.Resize({m, 1});
  tbias.Resize({m});

  ta.set_precision(PRECISION(kFloat));
  tb.set_precision(PRECISION(kFloat));
  tc.set_precision(PRECISION(kFloat));
  tc_basic.set_precision(PRECISION(kFloat));
  tbias.set_precision(PRECISION(kFloat));

  fill_tensor_rand(ta, -1.f, 1.f);
  // fill_tensor_const(ta, 1.f);
  fill_tensor_rand(tb, -1.f, 1.f);
  // fill_tensor_const(tb, 1.f);
  fill_tensor_rand(tbias, -1.f, 1.f);

  LOG(INFO) << "sgemv M: " << m << ", K: " << k
79
            << ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
80 81 82 83 84 85 86
            << ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM

  auto da = ta.mutable_data<float>();
  auto db = tb.mutable_data<float>();
  auto dc = tc.mutable_data<float>();
  auto dc_basic = tc_basic.mutable_data<float>();
87
  memset(reinterpret_cast<char*>(dc_basic), 0, tc_basic.numel());
88
  auto dbias = tbias.mutable_data<float>();
89 90 91 92 93 94 95 96 97
  paddle::lite_api::ActivationType act =
      paddle::lite_api::ActivationType::kIndentity;
  if (flag_act == 1) {
    act = paddle::lite_api::ActivationType::kRelu;
  } else if (flag_act == 2) {
    act = paddle::lite_api::ActivationType::kRelu6;
  } else if (flag_act == 4) {
    act = paddle::lite_api::ActivationType::kLeakyRelu;
  }
98
  if (FLAGS_check_result) {
99 100 101 102 103 104 105 106 107 108 109 110 111
    basic_gemv(m,
               k,
               da,
               db,
               dbias,
               dc_basic,
               1.f,
               0.f,
               tra,
               has_bias,
               flag_act,
               six,
               alpha);
112
  }
113
  paddle::lite::profile::Timer t0;
114 115 116 117 118 119 120 121
  //! compute
  double ops = 2.0 * m * k;
  std::unique_ptr<paddle::lite::KernelContext> ctx1(
      new paddle::lite::KernelContext);
  auto& ctx = ctx1->As<paddle::lite::ARMContext>();
  ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
  /// warmup
  for (int j = 0; j < FLAGS_warmup; ++j) {
122 123 124 125 126 127 128 129 130 131 132 133 134
    paddle::lite::arm::math::sgemv(da,
                                   db,
                                   dc,
                                   tra,
                                   m,
                                   k,
                                   has_bias,
                                   dbias,
                                   flag_act > 0,
                                   act,
                                   &ctx,
                                   six,
                                   alpha);
135 136
  }

137
  t0.Reset();
138
  for (int i = 0; i < FLAGS_repeats; ++i) {
139
    t0.Start();
140 141 142 143 144 145 146 147 148 149 150 151 152
    paddle::lite::arm::math::sgemv(da,
                                   db,
                                   dc,
                                   tra,
                                   m,
                                   k,
                                   has_bias,
                                   dbias,
                                   flag_act > 0,
                                   act,
                                   &ctx,
                                   six,
                                   alpha);
153
    t0.Stop();
154 155 156
  }
  LOG(INFO) << "gemv output: M: " << m << ", K: " << k << ", cluster: " << cls
            << ", threads: " << ths << ", GOPS: " << ops * 1e-9f
157 158 159 160
            << " GOPS, avg time: " << t0.LapTimes().Avg()
            << " ms, min time: " << t0.LapTimes().Min()
            << " ms, mean GOPs: " << ops * 1e-6f / t0.LapTimes().Avg()
            << " GOPs, max GOPs: " << ops * 1e-6f / t0.LapTimes().Min()
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
            << " GOPs";

  if (FLAGS_check_result) {
    double max_ratio = 0;
    double max_diff = 0;
    /// fp32 result
    tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
    LOG(INFO) << "compare result, max diff: " << max_diff
              << ", max ratio: " << max_ratio;
    if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
      Tensor tdiff;
      tdiff.set_precision(PRECISION(kFloat));
      tdiff.Resize(tc.dims());
      tensor_diff(tc_basic, tc, tdiff);
      LOG(INFO) << "basic result: ";
      print_tensor(tc_basic);
177
      LOG(INFO) << "lite result: ";
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
      print_tensor(tc);
      LOG(INFO) << "diff result: ";
      print_tensor(tdiff);
      return false;
    }
  }
#endif
  return true;
}

TEST(TestLiteSgemv, Sgemv) {
  if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
    paddle::lite::DeviceInfo::Init();
#endif
    LOG(INFO) << "run basic sgemv test";
    for (auto& m : {1, 3, 8, 21, 32, 397}) {
      for (auto& k : {1, 3, 8, 17, 59, 234}) {
196
        for (auto& tra : {false, true}) {
197
          for (auto& has_bias : {false, true}) {
198
            for (auto& flag_act : {0, 1, 2, 4}) {
199
              for (auto& th : {1, 2, 4}) {
200 201 202 203 204 205 206 207 208 209 210
                float six = 6.f;
                float alpha = 8.88f;
                auto flag = test_sgemv(tra,
                                       m,
                                       k,
                                       has_bias,
                                       flag_act,
                                       FLAGS_cluster,
                                       th,
                                       six,
                                       alpha);
211 212 213
                if (flag) {
                  LOG(INFO) << "test m = " << m << ", k=" << k
                            << ", bias: " << (has_bias ? "true" : "false")
214
                            << ", flag act: " << flag_act
215 216 217 218 219
                            << ", trans A: " << (tra ? "true" : "false")
                            << ", threads: " << th << " passed\n";
                } else {
                  LOG(FATAL) << "test m = " << m << ", k=" << k
                             << ", bias: " << (has_bias ? "true" : "false")
220
                             << ", flag_act: " << flag_act
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
                             << ", trans A: " << (tra ? "true" : "false")
                             << ", threads: " << th << " failed\n";
                }
              }
            }
          }
        }
      }
    }
  }
}

TEST(TestSgemvCustom, Sgemv_custom) {
#ifdef LITE_WITH_ARM
  paddle::lite::DeviceInfo::Init();
#endif
  auto flag = test_sgemv(FLAGS_traA,
                         FLAGS_M,
                         FLAGS_K,
                         FLAGS_flag_bias,
241
                         FLAGS_flag_act,
242
                         FLAGS_cluster,
243 244 245
                         FLAGS_threads,
                         FLAGS_clipped_coef,
                         FLAGS_leakey_relu_alpha);
246 247 248
  if (!flag) {
    LOG(FATAL) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
               << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
249
               << ", act: " << FLAGS_flag_act << " failed!!";
250 251 252
  }
  LOG(INFO) << "test m = " << FLAGS_M << ", k=" << FLAGS_K
            << ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
253
            << ", act: " << FLAGS_flag_act << " passed!!";
254
}