#include "megdnn/oprs.h" #include "src/common/utils.h" namespace megdnn { void ROIAlignBase::deduce_layout_fwd( const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, TensorLayout& index) { if (!src.is_empty()) megdnn_assert_contiguous(src); if (!rois.is_empty()) megdnn_assert_contiguous(rois); megdnn_assert_contiguous(dst); megdnn_assert_contiguous(index); auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(rois) + ", " + megdnn_layout_msg(dst) + ", " + megdnn_layout_msg(index); }; MEGDNN_MARK_USED_VAR(errmsg); using Format = ROIAlignBase::Param::Format; megdnn_assert(param().format == Format::NCHW); auto src_dtype = src.dtype, rois_dtype = rois.dtype; megdnn_assert( src_dtype == rois_dtype && src_dtype.category() == DTypeCategory::FLOAT); megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); size_t channels = src.shape[1]; megdnn_assert(rois.ndim == 2_z, "%s", errmsg().c_str()); // rois shape: bid, x0, y0, x1, y1 megdnn_assert(rois[1] == 5_z, "%s", errmsg().c_str()); size_t M = rois[0]; size_t pooled_height = param().pooled_height; size_t pooled_width = param().pooled_width; dst = TensorLayout{{M, channels, pooled_height, pooled_width}, src.dtype}; index = dst; index.dtype = dtype::Int32(); } void ROIAlignBase::check_layout_fwd( const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, const TensorLayout& index) { TensorLayout dst_expected, index_expected; megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, rois, dst_expected, index_expected); megdnn_assert_eq_shape(dst_expected, dst); megdnn_assert_eq_shape(index_expected, index); megdnn_assert(index.dtype == dtype::Int32()); } void ROIAlignForward::deduce_layout( const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, TensorLayout& index) { deduce_layout_fwd(src, rois, dst, index); } void ROIAlignForward::check_exec( const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, const TensorLayout& index, size_t workspace_in_bytes) { check_layout_fwd(src, rois, dst, index); auto required_workspace_in_bytes = get_workspace_in_bytes(src, rois, dst, index); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } void ROIAlignBackward::check_exec( const TensorLayout& diff, const TensorLayout& rois, const TensorLayout& index, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, rois, diff, index); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, rois, index, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } } // namespace megdnn // vim: syntax=cpp.doxygen