diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index c5cf0b237fc0548ac2bb7549d3950b3cead2b74c..90fe63662bd4ff7ea147707ae2e91ffbaff4478d 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -69,10 +69,8 @@ void WinogradConv::ReInitWhenNeeded() { (tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw + 8 * wino_iw * wino_iw) * threads; - ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); - + workspace_size_ = (temp_size + new_input_size) * sizeof(float); weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad}); - ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic); auto weights_data_ = weights_.mutable_data(); if (!choose_small_) { @@ -96,6 +94,7 @@ template <> void WinogradConv::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + ctx.ExtendWorkspace(workspace_size_); const auto* i_data = param.x->data(); const auto* w_data = weights_.data(); const auto* b_data = param.bias ? param.bias->data() : nullptr;