diff --git a/mace/ops/arm/fp32/gemm.cc b/mace/ops/arm/fp32/gemm.cc index 4b593c015e30204e8bf4fc26831bf61397d3d696..ff26052ffae16a064f4873151ef675c83d1ecbb3 100644 --- a/mace/ops/arm/fp32/gemm.cc +++ b/mace/ops/arm/fp32/gemm.cc @@ -88,11 +88,11 @@ MaceStatus Gemm::Compute(const OpContext *context, } else if (cached_ == kCacheRhs) { packed_rhs_data = pack_cache_.mutable_data(); } else if (should_cache_pack_) { - if (lhs->is_weight() && !lhs_batched) { + if (lhs->is_weight() && (!lhs_batched || batch == 1)) { cache_side = kCacheLhs; pack_cache_.Resize(packed_lhs_size); packed_lhs_data = pack_cache_.mutable_data(); - } else if (rhs->is_weight() && !rhs_batched) { + } else if (rhs->is_weight() && (!rhs_batched || batch == 1)) { cache_side = kCacheRhs; pack_cache_.Resize(packed_rhs_size); packed_rhs_data = pack_cache_.mutable_data();