提交 9a30bf0b 编写于 作者: Z ZhenWang

update the usage of matmul

上级 e333964f
......@@ -73,13 +73,14 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out,
static_cast<float>(0), false, static_cast<int32_t *>(nullptr));
math::matmul<float, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
} else {
out->mutable_data<float>();
math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out,
static_cast<float>(0), false, static_cast<float *>(nullptr));
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
}
if (out_dim.size() != 2) {
out->Resize(out_dim);
......
......@@ -85,16 +85,16 @@ int main() {
// int8_t without bias
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<float, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, static_cast<int32_t*>(nullptr));
static_cast<float>(0));
}
auto time3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul(
paddle_mobile::operators::math::matmul<float, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, static_cast<int32_t*>(nullptr));
static_cast<float>(0));
}
auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
......
......@@ -14,8 +14,8 @@ limitations under the License. */
#ifdef FUSION_CONVADDRELU_INT8_OP
#include <limits>
#include <iostream>
#include <limits>
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/fusion_conv_add_relu_int8_op.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册