From 26ea33c6a7bcd60459d4631780a92bb2a0d175f0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 31 Mar 2022 22:07:18 +0800 Subject: [PATCH] perf(imperative): improve convbwd performance GitOrigin-RevId: cfc8623d7a818c84ee8d6ed56e0dea6251ef36d6 --- dnn/include/megdnn/oprs/nn.h | 2 +- imperative/src/impl/ops/convolution.cpp | 157 ++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 41728e16a..f45657486 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -310,7 +310,7 @@ public: const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) = 0; - void deduce_dtype(DType filter, DType diff, DType& grad); + MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType filter, DType diff, DType& grad); void deduce_layout( const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad); diff --git a/imperative/src/impl/ops/convolution.cpp b/imperative/src/impl/ops/convolution.cpp index 3db9df94f..0a5761237 100644 --- a/imperative/src/impl/ops/convolution.cpp +++ b/imperative/src/impl/ops/convolution.cpp @@ -250,8 +250,165 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { } } +TensorLayout convbwd_do_shape_infer( + const OpDef& def, size_t diff_ndim, TensorLayout filter, TensorLayout diff, + CompNode cn) { + auto&& bwd_conv = static_cast(def); + DnnOprCaller caller(cn); + auto&& dnn_opr = caller.op; + using Param = ::megdnn::param::Convolution; + // using Param1 = ::megdnn::param::ConvolutionBackwardData; + + auto img_ndim = diff_ndim - 2; + mgb_assert( + img_ndim == 2, + "only 2D convolution is supported, and input should be 4-dim; " + "got input dim = %zu", + diff_ndim); + size_t group = 1; + size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos; + if (bwd_conv.sparse == Param::Sparse::DENSE) { + mgb_assert( + filter.ndim == img_ndim + 2 || filter.ndim == img_ndim + 4, + "bad filter ndim for dense convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); + group = 1; + flt_start = 0; + } else { // Param::Sparse::GROUP + mgb_assert( + filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5, + "bad filter ndim for group convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); + // grp, oc, ic, dims[] + group = filter[0]; + flt_start = 1; + } + + uint32_t ic_block_size = 1, oc_block_size = 1; + size_t src_or_dst_c_pos = 0; + size_t src_or_dst_spatial_start = 0; + if (bwd_conv.format == Param::Format::NCHW) { + // filter should be (oc, ic, fh, fw) + flt_spatial_start = 2; + ocpg_pos = 0; + icpg_pos = 1; + src_or_dst_c_pos = 1; + src_or_dst_spatial_start = 2; + } else { // Param::Format::NHWC + // filter should be (oc, fh, fw, ic) + flt_spatial_start = 1; + ocpg_pos = 0; + icpg_pos = 3; + src_or_dst_c_pos = 3; + src_or_dst_spatial_start = 1; + } + size_t ocpg = filter[flt_start + ocpg_pos] * oc_block_size; + size_t icpg = filter[flt_start + icpg_pos] * ic_block_size; + uint32_t dilation[2], dilated_spatial[2], stride[2], padding[2]; + dilation[0] = bwd_conv.dilate_h; + dilation[1] = bwd_conv.dilate_w; + stride[0] = bwd_conv.stride_h; + stride[1] = bwd_conv.stride_w; + padding[0] = bwd_conv.pad_h; + padding[1] = bwd_conv.pad_w; + for (size_t i = 0; i < img_ndim; ++i) { + mgb_assert( + dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i, + dilation[i]); + dilated_spatial[i] = + (filter[i + flt_start + flt_spatial_start] - 1) * dilation[i] + 1; + } + mgb_assert(ocpg * group == diff[src_or_dst_c_pos], "group conv invalid"); + + auto deduce = [](size_t out, size_t filter, size_t stride, size_t pad) { + auto i = (out - 1) * stride + filter; + mgb_assert(i > pad * 2); + return i - pad * 2; + }; + + DType dst_dtype = bwd_conv.dtype; + dnn_opr->deduce_dtype(filter.dtype, diff.dtype, dst_dtype); + TensorLayout dst{dst_dtype}; + dst.ndim = diff_ndim; + dst[0] = diff[0]; + dst[src_or_dst_c_pos] = icpg * group; + for (size_t i = 0; i < img_ndim; ++i) { + dst[i + src_or_dst_spatial_start] = + deduce(diff[i + src_or_dst_spatial_start], dilated_spatial[i], + stride[i], padding[i]); + } + dst.init_contiguous_stride(); + return dst; +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& conv = static_cast(def); + + SmallVector dests(1); + auto&& desc = dests[0]; + desc.comp_node = inputs[0].comp_node; + + TensorLayout filter = inputs[0].layout; + TensorLayout diff = inputs[1].layout; + size_t filter_ndim = filter.ndim; + size_t diff_ndim = diff.ndim; + if (filter_ndim == 0) { + desc.layout = filter; + return {dests, false}; + } + + desc.layout = + convbwd_do_shape_infer(def, diff_ndim, filter, diff, inputs[0].comp_node); + return {dests, true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + // create megdnn opr + auto&& convbwd = static_cast(def); + CompNode cn = inputs[0]->comp_node(); + + TensorLayout out_layout = output_descs[0].layout; + if (!validated) + out_layout = convbwd_do_shape_infer( + def, inputs[1]->layout().ndim, inputs[0]->layout(), inputs[1]->layout(), + cn); + + DeviceTensorND out = + BlobManager::inst()->alloc_workspace_with_defrag(cn, out_layout); + + using TensorND = megdnn::TensorND; + SmallVector inp_tensornds(inputs.size()); + TensorLayoutArray inp_shapes(inputs.size()), oup_shapes(output_descs.size()); + for (unsigned i = 0; i < inputs.size(); ++i) { + inp_tensornds[i] = inputs[i]->dnn_tensor(); + inp_shapes[i] = inputs[i]->layout(); + } + oup_shapes[0] = out_layout; + DnnOprCaller dnn_opr(cn); + dnn_opr.op->param() = convbwd.param(); + + size_t sz = setup_algo( + {inp_shapes[0], inp_shapes[1], oup_shapes[0]}, dnn_opr.op.get(), 0, false, + false, cn, convbwd.policy(), false); + + auto wk = Blob::make(cn, sz); + auto ptr = wk->storage().get(); + megdnn::Workspace dnn_wk(ptr, sz); + + // exeucte + dnn_opr.op->exec(inp_tensornds[0], inp_tensornds[1], out.as_megdnn(), dnn_wk); + return {Tensor::make(out)}; +} + OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) .fallback(); } // namespace convolution_backward_data } // namespace -- GitLab