forward.cpp 6.6 KB
Newer Older
1
#include "src/common/cv/common.h"
2
#include "src/common/cv/enums.h"
3 4 5 6 7 8 9 10 11 12 13 14
#include "src/cuda/handle.h"
#include "src/cuda/resize/common.h"
#include "src/cuda/resize/helper.h"
#include "src/cuda/resize/opr_impl.h"
#include "src/cuda/resize/resize_cv.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

namespace {

M
Megvii Engine Team 已提交
15 16 17
void resize_cv_proxy(
        _megdnn_tensor_in src, _megdnn_tensor_out dst, InterpolationMode imode,
        void* workspace, cudaStream_t stream) {
18 19 20 21 22 23
    using namespace megcv;
    for (size_t i = 0; i < src.layout.shape[0]; ++i) {
        if (dst.layout.dtype == dtype::Float32()) {
            Mat<float> src_mat = TensorND2Mat<float>(src, i);
            Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
            resize::resize_cv<float>(
M
Megvii Engine Team 已提交
24 25 26
                    src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
                    dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
                    src_mat.channels(), imode, workspace, stream);
27 28 29 30
        } else if (dst.layout.dtype == dtype::Uint8()) {
            Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
            Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
            resize::resize_cv<uchar>(
M
Megvii Engine Team 已提交
31 32 33
                    src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
                    dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
                    src_mat.channels(), imode, workspace, stream);
34
        } else {
M
Megvii Engine Team 已提交
35
            megdnn_throw("Unsupported datatype of WarpAffine optr.");
36 37 38 39 40 41
        }
    }
}

}  // anonymous namespace

M
Megvii Engine Team 已提交
42 43
size_t ResizeImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& dst) {
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    InterpolationMode imode = param().imode;
    if (param().format == Param::Format::NCHW ||
        (imode != Param::InterpolationMode::CUBIC &&
         imode != Param::InterpolationMode::LANCZOS4)) {
        return 0;
    }

    size_t src_rows = src.shape[1];
    size_t dst_rows = dst.shape[1];
    size_t src_cols = src.shape[2];
    size_t dst_cols = dst.shape[2];
    size_t ch = src.shape[3];

    size_t dst_area_size = dst_rows * dst_cols;
    size_t src_area_size = src_rows * src_cols;

    bool enlarge = dst_area_size > src_area_size;
    bool shrink = dst_area_size <= src_area_size;
    bool U8 = src.dtype == dtype::Uint8();
    megdnn_assert(src.dtype == dtype::Uint8() || src.dtype == dtype::Float32());
    bool F32_1 = !U8 && ch == 1;
    bool F32_3 = !U8 && ch == 3;

    bool use_vector = (enlarge && (dst_area_size <= 500 * 500)) ||
                      (shrink && (F32_3 || (U8 && dst_area_size <= 500 * 500) ||
                                  (F32_1 && dst_area_size <= 1000 * 1000)));

    if (!use_vector) {
        int coef_size = 0;
        if (imode == Param::InterpolationMode::CUBIC) {
            coef_size = 4;
        } else {
            coef_size = 8;
            megdnn_assert(imode == Param::InterpolationMode::LANCZOS4);
        }
        if (U8) {
            return dst_rows * coef_size * sizeof(short) +  //! dev_coef_row
                   dst_rows * sizeof(int) +                //! dev_sr
                   dst_cols * coef_size * sizeof(short) +  //! dev_coef_col
                   dst_cols * sizeof(int);                 //! dev_sc
        } else {
            return dst_rows * coef_size * sizeof(float) +  //! dev_coef_row
                   dst_rows * sizeof(int) +                //! dev_sr
                   dst_cols * coef_size * sizeof(float) +  //! dev_coef_col
                   dst_cols * sizeof(int);                 //! dev_sc
        }
    }

    return 0;
}

M
Megvii Engine Team 已提交
95 96
void ResizeImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
97 98 99 100 101 102 103 104
    check_exec(src.layout, dst.layout, workspace.size);
    auto stream = cuda_stream(this->handle());
    bool is_nhwc = param().format == param::Resize::Format::NHWC;
    size_t C, IH, IW, OH, OW;
    ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0;
    if (is_nhwc) {
        if (param().imode != Param::InterpolationMode::LINEAR &&
            is_nhwc_contig_wc(src.layout)) {
M
Megvii Engine Team 已提交
105 106 107
            resize_cv_proxy(
                    src, dst, resize::get_imode(param().imode), workspace.raw_ptr,
                    stream);
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
            return;
        }
        C = src.layout.shape[3];
        IH = src.layout.shape[1];
        IW = src.layout.shape[2];
        OH = dst.layout.shape[1];
        OW = dst.layout.shape[2];
    } else if (param().format == param::Resize::Format::NCHW) {
        C = src.layout.shape[1];
        IH = src.layout.shape[2];
        IW = src.layout.shape[3];
        OH = dst.layout.shape[2];
        OW = dst.layout.shape[3];
        S_IN = src.layout.stride[0];
        S_IC = src.layout.stride[1];
        S_IH = src.layout.stride[2];
        S_IW = src.layout.stride[3];
    } else {
M
Megvii Engine Team 已提交
126 127 128
        megdnn_assert(
                param().format == param::Resize::Format::NCHW4,
                "invalid resize format");
129 130 131 132 133 134
        megdnn_assert(src.layout.dtype.enumv() == DTypeEnum::QuantizedS8);
        C = src.layout.shape[1] * 4;
        IH = src.layout.shape[2];
        IW = src.layout.shape[3];
        OH = dst.layout.shape[2];
        OW = dst.layout.shape[3];
M
Megvii Engine Team 已提交
135 136 137
        resize::forward_proxy_nchw4(
                src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(),
                src.layout[0], C, IH, IW, OH, OW, stream);
138 139
        return;
    }
140 141 142 143 144
    megdnn_assert(
            param().imode == Param::InterpolationMode::LINEAR ||
                    param().imode == Param::InterpolationMode::NEAREST ||
                    param().imode == Param::InterpolationMode::INTER_CUBIC,
            "unsupported interpolation mode for NCHW format");
145 146

    if (src.layout.dtype == dtype::Float32{}) {
M
Megvii Engine Team 已提交
147 148 149 150
        resize::forward_proxy(
                is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(),
                dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
                S_IH, S_IW, stream);
151
    } else if (src.layout.dtype == dtype::Uint8()) {
M
Megvii Engine Team 已提交
152 153 154 155
        resize::forward_proxy(
                is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(),
                dst.ptr<dt_uint8>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
                S_IW, stream);
156
    } else if (src.layout.dtype == dtype::Int8()) {
M
Megvii Engine Team 已提交
157 158 159 160
        resize::forward_proxy(
                is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_int8>(),
                dst.ptr<dt_int8>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
                S_IW, stream);
161
    } else {
M
Megvii Engine Team 已提交
162
        megdnn_throw(ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
163 164 165 166
    }
}

// vim: syntax=cpp.doxygen