From c6d648eac214951fd4233a930fb26c67c23dd399 Mon Sep 17 00:00:00 2001 From: liutuo Date: Sat, 29 Sep 2018 16:49:51 +0800 Subject: [PATCH] refactor deconv code --- mace/kernels/deconv_2d.h | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/mace/kernels/deconv_2d.h b/mace/kernels/deconv_2d.h index aeead4e5..3656fff8 100644 --- a/mace/kernels/deconv_2d.h +++ b/mace/kernels/deconv_2d.h @@ -191,20 +191,6 @@ struct Deconv2dFunctor: Deconv2dFunctorBase { const index_t *in_shape, const index_t *out_shape, float *output) { - const index_t kernel_size = kernel_h * kernel_w; - std::vector out_map(kernel_size); - int p0 = 0; - int p1 = 0; - index_t gap = out_shape[3] - kernel_w; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - out_map[p0] = p1; - p0++; - p1++; - } - p1 += gap; - } - const index_t out_height = out_shape[2]; const index_t out_width = out_shape[3]; const index_t in_height = in_shape[2]; @@ -212,6 +198,14 @@ struct Deconv2dFunctor: Deconv2dFunctorBase { const index_t out_img_size = out_height * out_width; const index_t in_img_size = in_height * in_width; + const int kernel_size = static_cast(kernel_h * kernel_w); + std::vector index_map(kernel_size, 0); + for (index_t i = 0; i < kernel_h; ++i) { + for (index_t j = 0; j < kernel_w; ++j) { + index_map[i * kernel_w + j] = i * out_width + j; + } + } + #pragma omp parallel for for (int b = 0; b < in_shape[0]; ++b) { for (int oc = 0; oc < out_shape[1]; ++oc) { @@ -230,7 +224,7 @@ struct Deconv2dFunctor: Deconv2dFunctorBase { const index_t kernel_offset = (oc * in_shape[1] + ic) * kernel_size; for (int k = 0; k < kernel_size; ++k) { - const index_t out_idx = out_offset + out_map[k]; + const index_t out_idx = out_offset + index_map[k]; const index_t kernel_idx = kernel_offset + k; out_base[out_idx] += val * filter[kernel_idx]; } -- GitLab