cv.h 11.6 KB
Newer Older
1 2 3 4
/**
 * \file dnn/include/megdnn/oprs/cv.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24
 */
#pragma once
#include "megdnn/internal/opr_header_prologue.h"

namespace megdnn {

/**
 * \brief This file contains CV operators, The layout is NHWC
 */

class FlipBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase);
    DEF_OPR_PARAM(Flip);

25 26 27
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
28 29 30 31 32
};

class FlipForward : public FlipBase {
    DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);

33
public:
M
Megvii Engine Team 已提交
34 35 36
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
37
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
38 39
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
40

41
protected:
M
Megvii Engine Team 已提交
42 43 44
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
45 46 47 48 49 50 51
};
using Flip = FlipForward;

class RotateBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase);
    DEF_OPR_PARAM(Rotate);

52 53 54
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
55 56 57 58 59
};

class RotateForward : public RotateBase {
    DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);

60
public:
M
Megvii Engine Team 已提交
61 62 63
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
64
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
65 66
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
67

68
protected:
M
Megvii Engine Team 已提交
69 70 71
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
72 73 74 75 76 77 78
};
using Rotate = RotateForward;

class ROICopyBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase);
    DEF_OPR_PARAM(ROICopy);

79 80 81
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
82 83 84 85 86
};

class ROICopyForward : public ROICopyBase {
    DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);

87
public:
M
Megvii Engine Team 已提交
88 89 90
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
91
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
92 93
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
94

95
protected:
M
Megvii Engine Team 已提交
96 97 98
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
99 100 101 102 103 104 105
};
using ROICopy = ROICopyForward;

class CvtColorBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase);
    DEF_OPR_PARAM(CvtColor);

106 107 108
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
109 110 111 112 113
};

class CvtColorForward : public CvtColorBase {
    DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);

114
public:
M
Megvii Engine Team 已提交
115 116 117
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
118
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
119 120
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
121

122
protected:
M
Megvii Engine Team 已提交
123 124 125
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
126 127 128 129 130 131 132 133 134 135
};
using CvtColor = CvtColorForward;

/**
 * \brief Applices an affine transformation
 */
class WarpAffineBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase);
    DEF_OPR_PARAM(WarpAffine);

136 137 138 139 140
public:
    using InterpolationMode = Param::InterpolationMode;
    using BorderMode = Param::BorderMode;

protected:
M
Megvii Engine Team 已提交
141 142 143
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& trans,
            const TensorLayout& dst);
144 145
    std::string param_msg() const;
    int get_real_coord(int p, int len);
146 147 148 149 150
};

class WarpAffineForward : public WarpAffineBase {
    DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1);

151
public:
152 153 154 155 156 157 158 159
    /**
     * \param[in] src input tensor
     * \param[in] trans transform matrix tensor
     * \param[in] dst output tensor
     *
     * \warning src, trans, border_value, dst should be contiguous
     * The size of trans is N * 2 * 3
     */
M
Megvii Engine Team 已提交
160 161 162 163 164 165
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& trans,
            const TensorLayout& dst) = 0;
166

167
protected:
M
Megvii Engine Team 已提交
168 169 170
    void check_exec(
            const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst,
            size_t workspace_in_bytes);
171 172 173 174 175 176 177
};
using WarpAffine = WarpAffineForward;

class GaussianBlurBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase);
    DEF_OPR_PARAM(GaussianBlur);

178 179 180
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
181 182 183 184 185
};

class GaussianBlurForward : public GaussianBlurBase {
    DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);

186
public:
M
Megvii Engine Team 已提交
187 188 189
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
190
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
191 192
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
193

194
protected:
M
Megvii Engine Team 已提交
195 196 197
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
};
using GaussianBlur = GaussianBlurForward;

/**
 * \brief Resize opr.
 */
class ResizeBase : public OperatorBase {
    DEF_OPR_PARAM(Resize);
    DEF_OPR_IMPL(ResizeBase, OperatorBase, 1, 1);

public:
    using InterpolationMode = Param::InterpolationMode;

protected:
    //! get origin coord
213 214 215 216 217
    std::pair<float, int> get_cubic_coord(float scale, int idx);

    std::tuple<float, int, float, int> get_nearest_linear_coord(
            InterpolationMode imode, float scale, int size, int idx);

218 219 220
    //! get nearest index in src
    int get_nearest_src(float scale, int size, int idx);

221 222 223 224 225 226 227
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
};

class ResizeForward : public ResizeBase {
    DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);

public:
M
Megvii Engine Team 已提交
228 229 230
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
231

M
Megvii Engine Team 已提交
232 233
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
234 235

protected:
M
Megvii Engine Team 已提交
236 237 238
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
239 240 241 242 243 244 245
};
using Resize = ResizeForward;

class ResizeBackward : public ResizeBase {
    DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);

public:
M
Megvii Engine Team 已提交
246 247 248
    virtual void exec(
            _megdnn_tensor_in diff, _megdnn_tensor_out grad,
            _megdnn_workspace workspace) = 0;
249

M
Megvii Engine Team 已提交
250 251
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& diff, const TensorLayout& mat) = 0;
252 253

protected:
M
Megvii Engine Team 已提交
254 255 256
    void check_exec(
            const TensorLayout& diff, const TensorLayout& mat,
            size_t workspace_in_bytes);
257 258
};

259 260 261 262 263 264 265 266 267 268 269 270
/**
 * \brief Remap opr.
 */
class RemapBase : public OperatorBase {
    DEF_OPR_PARAM(Remap);
    DEF_OPR_IMPL(RemapBase, OperatorBase, 2, 1);

public:
    using InterpolationMode = Param::InterpolationMode;
    using BorderMode = Param::BorderMode;

protected:
M
Megvii Engine Team 已提交
271 272 273 274 275
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& dst);
    void deduce_layout_fwd(
            const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
276 277 278 279 280 281
};

class RemapForward : public RemapBase {
    DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);

public:
M
Megvii Engine Team 已提交
282 283 284
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
285

M
Megvii Engine Team 已提交
286 287
    void deduce_layout(
            const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
288

M
Megvii Engine Team 已提交
289 290 291
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& dst) = 0;
292 293

protected:
M
Megvii Engine Team 已提交
294 295 296
    void check_exec(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& dst, size_t workspace_in_bytes);
297 298 299
};
using Remap = RemapForward;

300 301 302 303
class RemapBackwardData : public RemapBase {
    DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);

public:
M
Megvii Engine Team 已提交
304 305 306
    virtual void exec(
            _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
            _megdnn_workspace workspace) = 0;
307

M
Megvii Engine Team 已提交
308 309 310
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& map_xy, const TensorLayout& diff,
            const TensorLayout& grad) = 0;
311 312

protected:
M
Megvii Engine Team 已提交
313 314 315
    void check_exec(
            const TensorLayout& map_xy, const TensorLayout& diff,
            const TensorLayout& grad, size_t workspace_in_bytes);
316 317 318 319 320 321
};

class RemapBackwardMat : public RemapBase {
    DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);

public:
M
Megvii Engine Team 已提交
322 323 324
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
            _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
325

M
Megvii Engine Team 已提交
326 327 328
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& diff, const TensorLayout& grad) = 0;
329 330

protected:
M
Megvii Engine Team 已提交
331 332 333 334
    void check_exec(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& diff, const TensorLayout& grad,
            size_t workspace_in_bytes);
335 336
};

337
class SeparableFilterBase : public OperatorBase {
338 339
    DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
    DEF_OPR_PARAM(SeparableFilter);
340 341

protected:
M
Megvii Engine Team 已提交
342 343 344 345 346 347
    void deduce_layout_fwd(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, TensorLayout& dst);
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, const TensorLayout& dst);
348 349
};

350
class SeparableFilterForward : public SeparableFilterBase {
351
    DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);
352 353

public:
M
Megvii Engine Team 已提交
354 355 356 357 358 359 360 361 362 363
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
            _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
    void deduce_layout(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, TensorLayout& dst);
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, const TensorLayout& dst) = 0;
364 365

protected:
M
Megvii Engine Team 已提交
366 367 368 369
    void check_exec(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, const TensorLayout& dst,
            size_t workspace_in_bytes);
370 371 372 373 374 375 376 377
};
using SeparableFilter = SeparableFilterForward;

}  // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"

// vim: syntax=cpp.doxygen