/** * \file dnn/src/common/deformable_conv.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/oprs/nn.h" #include "src/common/utils.h" using namespace megdnn; using CanonizedFilterMeta = DeformableConvBase::CanonizedFilterMeta; namespace { template std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& dst, const Param& param) { MEGDNN_MARK_USED_VAR(src); MEGDNN_MARK_USED_VAR(filter); MEGDNN_MARK_USED_VAR(dst); return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + megdnn_layout_msg(offset) + ", " + megdnn_layout_msg(mask) + ", " + megdnn_layout_msg(dst) + ", " + megdnn_mangle("only support nchw") + ", " + megdnn_mangle("group=") + std::to_string(param.group) + ", " + megdnn_mangle("deformable_group=") + std::to_string(param.deformable_group) + ", " + megdnn_mangle("pad_h=") + std::to_string(param.pad_h) + ", " + megdnn_mangle("pad_w=") + std::to_string(param.pad_w) + ", " + megdnn_mangle("stride_h=") + std::to_string(param.stride_h) + ", " + megdnn_mangle("stride_w=") + std::to_string(param.stride_w) + ", " + megdnn_mangle("dilate_h=") + std::to_string(param.dilate_h) + ", " + megdnn_mangle("dilate_w=") + std::to_string(param.dilate_w); } template void make_canonized_filter_meta_nchw(size_t src_ndim, const TensorLayout& filter, const Param& param, CanonizedFilterMeta& ret) { megdnn_assert(param.mode == Param::Mode::CROSS_CORRELATION, "only support CROSS_CORRELATION mode"); megdnn_assert(param.format == Param::Format::NCHW, "only support nchw input layout"); size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos; flt_start = 0, flt_spatial_start = 2; ocpg_pos = 0, icpg_pos = 1; if (param.sparse == Param::Sparse::GROUP) flt_start = 1; ret.spatial_ndim = src_ndim - 2; megdnn_assert( ret.spatial_ndim == 2, "only 2D convolution is supported, and imput should be 4-dim; " "got input dim = %zu", src_ndim); ret.ocpg = filter[flt_start + ocpg_pos]; ret.icpg = filter[flt_start + icpg_pos]; auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { megdnn_assert(dilation[i] > 0, "invalid dilation on spatial dim %zu, %u", i, dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; } } } // namespace namespace megdnn { CanonizedFilterMeta DeformableConvBase::make_canonized_filter_meta( size_t src_ndim, const TensorLayout& filter, const TensorLayout& offset) const { megdnn_assert_contiguous(filter); CanonizedFilterMeta ret; ret.group = 1; ret.dtype = filter.dtype; ret.stride[0] = param().stride_h; ret.stride[1] = param().stride_w; ret.padding[0] = param().pad_h; ret.padding[1] = param().pad_w; ret.dilation[0] = param().dilate_h; ret.dilation[1] = param().dilate_w; if (param().sparse == Param::Sparse::GROUP) { megdnn_assert(filter.ndim == 5, "filter dim should be 5 for group conv"); ret.group = filter[0]; } make_canonized_filter_meta_nchw(src_ndim, filter, param(), ret); auto fh = ret.spatial[0]; auto fw = ret.spatial[1]; ret.deformable_group = offset[1] / (2 * fh * fw); return ret; } void DeformableConvBase::deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, TensorLayout& dst) { // im shape: (n, IC, IH, IW) megdnn_assert(im.ndim == 4, "invalid src layout: %s", megdnn_layout_msg(im).c_str()); // filter shape: (OC, IC, FH, FW) or (g, OC/g, IC/g, FH, FW) megdnn_assert(filter.ndim == 4 || filter.ndim == 5, "invalid filter layout: %s", megdnn_layout_msg(filter).c_str()); // offset shape: (N, 2*dg*FH*FW, OH, OW) megdnn_assert(offset.ndim == 4, "invalid offset layout: %s", megdnn_layout_msg(offset).c_str()); // mask shape: (N, dg*FH*FW, OH, OW) megdnn_assert(mask.ndim == 4, "invalid mask layout: %s", megdnn_layout_msg(mask).c_str()); size_t n = im.shape[0], ic = im.shape[1]; size_t ih = im.shape[2], iw = im.shape[3]; size_t dh = param().dilate_h, dw = param().dilate_w; size_t ph = param().pad_h, pw = param().pad_w; size_t sh = param().stride_h, sw = param().stride_w; auto&& fm = make_canonized_filter_meta(im.ndim, filter, offset); size_t fh = fm.spatial[0], fw = fm.spatial[1]; size_t kh = 1 + (fh - 1) * dh; size_t kw = 1 + (fw - 1) * dw; size_t group = fm.group; size_t deformable_group = fm.deformable_group; size_t icpg = fm.icpg, ocpg = fm.ocpg; size_t oc = group * ocpg; size_t oh = (ih + ph * 2 - kh) / sh + 1; size_t ow = (iw + pw * 2 - kw) / sw + 1; megdnn_assert(group > 0 && deformable_group > 0, "group and deformable group should > 0"); megdnn_assert(ic == icpg * group, "im ic != group * icpg of filter"); megdnn_assert(ic % deformable_group == 0, "ic %% deformable_group != 0"); megdnn_assert(oc % deformable_group == 0, "oc %% deformable_group != 0"); megdnn_assert( (offset[1] % (2 * fh * fw) == 0) && (mask[1] % (fh * fw) == 0), "invalid deformable group deduced from offset(%s) or mask(%s)", megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str()); megdnn_assert((offset[1] / (2 * fh * fw)) == (mask[1] / (fh * fw)), "offset(%s) and mask(%s) should have same deformable group", megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str()); megdnn_assert((offset[2] == mask[2]) && (offset[3] == mask[3]), "offset(%s) and mask(%s) should have same spatial dim", megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str()); megdnn_assert(oh == offset[2], "deduced oh(%zu) != offset oh(%zu)", oh, offset[2]); megdnn_assert(ow == offset[3], "deduced ow(%zu) != offset ow(%zu)", ow, offset[3]); dst.ndim = 4; dst = {{n, oc, oh, ow}, im.dtype}; } void DeformableConvBase::check_layout_fwd(const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& dst) { auto& im_dtype = im.dtype; TensorLayout dst_expected; megdnn_assert(im_dtype.enumv() == DTypeEnum::Float32, "DeformableConv only support float32 input"); megdnn_assert_eq_dtype(im, dst); megdnn_assert_eq_dtype(im, filter); megdnn_assert_eq_dtype(im, dst); megdnn_assert_eq_dtype(im, offset); megdnn_assert_eq_dtype(im, mask); deduce_layout_fwd(im, filter, offset, mask, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); } void DeformableConvForward::deduce_layout(const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, TensorLayout& dst) { deduce_layout_fwd(im, filter, offset, mask, dst); return; } CanonizedFilterMeta DeformableConvForward::check_exec( const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& dst, size_t workspace_in_bytes) { auto ret = make_canonized_filter_meta(im.ndim, filter, offset); auto required_workspace_in_bytes = get_workspace_in_bytes(im, filter, offset, mask, dst); check_layout_fwd(im, filter, offset, mask, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; } CanonizedFilterMeta DeformableConvBackwardFilter::check_exec( const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& filter_grad, size_t workspace_in_bytes) { check_layout_fwd(im, filter_grad, offset, mask, out_grad); // check dtype megdnn_assert_eq_dtype(im, filter_grad); auto ret = make_canonized_filter_meta(im.ndim, filter_grad, offset); auto required_workspace_in_bytes = get_workspace_in_bytes(im, offset, mask, out_grad, filter_grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; } CanonizedFilterMeta DeformableConvBackwardData::check_exec( const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad, size_t workspace_in_bytes) { check_layout_fwd(im, filter, offset, mask, out_grad); // check dtype megdnn_assert_eq_dtype(im, im_grad); megdnn_assert_eq_dtype(im, offset_grad); megdnn_assert_eq_dtype(im, mask_grad); // check layout megdnn_assert(im.shape == im_grad.shape, "invalid im_grad shape: %s", megdnn_layout_msg(im_grad).c_str()); megdnn_assert(offset.shape == offset_grad.shape, "invalid offset_grad shape: %s", megdnn_layout_msg(offset_grad).c_str()); megdnn_assert(mask.shape == mask_grad.shape, "invalid mask_grad shape: %s", megdnn_layout_msg(mask_grad).c_str()); auto ret = make_canonized_filter_meta(im.ndim, filter, offset); auto required_workspace_in_bytes = get_workspace_in_bytes(im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; } } // namespace megdnn // vim: syntax=cpp.doxygen