未验证 提交 e70eade1 编写于 作者: Y yiicy 提交者: GitHub

[ARM] fix winograd conv reinit bug, test=develop (#3281)

上级 f4d678ee
...@@ -35,7 +35,8 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() { ...@@ -35,7 +35,8 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
if (last_shape_ == x_dims) { if (last_shape_ == x_dims) {
return; return;
} }
last_shape_ = x_dims;
//! update workspace size
int ic = x_dims[1]; int ic = x_dims[1];
int ih = x_dims[2]; int ih = x_dims[2];
int iw = x_dims[3]; int iw = x_dims[3];
...@@ -43,6 +44,20 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() { ...@@ -43,6 +44,20 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int oh = o_dims[2]; int oh = o_dims[2];
int ow = o_dims[3]; int ow = o_dims[3];
int tile_block = 8; int tile_block = 8;
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw +
8 * wino_iw * wino_iw) *
threads;
workspace_size_ = (temp_size + new_input_size) * sizeof(float);
//! update trans weights impl
choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false; choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
if (choose_small_) { if (choose_small_) {
wino_iw = 4; wino_iw = 4;
...@@ -58,18 +73,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() { ...@@ -58,18 +73,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
} }
last_function_ = 1; last_function_ = 1;
} }
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw +
8 * wino_iw * wino_iw) *
threads;
workspace_size_ = (temp_size + new_input_size) * sizeof(float);
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad}); weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic); void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<float>(); auto weights_data_ = weights_.mutable_data<float>();
...@@ -81,8 +85,6 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() { ...@@ -81,8 +85,6 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr); weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
} }
free(trans_tmp_ptr); free(trans_tmp_ptr);
last_shape_ = x_dims;
} }
template <> template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册