提交 5efbfbff 编写于 作者: 李滨

Merge branch 'fix_nmt' into 'master'

Fix nmt

See merge request !975
...@@ -127,14 +127,6 @@ class MatMulOp<CPU, float> : public MatMulOpBase { ...@@ -127,14 +127,6 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
auto scratch_buffer = context->device()->scratch_buffer(); auto scratch_buffer = context->device()->scratch_buffer();
scratch_buffer->Rewind(); scratch_buffer->Rewind();
index_t scratch_size = C->raw_max_size();
if (!A->is_weight()) {
scratch_size += A->raw_max_size();
}
if (!B->is_weight()) {
scratch_size += B->raw_max_size();
}
scratch_buffer->GrowSize(scratch_size);
sgemm_.Run(a_ptr_base, sgemm_.Run(a_ptr_base,
b_ptr_base, b_ptr_base,
......
...@@ -56,7 +56,7 @@ void SGemm::operator()(const MatrixMap<const float> &lhs, ...@@ -56,7 +56,7 @@ void SGemm::operator()(const MatrixMap<const float> &lhs,
const MatrixMap<const float> &rhs, const MatrixMap<const float> &rhs,
MatrixMap<float> *result, MatrixMap<float> *result,
ScratchBuffer *scratch_buffer) { ScratchBuffer *scratch_buffer) {
if (rhs.col() < lhs.row()) { if (lhs.is_const() && !rhs.is_const()) {
MatrixMap<const float> lhs_transpose = lhs.transpose(); MatrixMap<const float> lhs_transpose = lhs.transpose();
MatrixMap<const float> rhs_transpose = rhs.transpose(); MatrixMap<const float> rhs_transpose = rhs.transpose();
MatrixMap<float> result_transpose = result->transpose(); MatrixMap<float> result_transpose = result->transpose();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册