resize.cpp 4.2 KB
Newer Older
1
#include "megdnn/handle.h"
2
#include "megdnn/opr_param_defs.h"
3 4 5 6 7 8
#include "megdnn/oprs.h"

#include "src/common/utils.h"

namespace megdnn {

M
Megvii Engine Team 已提交
9
void ResizeBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) {
10 11 12 13 14
    auto errmsg = [&]() {
        return megdnn_layout_msg(src) + ", " + ", " + megdnn_layout_msg(dst);
    };
    MEGDNN_MARK_USED_VAR(errmsg);

M
Megvii Engine Team 已提交
15 16 17
    megdnn_assert(
            dst.dtype == src.dtype && dst.shape[0] == src.shape[0], "%s",
            errmsg().c_str());
18 19
    if (param().format == Param::Format::NCHW) {
        megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str());
20
        auto imode = param().imode;
21
        using IMode = param::Resize::InterpolationMode;
M
Megvii Engine Team 已提交
22 23 24
        megdnn_assert(
                imode == IMode::INTER_LINEAR || imode == IMode::NEAREST ||
                imode == IMode::INTER_CUBIC);
25 26 27 28 29 30 31
    } else if (param().format == Param::Format::NHWC) {
        megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str());
    } else if (param().format == Param::Format::NCHW4) {
        megdnn_assert(src.ndim == 5);
        megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8);
        megdnn_assert(src.shape[4] == 4);
        megdnn_assert(dst.shape[4] == 4);
32 33 34 35
    } else if (param().format == Param::Format::NCHW44) {
        megdnn_assert(src.ndim == 5);
        megdnn_assert(src.shape[4] == 4);
        megdnn_assert(dst.shape[4] == 4);
M
Megvii Engine Team 已提交
36 37 38
        megdnn_assert(
                param().imode == param::Resize::InterpolationMode::INTER_LINEAR ||
                param().imode == param::Resize::InterpolationMode::INTER_NEAREST);
39 40 41 42
    } else if (param().format == Param::Format::NCHW88) {
        megdnn_assert(src.ndim == 5);
        megdnn_assert(src.shape[4] == 8);
        megdnn_assert(dst.shape[4] == 8);
M
Megvii Engine Team 已提交
43 44 45
        megdnn_assert(
                param().imode == param::Resize::InterpolationMode::INTER_LINEAR ||
                param().imode == param::Resize::InterpolationMode::INTER_NEAREST);
46
    } else {
M
Megvii Engine Team 已提交
47 48 49 50 51 52
        megdnn_assert(
                param().format == Param::Format::NHWCD4,
                "invalid resize tensor format");
        megdnn_assert(
                param().imode == param::Resize::InterpolationMode::INTER_LINEAR ||
                param().imode == param::Resize::InterpolationMode::INTER_NEAREST);
53 54 55 56
        megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str());
    }
}

M
Megvii Engine Team 已提交
57 58
void Resize::check_exec(
        const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
59 60 61 62 63
    check_layout_fwd(src, dst);
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

M
Megvii Engine Team 已提交
64 65
void ResizeBackward::check_exec(
        const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) {
66 67 68
    check_layout_fwd(grad, diff);
    auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
M
Megvii Engine Team 已提交
69 70 71
    megdnn_assert(
            param().format == Param::Format::NCHW && grad.dtype == dtype::Float32(),
            "Backward resize only supports Float32 and NCHW.");
72 73
}

74
std::pair<float, int> ResizeBase::get_cubic_coord(float scale, int idx) {
75 76 77 78 79 80
    float alpha = (idx + 0.5f) / scale - 0.5f;
    int origin_idx = static_cast<int>(floor(alpha));
    alpha -= origin_idx;
    return {alpha, origin_idx};
}

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
std::tuple<float, int, float, int> ResizeBase::get_nearest_linear_coord(
        InterpolationMode imode, float scale, int size, int idx) {
    if (size == 1) {
        return std::make_tuple(1.0f, 0, 0.0f, 0);
    }

    float alpha = (idx + 0.5f) / scale - 0.5f;
    int origin_idx = static_cast<int>(floor(alpha));
    alpha -= origin_idx;

    if (imode == InterpolationMode::INTER_NEAREST) {
        origin_idx = get_nearest_src(scale, size, idx);
        alpha = 0;
    }

    if (origin_idx < 0) {
        origin_idx = 0;
        alpha = 0;
    } else if (origin_idx + 1 >= size) {
        origin_idx = size - 2;
        alpha = 1;
    }

    return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1);
}

107 108 109
int ResizeBase::get_nearest_src(float scale, int size, int idx) {
    return std::min(static_cast<int>(idx / scale), size - 1);
}
110 111 112
}  // namespace megdnn

// vim: syntax=cpp.doxygen