提交 9a30bf0b 编写于 作者: Z ZhenWang

update the usage of matmul

上级 e333964f
...@@ -73,13 +73,14 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -73,13 +73,14 @@ void MulCompute(const MulParam<CPU> &param) {
} }
if (param.InputX()->type() == typeid(int8_t)) { if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>(); out->mutable_data<int32_t>();
math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out, math::matmul<float, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(0), false, static_cast<int32_t *>(nullptr)); static_cast<float>(1), out,
static_cast<float>(0));
} else { } else {
out->mutable_data<float>(); out->mutable_data<float>();
math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out, math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
static_cast<float>(0), false, static_cast<float *>(nullptr)); out, static_cast<float>(0));
} }
if (out_dim.size() != 2) { if (out_dim.size() != 2) {
out->Resize(out_dim); out->Resize(out_dim);
......
...@@ -85,16 +85,16 @@ int main() { ...@@ -85,16 +85,16 @@ int main() {
// int8_t without bias // int8_t without bias
// warm-up 10 times // warm-up 10 times
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul( paddle_mobile::operators::math::matmul<float, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, static_cast<int32_t*>(nullptr)); static_cast<float>(0));
} }
auto time3 = time(); auto time3 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul( paddle_mobile::operators::math::matmul<float, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, static_cast<int32_t*>(nullptr)); static_cast<float>(0));
} }
auto time4 = time(); auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#ifdef FUSION_CONVADDRELU_INT8_OP #ifdef FUSION_CONVADDRELU_INT8_OP
#include <limits>
#include <iostream> #include <iostream>
#include <limits>
#include "../test_helper.h" #include "../test_helper.h"
#include "../test_include.h" #include "../test_include.h"
#include "operators/fusion_conv_add_relu_int8_op.h" #include "operators/fusion_conv_add_relu_int8_op.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册