未验证 提交 ca9ec692 编写于 作者: Z zhupengyang 提交者: GitHub

[xpu] update bert, ernie unittests (#4357)

上级 1d3754aa
......@@ -63,6 +63,7 @@ if (WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "VGG19.tar.gz")
# data
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ILSVRC2012_small.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "bert_data.tar.gz")
endif()
endif()
......
......@@ -9,11 +9,18 @@ if(LITE_WITH_ARM)
endif()
function(xpu_x86_without_xtcl_test TARGET MODEL DATA)
lite_cc_test(${TARGET} SRCS ${TARGET}.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/${MODEL}
--data_dir=${LITE_MODEL_DIR}/${DATA})
if(${DATA} STREQUAL "")
lite_cc_test(${TARGET} SRCS ${TARGET}.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/${MODEL})
else()
lite_cc_test(${TARGET} SRCS ${TARGET}.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${x86_kernels} ${xpu_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/${MODEL} --data_dir=${LITE_MODEL_DIR}/${DATA})
endif()
if(WITH_TESTING)
add_dependencies(${TARGET} extern_lite_download_${MODEL}_tar_gz)
if(NOT ${DATA} STREQUAL "")
......@@ -26,8 +33,8 @@ if(LITE_WITH_XPU AND NOT LITE_WITH_XTCL)
xpu_x86_without_xtcl_test(test_resnet50_fp32_xpu resnet50 ILSVRC2012_small)
xpu_x86_without_xtcl_test(test_googlenet_fp32_xpu GoogLeNet ILSVRC2012_small)
xpu_x86_without_xtcl_test(test_vgg19_fp32_xpu VGG19 ILSVRC2012_small)
xpu_x86_without_xtcl_test(test_ernie_fp32_xpu ernie "")
xpu_x86_without_xtcl_test(test_bert_fp32_xpu bert "")
xpu_x86_without_xtcl_test(test_ernie_fp32_xpu ernie bert_data)
xpu_x86_without_xtcl_test(test_bert_fp32_xpu bert bert_data)
endif()
if(LITE_WITH_RKNPU)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "lite/api/paddle_api.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/io.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
template <class T = int64_t>
void ReadRawData(const std::string& input_data_dir,
std::vector<std::vector<T>>* input0,
std::vector<std::vector<T>>* input1,
std::vector<std::vector<T>>* input2,
std::vector<std::vector<T>>* input3,
std::vector<std::vector<int64_t>>* input_shapes) {
auto lines = ReadLines(input_data_dir);
for (auto line : lines) {
std::vector<std::string> shape_and_data = Split(line, ";");
std::vector<int64_t> input_shape =
Split<int64_t>(Split(shape_and_data[0], ":")[0], " ");
input_shapes->emplace_back(input_shape);
std::vector<T> input0_data =
Split<T>(Split(shape_and_data[0], ":")[1], " ");
input0->emplace_back(input0_data);
std::vector<T> input1_data =
Split<T>(Split(shape_and_data[1], ":")[1], " ");
input1->emplace_back(input1_data);
std::vector<T> input2_data =
Split<T>(Split(shape_and_data[2], ":")[1], " ");
input2->emplace_back(input2_data);
std::vector<T> input3_data =
Split<T>(Split(shape_and_data[3], ":")[1], " ");
input3->emplace_back(input3_data);
}
}
template <class T = int64_t>
void FillTensor(const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
int tensor_id,
const std::vector<int64_t>& tensor_shape,
const std::vector<T>& tensor_value) {
predictor->GetInput(tensor_id)->Resize(tensor_shape);
int64_t tensor_size = 1;
for (size_t i = 0; i < tensor_shape.size(); i++) {
tensor_size *= tensor_shape[i];
}
CHECK_EQ(static_cast<size_t>(tensor_size), tensor_value.size());
memcpy(predictor->GetInput(tensor_id)->mutable_data<T>(),
tensor_value.data(),
sizeof(T) * tensor_size);
}
float CalBertOutAccuracy(const std::vector<std::vector<float>>& out,
const std::string& out_file) {
auto lines = ReadLines(out_file);
std::vector<std::vector<float>> ref_out;
for (auto line : lines) {
ref_out.emplace_back(Split<float>(line, " "));
}
int right_num = 0;
for (size_t i = 0; i < out.size(); i++) {
std::vector<size_t> out_index{0, 1, 2};
std::vector<size_t> ref_out_index{0, 1, 2};
std::sort(out_index.begin(),
out_index.end(),
[&out, i](size_t a, size_t b) { return out[i][a] > out[i][b]; });
std::sort(ref_out_index.begin(),
ref_out_index.end(),
[&ref_out, i](size_t a, size_t b) {
return ref_out[i][a] > ref_out[i][b];
});
right_num += (out_index == ref_out_index);
}
return static_cast<float>(right_num) / static_cast<float>(out.size());
}
float CalErnieOutAccuracy(const std::vector<std::vector<float>>& out,
const std::string& out_file) {
auto lines = ReadLines(out_file);
std::vector<std::vector<float>> ref_out;
for (auto line : lines) {
ref_out.emplace_back(Split<float>(line, " "));
}
int right_num = 0;
for (size_t i = 0; i < out.size(); i++) {
right_num += (std::fabs(out[i][0] - ref_out[i][0]) < 0.01f);
}
return static_cast<float>(right_num) / static_cast<float>(out.size());
}
} // namespace lite
} // namespace paddle
......@@ -21,23 +21,16 @@
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/bert_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 9, "iteration times to run");
namespace paddle {
namespace lite {
template <typename T>
lite::Tensor GetTensorWithShape(std::vector<int64_t> shape) {
lite::Tensor ret;
ret.Resize(shape);
T* ptr = ret.mutable_data<T>();
for (int i = 0; i < ret.numel(); ++i) {
ptr[i] = (T)1;
}
return ret;
}
TEST(Ernie, test_ernie_fp32_xpu) {
TEST(Bert, test_bert_fp32_xpu) {
lite_api::CxxConfig config;
config.set_model_dir(FLAGS_model_dir);
config.set_valid_places({lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
......@@ -46,56 +39,58 @@ TEST(Ernie, test_ernie_fp32_xpu) {
config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config);
int64_t batch_size = 1;
int64_t seq_len = 64;
Tensor sample_input = GetTensorWithShape<int64_t>({batch_size, seq_len, 1});
std::vector<int64_t> input_shape{batch_size, seq_len, 1};
predictor->GetInput(0)->Resize(input_shape);
predictor->GetInput(1)->Resize(input_shape);
predictor->GetInput(2)->Resize(input_shape);
predictor->GetInput(3)->Resize(input_shape);
memcpy(predictor->GetInput(0)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
memcpy(predictor->GetInput(1)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
memcpy(predictor->GetInput(2)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
memcpy(predictor->GetInput(3)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
std::string input_data_file = FLAGS_data_dir + std::string("/bert_in.txt");
std::vector<std::vector<int64_t>> input0;
std::vector<std::vector<int64_t>> input1;
std::vector<std::vector<int64_t>> input2;
std::vector<std::vector<int64_t>> input3;
std::vector<std::vector<int64_t>> input_shapes;
ReadRawData(
input_data_file, &input0, &input1, &input2, &input3, &input_shapes);
for (int i = 0; i < FLAGS_warmup; ++i) {
std::vector<int64_t> shape = {1, 64, 1};
std::vector<int64_t> fill_value(64, 0);
for (int j = 0; j < 4; j++) {
FillTensor(predictor, j, shape, fill_value);
}
predictor->Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (int i = 0; i < FLAGS_iteration; ++i) {
FillTensor(predictor, 0, input_shapes[i], input0[i]);
FillTensor(predictor, 1, input_shapes[i], input1[i]);
FillTensor(predictor, 2, input_shapes[i], input2[i]);
FillTensor(predictor, 3, input_shapes[i], input3[i]);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 3);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
<< ", warmup: " << FLAGS_warmup
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
std::vector<std::vector<float>> results;
results.emplace_back(std::vector<float>({0.278893, 0.330888, 0.39022}));
auto out = predictor->GetOutput(0);
ASSERT_EQ(out->shape().size(), 2);
ASSERT_EQ(out->shape()[0], 1);
ASSERT_EQ(out->shape()[1], 3);
for (size_t i = 0; i < results.size(); ++i) {
for (size_t j = 0; j < results[i].size(); ++j) {
EXPECT_NEAR(
out->data<float>()[j + (out->shape()[1] * i)], results[i][j], 3e-5);
}
}
std::string ref_out_file = FLAGS_data_dir + std::string("/bert_out.txt");
float out_accuracy = CalBertOutAccuracy(out_rets, ref_out_file);
ASSERT_GT(out_accuracy, 0.95f);
}
} // namespace lite
......
......@@ -21,8 +21,12 @@
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/tests/api/bert_utility.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(data_dir, "", "data dir");
DEFINE_int32(iteration, 9, "iteration times to run");
namespace paddle {
namespace lite {
......@@ -46,56 +50,58 @@ TEST(Ernie, test_ernie_fp32_xpu) {
config.set_xpu_workspace_l3_size_per_thread();
auto predictor = lite_api::CreatePaddlePredictor(config);
int64_t batch_size = 1;
int64_t seq_len = 64;
Tensor sample_input = GetTensorWithShape<int64_t>({batch_size, seq_len, 1});
std::vector<int64_t> input_shape{batch_size, seq_len, 1};
predictor->GetInput(0)->Resize(input_shape);
predictor->GetInput(1)->Resize(input_shape);
predictor->GetInput(2)->Resize(input_shape);
predictor->GetInput(3)->Resize(input_shape);
memcpy(predictor->GetInput(0)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
memcpy(predictor->GetInput(1)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
memcpy(predictor->GetInput(2)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
memcpy(predictor->GetInput(3)->mutable_data<int64_t>(),
sample_input.raw_data(),
sizeof(int64_t) * batch_size * seq_len);
std::string input_data_file = FLAGS_data_dir + std::string("/bert_in.txt");
std::vector<std::vector<int64_t>> input0;
std::vector<std::vector<int64_t>> input1;
std::vector<std::vector<int64_t>> input2;
std::vector<std::vector<int64_t>> input3;
std::vector<std::vector<int64_t>> input_shapes;
ReadRawData(
input_data_file, &input0, &input1, &input2, &input3, &input_shapes);
for (int i = 0; i < FLAGS_warmup; ++i) {
std::vector<int64_t> shape = {1, 64, 1};
std::vector<int64_t> fill_value(64, 0);
for (int j = 0; j < 4; j++) {
FillTensor(predictor, j, shape, fill_value);
}
predictor->Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
std::vector<std::vector<float>> out_rets;
out_rets.resize(FLAGS_iteration);
double cost_time = 0;
for (int i = 0; i < FLAGS_iteration; ++i) {
FillTensor(predictor, 0, input_shapes[i], input0[i]);
FillTensor(predictor, 1, input_shapes[i], input1[i]);
FillTensor(predictor, 2, input_shapes[i], input2[i]);
FillTensor(predictor, 3, input_shapes[i], input3[i]);
double start = GetCurrentUS();
predictor->Run();
cost_time += GetCurrentUS() - start;
auto output_tensor = predictor->GetOutput(0);
auto output_shape = output_tensor->shape();
auto output_data = output_tensor->data<float>();
ASSERT_EQ(output_shape.size(), 2UL);
ASSERT_EQ(output_shape[0], 1);
ASSERT_EQ(output_shape[1], 1);
int output_size = output_shape[0] * output_shape[1];
out_rets[i].resize(output_size);
memcpy(&(out_rets[i].at(0)), output_data, sizeof(float) * output_size);
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
std::vector<std::vector<float>> results;
results.emplace_back(std::vector<float>({0.108398}));
auto out = predictor->GetOutput(0);
ASSERT_EQ(out->shape().size(), 2);
ASSERT_EQ(out->shape()[0], 1);
ASSERT_EQ(out->shape()[1], 1);
<< ", warmup: " << FLAGS_warmup
<< ", iteration: " << FLAGS_iteration << ", spend "
<< cost_time / FLAGS_iteration / 1000.0 << " ms in average.";
for (size_t i = 0; i < results.size(); ++i) {
for (size_t j = 0; j < results[i].size(); ++j) {
EXPECT_NEAR(
out->data<float>()[j + (out->shape()[1] * i)], results[i][j], 2e-5);
}
}
std::string ref_out_file = FLAGS_data_dir + std::string("/ernie_out.txt");
float out_accuracy = CalErnieOutAccuracy(out_rets, ref_out_file);
ASSERT_GT(out_accuracy, 0.95f);
}
} // namespace lite
......
......@@ -121,9 +121,9 @@ class FcOPTest : public arena::TestCase {
int k = wdims_[0];
int n = wdims_[1];
LOG(INFO) << "M=" << m << ", N=" << n << ", K=" << k
<< ", bias=" << flag_bias << ", with_relu=" << with_relu_
<< ", padding_weights=" << padding_weights_;
VLOG(4) << "M=" << m << ", N=" << n << ", K=" << k << ", bias=" << flag_bias
<< ", with_relu=" << with_relu_
<< ", padding_weights=" << padding_weights_;
if (m == 1) {
basic_gemv(n,
......
......@@ -738,7 +738,7 @@ TEST(PriorBox, precision) {
}
TEST(DensityPriorBox, precision) {
#ifdef LITE_WITH_X86
#if defined(LITE_WITH_X86) && !defined(LITE_WITH_XPU)
Place place(TARGET(kX86));
test_density_prior_box(place);
#endif
......
......@@ -104,11 +104,11 @@ bool test_gemm_int8(bool tra,
scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0];
}
LOG(INFO) << "gemm_int8 M: " << m << ", N: " << n << ", K: " << k
<< ", transA: " << (tra ? "true" : "false")
<< ", transB: " << (trb ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
VLOG(4) << "gemm_int8 M: " << m << ", N: " << n << ", K: " << k
<< ", transA: " << (tra ? "true" : "false")
<< ", transB: " << (trb ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
int lda = tra ? m : k;
int ldb = trb ? k : n;
......@@ -344,13 +344,12 @@ TEST(TestLiteGemmInt8, gemm_prepacked_int8) {
FLAGS_power_mode,
th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n
<< ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< ", trans B: " << (trb ? "true" : "false")
<< " passed\n";
VLOG(4) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< ", trans B: " << (trb ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n
<< ", k=" << k
......
......@@ -97,9 +97,9 @@ bool test_gemv_int8(bool tra,
scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0];
}
LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n
<< ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false");
VLOG(4) << "gemv_int8 M: " << m << ", N: " << n
<< ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
auto da = ta.mutable_data<int8_t>();
auto db = tb.mutable_data<int8_t>();
......@@ -336,11 +336,11 @@ TEST(TestLiteGemvInt8, gemv_prepacked_int8) {
six,
alpha);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< " passed\n";
VLOG(4) << "test m = " << m << ", n=" << n
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", trans A: " << (tra ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n
<< ", bias: " << (has_bias ? "true" : "false")
......
......@@ -98,9 +98,9 @@ bool test_sgemm_c4(
basic_trans_mat_to_c4(da, da_c4, k, m, k, true);
basic_trans_mat_to_c4(db, db_c4, n, k, n, false);
LOG(INFO) << "sgemm_c4 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
VLOG(4) << "sgemm_c4 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm_c4(false,
......@@ -331,10 +331,10 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
auto flag = test_sgemm_c4(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
VLOG(4) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
......@@ -364,10 +364,10 @@ TEST(TestSgemmC8, test_func_sgemm_c8_prepacked) {
auto flag = test_sgemm_c8(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
VLOG(4) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
......
......@@ -75,9 +75,9 @@ bool test_sgemv(bool tra,
// fill_tensor_const(tb, 1.f);
fill_tensor_rand(tbias, -1.f, 1.f);
LOG(INFO) << "sgemv M: " << m << ", K: " << k
<< ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false");
VLOG(4) << "sgemv M: " << m << ", K: " << k
<< ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
auto da = ta.mutable_data<float>();
......@@ -209,11 +209,11 @@ TEST(TestLiteSgemv, Sgemv) {
six,
alpha);
if (flag) {
LOG(INFO) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", flag act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " passed\n";
VLOG(4) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", flag act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false")
<< ", threads: " << th << " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册