From 781d5fe36280f300313b28f671b85e8ea9cc83ee Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Sat, 26 May 2018 17:34:00 +0800 Subject: [PATCH] update conv op kernel --- src/operators/kernel/arm/conv_kernel.cpp | 30 ++++++------------------ 1 file changed, 7 insertions(+), 23 deletions(-) diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 03558141f9..c8ac141f9c 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -35,14 +35,9 @@ void ConvKernel::Compute(const ConvParam ¶m) const { LOG(kLOG_DEBUG) << param; const Tensor *input = param.Input(); - - // The filter will be reshaped in the calculations, - // so here use an assignment operation, - // that avoids modifying the variable in the Scope. Tensor filter = *param.Filter(); - Tensor *output = param.Output(); - // output->mutable_data(context.GetPlace()); + output->mutable_data(); int groups = param.Groups(); std::vector strides = param.Strides(); @@ -53,17 +48,9 @@ void ConvKernel::Compute(const ConvParam ¶m) const { const int batch_size = static_cast(input->dims()[0]); - // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, - // k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); - // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, - // o_w} std::vector output_shape_vec(framework::vectorize(output->dims())); - // use col_shape in the im2col calculation - // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, - // k_w, o_d, - // o_h, o_w} size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); col_shape_vec[0] = input->dims()[1] / groups; @@ -73,24 +60,19 @@ void ConvKernel::Compute(const ConvParam ¶m) const { } framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - // use col_matrix_shape in the gemm calculation - // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, - // o_d * - // o_h * o_w) framework::DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; - // col_matrix shares the same piece of data with col, - // but will be reshaped into a two-dimensional matrix shape - // to call the matrix multiplication interface. Tensor col_matrix; if (is_expand) { col.mutable_data(col_shape); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } + DLOG << " col_shape = " << col_shape; + DLOG << " col_matrix_shape = " << col_matrix_shape; framework::DDim input_shape = framework::slice_ddim( input->dims(), 1, static_cast(input->dims().size())); @@ -98,6 +80,7 @@ void ConvKernel::Compute(const ConvParam ¶m) const { framework::DDim filter_matrix_shape = {filter.dims()[0], filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); + DLOG << " filter.deims() = " << filter.dims(); framework::DDim output_matrix_shape = { output->dims()[1], @@ -110,8 +93,6 @@ void ConvKernel::Compute(const ConvParam ¶m) const { math::Vol2ColFunctor vol2col; math::Im2ColFunctor im2col; - // auto& dev_ctx = context.template - // device_context(); for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); @@ -137,6 +118,9 @@ void ConvKernel::Compute(const ConvParam ¶m) const { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + DLOG << " out_slice " << out_slice.dims(); + DLOG << " filter_slice " << filter_slice.dims(); + DLOG << " col_matrix " << col_matrix.dims(); math::matmul(filter_slice, false, col_matrix, false, static_cast(1), &out_slice, static_cast(0)); -- GitLab