diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index 9740a257ae112abc1c9ec93b82ac035e2e61e5a8..23db4e5208c20243b404a034dc3aba1ae58d903f 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -127,14 +127,6 @@ class MatMulOp : public MatMulOpBase { auto scratch_buffer = context->device()->scratch_buffer(); scratch_buffer->Rewind(); - index_t scratch_size = C->raw_max_size(); - if (!A->is_weight()) { - scratch_size += A->raw_max_size(); - } - if (!B->is_weight()) { - scratch_size += B->raw_max_size(); - } - scratch_buffer->GrowSize(scratch_size); sgemm_.Run(a_ptr_base, b_ptr_base, diff --git a/mace/ops/sgemm.cc b/mace/ops/sgemm.cc index 717cb05676a8629ae0a8815bd24d439db1d512a7..445b9cf6211f935034b2f4e4c74532f96e8708f1 100644 --- a/mace/ops/sgemm.cc +++ b/mace/ops/sgemm.cc @@ -56,7 +56,7 @@ void SGemm::operator()(const MatrixMap &lhs, const MatrixMap &rhs, MatrixMap *result, ScratchBuffer *scratch_buffer) { - if (rhs.col() < lhs.row()) { + if (lhs.is_const() && !rhs.is_const()) { MatrixMap lhs_transpose = lhs.transpose(); MatrixMap rhs_transpose = rhs.transpose(); MatrixMap result_transpose = result->transpose();