resize.cpp 3.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/common/resize.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12
 */

13
#include "megdnn/handle.h"
14
#include "megdnn/opr_param_defs.h"
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include "megdnn/oprs.h"

#include "src/common/utils.h"

namespace megdnn {

void ResizeBase::check_layout_fwd(const TensorLayout& src,
                                  const TensorLayout& dst) {
    auto errmsg = [&]() {
        return megdnn_layout_msg(src) + ", " + ", " + megdnn_layout_msg(dst);
    };
    MEGDNN_MARK_USED_VAR(errmsg);

    megdnn_assert(dst.dtype == src.dtype && dst.shape[0] == src.shape[0], "%s",
                  errmsg().c_str());
    if (param().format == Param::Format::NCHW) {
        megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str());
32
        auto imode = param().imode;
33 34 35
        using IMode = param::Resize::InterpolationMode;
        megdnn_assert(imode == IMode::INTER_LINEAR || imode == IMode::NEAREST ||
                      imode == IMode::INTER_CUBIC);
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    } else if (param().format == Param::Format::NHWC) {
        megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str());
    } else if (param().format == Param::Format::NCHW4) {
        megdnn_assert(src.ndim == 5);
        megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8);
        megdnn_assert(src.shape[4] == 4);
        megdnn_assert(dst.shape[4] == 4);
    } else {
        megdnn_assert(param().format == Param::Format::NHWCD4,
                      "invalid resize tensor format");
        megdnn_assert(param().imode ==
                      param::Resize::InterpolationMode::INTER_LINEAR);
        megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str());
    }
}

void Resize::check_exec(const TensorLayout& src, const TensorLayout& dst,
                        size_t workspace_in_bytes) {
    check_layout_fwd(src, dst);
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

void ResizeBackward::check_exec(const TensorLayout& diff,
                                const TensorLayout& grad,
                                size_t workspace_in_bytes) {
    check_layout_fwd(grad, diff);
    auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
    megdnn_assert(param().format == Param::Format::NCHW &&
                          grad.dtype == dtype::Float32(),
                  "Backward resize only supports Float32 and NCHW.");
}

std::pair<float, int> ResizeBase::get_origin_coord(float scale, int size,
71
                                                   int idx, bool cubic) {
72 73 74 75
    //! copy from resize_cv.cpp
    float alpha = (idx + 0.5f) / scale - 0.5f;
    int origin_idx = static_cast<int>(floor(alpha));
    alpha -= origin_idx;
76 77 78 79 80 81 82 83
    if (!cubic) {
        if (origin_idx < 0) {
            origin_idx = 0;
            alpha = 0;
        } else if (origin_idx + 1 >= size) {
            origin_idx = size - 2;
            alpha = 1;
        }
84 85 86 87
    }
    return {alpha, origin_idx};
}

88 89 90
int ResizeBase::get_nearest_src(float scale, int size, int idx) {
    return std::min(static_cast<int>(idx / scale), size - 1);
}
91 92 93
}  // namespace megdnn

// vim: syntax=cpp.doxygen