From 9a30bf0be92e12ef4c27e7f1e0390cb2aaefbcb9 Mon Sep 17 00:00:00 2001 From: ZhenWang Date: Mon, 3 Dec 2018 14:26:22 +0800 Subject: [PATCH] update the usage of matmul --- src/operators/kernel/central-arm-func/mul_arm_func.h | 9 +++++---- test/common/test_gemm_perf.cpp | 8 ++++---- test/operators/test_fusion_conv_add_relu_int8_op.cpp | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/operators/kernel/central-arm-func/mul_arm_func.h b/src/operators/kernel/central-arm-func/mul_arm_func.h index 60f6bca611..8b9dad90a0 100644 --- a/src/operators/kernel/central-arm-func/mul_arm_func.h +++ b/src/operators/kernel/central-arm-func/mul_arm_func.h @@ -73,13 +73,14 @@ void MulCompute(const MulParam ¶m) { } if (param.InputX()->type() == typeid(int8_t)) { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), out, - static_cast(0), false, static_cast(nullptr)); + math::matmul(x_matrix, false, y_matrix, false, + static_cast(1), out, + static_cast(0)); } else { out->mutable_data(); - math::matmul(x_matrix, false, y_matrix, false, static_cast(1), out, - static_cast(0), false, static_cast(nullptr)); + math::matmul(x_matrix, false, y_matrix, false, static_cast(1), + out, static_cast(0)); } if (out_dim.size() != 2) { out->Resize(out_dim); diff --git a/test/common/test_gemm_perf.cpp b/test/common/test_gemm_perf.cpp index c8081e2d47..f25a290aef 100644 --- a/test/common/test_gemm_perf.cpp +++ b/test/common/test_gemm_perf.cpp @@ -85,16 +85,16 @@ int main() { // int8_t without bias // warm-up 10 times for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, - static_cast(0), false, static_cast(nullptr)); + static_cast(0)); } auto time3 = time(); for (int j = 0; j < 10; ++j) { - paddle_mobile::operators::math::matmul( + paddle_mobile::operators::math::matmul( aa_int8, false, bb_int8, false, static_cast(1), &cc_int32, - static_cast(0), false, static_cast(nullptr)); + static_cast(0)); } auto time4 = time(); std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; diff --git a/test/operators/test_fusion_conv_add_relu_int8_op.cpp b/test/operators/test_fusion_conv_add_relu_int8_op.cpp index 8d7067898e..add38b34f1 100644 --- a/test/operators/test_fusion_conv_add_relu_int8_op.cpp +++ b/test/operators/test_fusion_conv_add_relu_int8_op.cpp @@ -14,8 +14,8 @@ limitations under the License. */ #ifdef FUSION_CONVADDRELU_INT8_OP -#include #include +#include #include "../test_helper.h" #include "../test_include.h" #include "operators/fusion_conv_add_relu_int8_op.h" -- GitLab