padding.cpp 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
/**
 * \file dnn/src/common/padding.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "megdnn/oprs.h"
#include "megdnn/oprs/general.h"
#include "megdnn/thin/small_vector.h"
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/common/utils.h"

namespace megdnn {

using padding_param = megdnn::param_enumv::Padding;

void PaddingForward::forward_check_exec(const TensorLayout& src,
                                        const TensorLayout& dst) {
    check_exec(src, dst);
    megdnn_assert(src.dtype.enumv() != DTypeEnum::Bool &&
                          src.dtype.enumv() != DTypeEnum::IntB1 &&
                          src.dtype.enumv() != DTypeEnum::IntB2 &&
                          src.dtype.enumv() != DTypeEnum::IntB4,
                  "unsupported %s dtype for forward padding opr",
                  src.dtype.name());
}

void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
    SmallVector<size_t> offsets(get_offsets());
    TensorShape dst_shape;
    switch (src.ndim) {
        case 1:
            dst_shape = {src.shape[0] + offsets[0] + offsets[1]};
            break;
        case 2:
            dst_shape = {src.shape[0] + offsets[0] + offsets[1],
                         src.shape[1] + offsets[2] + offsets[3]};
            break;
        case 3:
            dst_shape = {src.shape[0] + offsets[0] + offsets[1],
                         src.shape[1] + offsets[2] + offsets[3],
                         src.shape[2] + offsets[4] + offsets[5]};
            break;
        case 4:
            dst_shape = {src.shape[0] + offsets[0] + offsets[1],
                         src.shape[1] + offsets[2] + offsets[3],
                         src.shape[2] + offsets[4] + offsets[5],
                         src.shape[3] + offsets[6] + offsets[7]};
            break;
        case 5:
            dst_shape = {src.shape[0] + offsets[0] + offsets[1],
                         src.shape[1] + offsets[2] + offsets[3],
                         src.shape[2] + offsets[4] + offsets[5],
                         src.shape[3] + offsets[6] + offsets[7],
                         src.shape[4] + offsets[8] + offsets[9]};
            break;
        case 6:
            dst_shape = {src.shape[0] + offsets[0] + offsets[1],
                         src.shape[1] + offsets[2] + offsets[3],
                         src.shape[2] + offsets[4] + offsets[5],
                         src.shape[3] + offsets[6] + offsets[7],
                         src.shape[4] + offsets[8] + offsets[9],
                         src.shape[5] + offsets[10] + offsets[11]};
            break;
        case 7:
            dst_shape = {src.shape[0] + offsets[0] + offsets[1],
                         src.shape[1] + offsets[2] + offsets[3],
                         src.shape[2] + offsets[4] + offsets[5],
                         src.shape[3] + offsets[6] + offsets[7],
                         src.shape[4] + offsets[8] + offsets[9],
                         src.shape[5] + offsets[10] + offsets[11],
                         src.shape[6] + offsets[12] + offsets[13]};
            break;
        default:
            megdnn_assert(false, "invalid tensor ndim %zu", src.ndim);
            break;
    }
    dst = TensorLayout(dst_shape, src.dtype);
}

void PaddingBackward::backward_check_exec(const TensorLayout& src,
                                          const TensorLayout& dst) {
    check_exec(dst, src);
    megdnn_assert(src.dtype.enumv() ==
                          DTypeEnum::Float32 DNN_INC_FLOAT16(
                                  || src.dtype.enumv() == DTypeEnum::Float16 ||
                                  src.dtype.enumv() == DTypeEnum::BFloat16),
                  "unsupported %s dtype for forward padding opr",
                  src.dtype.name());
}

SmallVector<size_t> PaddingBase::get_offsets() {
    SmallVector<size_t> offsets = {
            param().front_offset_dim0, param().back_offset_dim0,
            param().front_offset_dim1, param().back_offset_dim1,
            param().front_offset_dim2, param().back_offset_dim2,
            param().front_offset_dim3, param().back_offset_dim3,
            param().front_offset_dim4, param().back_offset_dim4,
            param().front_offset_dim5, param().back_offset_dim5,
            param().front_offset_dim6, param().back_offset_dim6};
    return offsets;
}

void PaddingBase::check_exec(const TensorLayout& src, const TensorLayout& dst) {
    SmallVector<size_t> offsets(get_offsets());
    // make sure the src and dst tensor not empty
    megdnn_assert(src.ndim != 0 && dst.ndim != 0);
    // make sure src and dst is same dtype
    megdnn_assert_eq_dtype(src, dst);
    // make sure src and dst is same ndim
    megdnn_assert(src.ndim == dst.ndim, "the src.ndim = %zu the dst.ndim = %zu",
                  src.ndim, dst.ndim);
    // make sure in every dimension dst is equal or greater than src
    for (size_t i = 0; i < src.ndim; ++i) {
        megdnn_assert(dst.shape[i] ==
                      src.shape[i] + offsets[i * 2] + offsets[i * 2 + 1]);
    }
    // check the padding mode is valid
    megdnn_assert(static_cast<uint32_t>(param().padding_mode) ==
                                  padding_param::PaddingMode::REFLECT ||
                          static_cast<uint32_t>(param().padding_mode) ==
                                  padding_param::PaddingMode::REPLICATE ||
                          static_cast<uint32_t>(param().padding_mode) ==
                                  padding_param::PaddingMode::CONSTANT,
                  "unsupported padding mode");
    // addition check for reflect padding, make sure the reflected index is
    // valid
    if (static_cast<uint32_t>(param().padding_mode) ==
        padding_param::PaddingMode::REFLECT) {
        for (size_t i = 0; i < src.ndim; ++i) {
            megdnn_assert(offsets[i * 2] < src.shape[i] &&
                          dst.shape[i] - offsets[i * 2] - src.shape[i] <
                                  src.shape[i]);
        }
    }
}

}  // namespace megdnn