提交 bf078a01 编写于 作者: Z Zhen Wang

remove the ugly code matmul_int8.

上级 a80b04b9
...@@ -113,8 +113,8 @@ void ConvAddReluInt8Compute(const FusionConvAddReluInt8Param<CPU> &param) { ...@@ -113,8 +113,8 @@ void ConvAddReluInt8Compute(const FusionConvAddReluInt8Param<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul_int8(filter_slice, false, col_matrix, false, scale_v, math::matmul(filter_slice, false, col_matrix, false, scale_v, &out_slice,
&out_slice, static_cast<float>(0), true, biase_data); static_cast<float>(0), true, biase_data);
} }
} }
} }
......
...@@ -108,13 +108,13 @@ inline void GemmConv(const ConvParam<CPU> &param) { ...@@ -108,13 +108,13 @@ inline void GemmConv(const ConvParam<CPU> &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
if (param.Input()->type() == typeid(int8_t)) { if (param.Input()->type() == typeid(int8_t)) {
math::matmul_int8(filter_slice, false, col_matrix, false, math::matmul(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice, static_cast<float>(0),
static_cast<float>(0)); false, static_cast<int32_t *>(nullptr));
} else { } else {
math::matmul<float>(filter_slice, false, col_matrix, false, math::matmul(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice, static_cast<float>(1), &out_slice, static_cast<float>(0),
static_cast<float>(0)); false, static_cast<float *>(nullptr));
} }
} }
} }
......
...@@ -73,13 +73,13 @@ void MulCompute(const MulParam<CPU> &param) { ...@@ -73,13 +73,13 @@ void MulCompute(const MulParam<CPU> &param) {
} }
if (param.InputX()->type() == typeid(int8_t)) { if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>(); out->mutable_data<int32_t>();
math::matmul_int8(x_matrix, false, y_matrix, false, static_cast<float>(1), math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out,
out, static_cast<float>(0)); static_cast<float>(0), false, static_cast<int32_t *>(nullptr));
} else { } else {
out->mutable_data<float>(); out->mutable_data<float>();
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1), math::matmul(x_matrix, false, y_matrix, false, static_cast<float>(1), out,
out, static_cast<float>(0)); static_cast<float>(0), false, static_cast<float *>(nullptr));
} }
if (out_dim.size() != 2) { if (out_dim.size() != 2) {
out->Resize(out_dim); out->Resize(out_dim);
......
...@@ -30,10 +30,11 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a, ...@@ -30,10 +30,11 @@ void matmul(const framework::Tensor &matrix_a, bool trans_a,
framework::Tensor *matrix_out, T beta, bool relu = false, framework::Tensor *matrix_out, T beta, bool relu = false,
float *bias = nullptr); float *bias = nullptr);
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a, template <typename T, typename S>
const framework::Tensor &matrix_b, bool trans_b, float alpha, void matmul(const framework::Tensor &matrix_a, bool trans_a,
framework::Tensor *matrix_out, float beta, bool relu = false, const framework::Tensor &matrix_b, bool trans_b, T alpha,
int32_t *bias = nullptr); framework::Tensor *matrix_out, T beta, bool relu = false,
S *bias = nullptr);
template <typename T> template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a, void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
......
...@@ -20,10 +20,12 @@ limitations under the License. */ ...@@ -20,10 +20,12 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void matmul_int8(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha, template <>
framework::Tensor *matrix_out, float beta, bool relu, void matmul(const framework::Tensor &matrix_a, bool trans_a,
int32_t *bias) { const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
int32_t *bias) {
auto dim_a = matrix_a.dims(); auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims(); auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims(); auto dim_out = matrix_out->dims();
......
...@@ -85,16 +85,16 @@ int main() { ...@@ -85,16 +85,16 @@ int main() {
// int8_t without bias // int8_t without bias
// warm-up 10 times // warm-up 10 times
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8( paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr); static_cast<float>(0), false, static_cast<int32_t*>(nullptr));
} }
auto time3 = time(); auto time3 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8( paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32, aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0), false, nullptr); static_cast<float>(0), false, static_cast<int32_t*>(nullptr));
} }
auto time4 = time(); auto time4 = time();
std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n"; std::cout << "int8_t gemm cost :" << time_diff(time3, time4) / 10 << "ms\n";
...@@ -102,15 +102,15 @@ int main() { ...@@ -102,15 +102,15 @@ int main() {
// int8_t with bias&relu // int8_t with bias&relu
// warm-up 10 times // warm-up 10 times
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8( paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8, aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, &bias_data[0]); static_cast<float>(0), true, bias_data);
} }
auto time5 = time(); auto time5 = time();
for (int j = 0; j < 10; ++j) { for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul_int8( paddle_mobile::operators::math::matmul(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8, aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, &bias_data[0]); static_cast<float>(0), true, bias_data);
} }
auto time6 = time(); auto time6 = time();
std::cout << "int8_t gemm_with_bias_relu cost :" std::cout << "int8_t gemm_with_bias_relu cost :"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册