提交 bf078a01 编写于 作者: Z Zhen Wang

remove the ugly code matmul_int8.

上级 a80b04b9
......@@ -113,8 +113,8 @@ void ConvAddReluInt8Compute(const FusionConvAddReluInt8Param<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul_int8(filter_slice, false, col_matrix, false, scale_v,
&out_slice, static_cast<float>(0), true, biase_data);
math::matmul(filter_slice, false, col_matrix, false, scale_v, &out_slice,
static_cast<float>(0), true, biase_data);
}
}
}
......
......@@ -108,13 +108,13 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
if (param.Input()->type() == typeid(int8_t)) {
math::matmul_int8(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
math::matmul(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(0),
false, static_cast<int32_t *>(nullptr));
} else {
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0));
math::matmul(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(0),
false, static_cast<float *>(nullptr));
}
}
}
......
......@@ -73,13 +73,13 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul_int8(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out,
static_cast<float>(0), false, static_cast<int32_t *>(nullptr));
} else {
out->mutable_data<float>();
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out,
static_cast<float>(0), false, static_cast<float *>(nullptr));
}
if (out_dim.size() != 2) {
out->Resize(out_dim);
......
......@@ -30,10 +30,11 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr);
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
int32_t *bias = nullptr);
template <typename T, typename S>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, T alpha,
framework::Tensor *matrix_out, T beta, bool relu = false,
S *bias = nullptr);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
......@@ -20,10 +20,12 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
int32_t *bias) {
template <>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
int32_t *bias) {
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
......
......@@ -85,16 +85,16 @@ int main() {
// int8_t without bias
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
static_cast<float>(0), false, static_cast<int32_t*>(nullptr));
}
auto time3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr);
static_cast<float>(0), false, static_cast<int32_t*>(nullptr));
}
auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
......@@ -102,15 +102,15 @@ int main() {
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
static_cast<float>(0), true, bias_data);
}
auto time5 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8(
paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, &bias_data[0]);
static_cast<float>(0), true, bias_data);
}
auto time6 = time();
std::cout << "int8_t gemm_with_bias_relu cost :"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册