diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index 90fe63662bd4ff7ea147707ae2e91ffbaff4478d..e433a3f4bb4a7aa553fbb1193ff82779d9af3242 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -35,7 +35,8 @@ void WinogradConv::ReInitWhenNeeded() { if (last_shape_ == x_dims) { return; } - + last_shape_ = x_dims; + //! update workspace size int ic = x_dims[1]; int ih = x_dims[2]; int iw = x_dims[3]; @@ -43,6 +44,20 @@ void WinogradConv::ReInitWhenNeeded() { int oh = o_dims[2]; int ow = o_dims[3]; int tile_block = 8; + auto pad = *(param.paddings); + int pad_h = pad[0]; + int pad_w = pad[2]; + int oc_pad = (oc + 3) / 4 * 4; + int ic_pad = (ic + 3) / 4 * 4; + const int new_input_size = + (ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2); + const int temp_size = + (tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw + + 8 * wino_iw * wino_iw) * + threads; + workspace_size_ = (temp_size + new_input_size) * sizeof(float); + + //! update trans weights impl choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false; if (choose_small_) { wino_iw = 4; @@ -58,18 +73,7 @@ void WinogradConv::ReInitWhenNeeded() { } last_function_ = 1; } - auto pad = *(param.paddings); - int pad_h = pad[0]; - int pad_w = pad[2]; - int oc_pad = (oc + 3) / 4 * 4; - int ic_pad = (ic + 3) / 4 * 4; - const int new_input_size = - (ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2); - const int temp_size = - (tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw + - 8 * wino_iw * wino_iw) * - threads; - workspace_size_ = (temp_size + new_input_size) * sizeof(float); + weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad}); void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic); auto weights_data_ = weights_.mutable_data(); @@ -81,8 +85,6 @@ void WinogradConv::ReInitWhenNeeded() { weights_data_, param.filter->data(), ic, oc, trans_tmp_ptr); } free(trans_tmp_ptr); - - last_shape_ = x_dims; } template <>