未验证 提交 e70eade1 编写于 作者: Y yiicy 提交者: GitHub

[ARM] fix winograd conv reinit bug, test=develop (#3281)

上级 f4d678ee
......@@ -35,7 +35,8 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
//! update workspace size
int ic = x_dims[1];
int ih = x_dims[2];
int iw = x_dims[3];
......@@ -43,6 +44,20 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw +
8 * wino_iw * wino_iw) *
threads;
workspace_size_ = (temp_size + new_input_size) * sizeof(float);
//! update trans weights impl
choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
if (choose_small_) {
wino_iw = 4;
......@@ -58,18 +73,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
}
last_function_ = 1;
}
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw +
8 * wino_iw * wino_iw) *
threads;
workspace_size_ = (temp_size + new_input_size) * sizeof(float);
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<float>();
......@@ -81,8 +85,6 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
}
free(trans_tmp_ptr);
last_shape_ = x_dims;
}
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册