forward.cpp 8.2 KB
Newer Older
1
#include "src/common/cv/common.h"
2
#include "src/common/cv/enums.h"
3 4 5 6 7 8 9 10 11 12 13 14
#include "src/cuda/handle.h"
#include "src/cuda/resize/common.h"
#include "src/cuda/resize/helper.h"
#include "src/cuda/resize/opr_impl.h"
#include "src/cuda/resize/resize_cv.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;

namespace {

M
Megvii Engine Team 已提交
15 16 17
void resize_cv_proxy(
        _megdnn_tensor_in src, _megdnn_tensor_out dst, InterpolationMode imode,
        void* workspace, cudaStream_t stream) {
18 19 20 21 22 23
    using namespace megcv;
    for (size_t i = 0; i < src.layout.shape[0]; ++i) {
        if (dst.layout.dtype == dtype::Float32()) {
            Mat<float> src_mat = TensorND2Mat<float>(src, i);
            Mat<float> dst_mat = TensorND2Mat<float>(dst, i);
            resize::resize_cv<float>(
M
Megvii Engine Team 已提交
24 25 26
                    src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
                    dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
                    src_mat.channels(), imode, workspace, stream);
27 28 29 30
        } else if (dst.layout.dtype == dtype::Uint8()) {
            Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i);
            Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i);
            resize::resize_cv<uchar>(
M
Megvii Engine Team 已提交
31 32 33
                    src_mat.ptr(), dst_mat.ptr(), src_mat.rows(), src_mat.cols(),
                    dst_mat.rows(), dst_mat.cols(), src_mat.step(), dst_mat.step(),
                    src_mat.channels(), imode, workspace, stream);
34
        } else {
M
Megvii Engine Team 已提交
35
            megdnn_throw("Unsupported datatype of WarpAffine optr.");
36 37 38 39 40 41
        }
    }
}

}  // anonymous namespace

M
Megvii Engine Team 已提交
42 43
size_t ResizeImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& dst) {
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
    InterpolationMode imode = param().imode;
    if (param().format == Param::Format::NCHW ||
        (imode != Param::InterpolationMode::CUBIC &&
         imode != Param::InterpolationMode::LANCZOS4)) {
        return 0;
    }

    size_t src_rows = src.shape[1];
    size_t dst_rows = dst.shape[1];
    size_t src_cols = src.shape[2];
    size_t dst_cols = dst.shape[2];
    size_t ch = src.shape[3];

    size_t dst_area_size = dst_rows * dst_cols;
    size_t src_area_size = src_rows * src_cols;

    bool enlarge = dst_area_size > src_area_size;
    bool shrink = dst_area_size <= src_area_size;
    bool U8 = src.dtype == dtype::Uint8();
    megdnn_assert(src.dtype == dtype::Uint8() || src.dtype == dtype::Float32());
    bool F32_1 = !U8 && ch == 1;
    bool F32_3 = !U8 && ch == 3;

    bool use_vector = (enlarge && (dst_area_size <= 500 * 500)) ||
                      (shrink && (F32_3 || (U8 && dst_area_size <= 500 * 500) ||
                                  (F32_1 && dst_area_size <= 1000 * 1000)));

    if (!use_vector) {
        int coef_size = 0;
        if (imode == Param::InterpolationMode::CUBIC) {
            coef_size = 4;
        } else {
            coef_size = 8;
            megdnn_assert(imode == Param::InterpolationMode::LANCZOS4);
        }
        if (U8) {
            return dst_rows * coef_size * sizeof(short) +  //! dev_coef_row
                   dst_rows * sizeof(int) +                //! dev_sr
                   dst_cols * coef_size * sizeof(short) +  //! dev_coef_col
                   dst_cols * sizeof(int);                 //! dev_sc
        } else {
            return dst_rows * coef_size * sizeof(float) +  //! dev_coef_row
                   dst_rows * sizeof(int) +                //! dev_sr
                   dst_cols * coef_size * sizeof(float) +  //! dev_coef_col
                   dst_cols * sizeof(int);                 //! dev_sc
        }
    }

    return 0;
}

M
Megvii Engine Team 已提交
95 96
void ResizeImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
97 98 99 100 101 102 103 104
    check_exec(src.layout, dst.layout, workspace.size);
    auto stream = cuda_stream(this->handle());
    bool is_nhwc = param().format == param::Resize::Format::NHWC;
    size_t C, IH, IW, OH, OW;
    ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0;
    if (is_nhwc) {
        if (param().imode != Param::InterpolationMode::LINEAR &&
            is_nhwc_contig_wc(src.layout)) {
M
Megvii Engine Team 已提交
105 106 107
            resize_cv_proxy(
                    src, dst, resize::get_imode(param().imode), workspace.raw_ptr,
                    stream);
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
            return;
        }
        C = src.layout.shape[3];
        IH = src.layout.shape[1];
        IW = src.layout.shape[2];
        OH = dst.layout.shape[1];
        OW = dst.layout.shape[2];
    } else if (param().format == param::Resize::Format::NCHW) {
        C = src.layout.shape[1];
        IH = src.layout.shape[2];
        IW = src.layout.shape[3];
        OH = dst.layout.shape[2];
        OW = dst.layout.shape[3];
        S_IN = src.layout.stride[0];
        S_IC = src.layout.stride[1];
        S_IH = src.layout.stride[2];
        S_IW = src.layout.stride[3];
    } else {
M
Megvii Engine Team 已提交
126 127 128
        megdnn_assert(
                param().format == param::Resize::Format::NCHW4,
                "invalid resize format");
129 130 131 132 133 134
        megdnn_assert(src.layout.dtype.enumv() == DTypeEnum::QuantizedS8);
        C = src.layout.shape[1] * 4;
        IH = src.layout.shape[2];
        IW = src.layout.shape[3];
        OH = dst.layout.shape[2];
        OW = dst.layout.shape[3];
M
Megvii Engine Team 已提交
135 136 137
        resize::forward_proxy_nchw4(
                src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(),
                src.layout[0], C, IH, IW, OH, OW, stream);
138 139
        return;
    }
140 141 142 143 144
    megdnn_assert(
            param().imode == Param::InterpolationMode::LINEAR ||
                    param().imode == Param::InterpolationMode::NEAREST ||
                    param().imode == Param::InterpolationMode::INTER_CUBIC,
            "unsupported interpolation mode for NCHW format");
145 146

    if (src.layout.dtype == dtype::Float32{}) {
M
Megvii Engine Team 已提交
147 148 149 150
        resize::forward_proxy(
                is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float32>(),
                dst.ptr<dt_float32>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
                S_IH, S_IW, stream);
151 152 153 154 155
    } else if (src.layout.dtype == dtype::Float16{}) {
        resize::forward_proxy(
                is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_float16>(),
                dst.ptr<dt_float16>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC,
                S_IH, S_IW, stream);
156
    } else if (src.layout.dtype == dtype::Uint8()) {
M
Megvii Engine Team 已提交
157 158 159 160
        resize::forward_proxy(
                is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_uint8>(),
                dst.ptr<dt_uint8>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
                S_IW, stream);
161
    } else if (src.layout.dtype == dtype::Int8()) {
M
Megvii Engine Team 已提交
162 163 164 165
        resize::forward_proxy(
                is_nhwc, resize::get_imode((param().imode)), src.ptr<dt_int8>(),
                dst.ptr<dt_int8>(), src.layout[0], C, IH, IW, OH, OW, S_IN, S_IC, S_IH,
                S_IW, stream);
166
    } else {
M
Megvii Engine Team 已提交
167
        megdnn_throw(ssprintf("unsupported dtype: %s", src.layout.dtype.name()));
168 169 170
    }
}

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
size_t Resize3DImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& dst) {
    return 0;
}

void Resize3DImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) {
    check_exec(src.layout, dst.layout, workspace.size);
    size_t out_depth = dst.layout.shape[2];
    size_t out_height = dst.layout.shape[3];
    size_t out_width = dst.layout.shape[4];

    size_t in_depth = src.layout.shape[2];
    size_t in_height = src.layout.shape[3];
    size_t in_width = src.layout.shape[4];

    bool align_corners = param().align_corners;
    auto stream = cuda_stream(this->handle());

    if (src.layout.dtype == dtype::Float32{}) {
        resize3d::resize3d_forward(
                align_corners, src.ptr<dt_float32>(), dst.ptr<dt_float32>(),
                src.layout[0], src.layout[1], in_depth, in_height, in_width, out_depth,
                out_height, out_width, stream);
#if !MEGDNN_DISABLE_FLOAT16
    } else if (src.layout.dtype == dtype::Float16{}) {
        resize3d::resize3d_forward(
                align_corners, src.ptr<dt_float16>(), dst.ptr<dt_float16>(),
                src.layout[0], src.layout[1], in_depth, in_height, in_width, out_depth,
                out_height, out_width, stream);
#endif
    } else {
        megdnn_throw(ssprintf(
                "unsupported dtype: %s for Resize3D", src.layout.dtype.name()));
    }
}

208
// vim: syntax=cpp.doxygen