提交 7b06a19a 编写于 作者: 李寅

Fix matrix reorder case for various input shape

上级 97355004
......@@ -127,14 +127,6 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
auto scratch_buffer = context->device()->scratch_buffer();
scratch_buffer->Rewind();
index_t scratch_size = C->raw_max_size();
if (!A->is_weight()) {
scratch_size += A->raw_max_size();
}
if (!B->is_weight()) {
scratch_size += B->raw_max_size();
}
scratch_buffer->GrowSize(scratch_size);
sgemm_.Run(a_ptr_base,
b_ptr_base,
......
......@@ -56,7 +56,7 @@ void SGemm::operator()(const MatrixMap<const float> &lhs,
const MatrixMap<const float> &rhs,
MatrixMap<float> *result,
ScratchBuffer *scratch_buffer) {
if (rhs.col() < lhs.row()) {
if (lhs.is_const() && !rhs.is_const()) {
MatrixMap<const float> lhs_transpose = lhs.transpose();
MatrixMap<const float> rhs_transpose = rhs.transpose();
MatrixMap<float> result_transpose = result->transpose();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册