cv.h 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#pragma once
#include "megdnn/internal/opr_header_prologue.h"

namespace megdnn {

/**
 * \brief This file contains CV operators, The layout is NHWC
 */

class FlipBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase);
    DEF_OPR_PARAM(Flip);

14 15 16
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
17 18 19 20 21
};

class FlipForward : public FlipBase {
    DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1);

22
public:
M
Megvii Engine Team 已提交
23 24 25
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
26
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
27 28
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
29

30
protected:
M
Megvii Engine Team 已提交
31 32 33
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
34 35 36 37 38 39 40
};
using Flip = FlipForward;

class RotateBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase);
    DEF_OPR_PARAM(Rotate);

41 42 43
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
44 45 46 47 48
};

class RotateForward : public RotateBase {
    DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1);

49
public:
M
Megvii Engine Team 已提交
50 51 52
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
53
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
54 55
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
56

57
protected:
M
Megvii Engine Team 已提交
58 59 60
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
61 62 63 64 65 66 67
};
using Rotate = RotateForward;

class ROICopyBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase);
    DEF_OPR_PARAM(ROICopy);

68 69 70
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
71 72 73 74 75
};

class ROICopyForward : public ROICopyBase {
    DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1);

76
public:
M
Megvii Engine Team 已提交
77 78 79
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
80
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
81 82
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
83

84
protected:
M
Megvii Engine Team 已提交
85 86 87
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
88 89 90 91 92 93 94
};
using ROICopy = ROICopyForward;

class CvtColorBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase);
    DEF_OPR_PARAM(CvtColor);

95 96 97
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
98 99 100 101 102
};

class CvtColorForward : public CvtColorBase {
    DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1);

103
public:
M
Megvii Engine Team 已提交
104 105 106
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
107
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
108 109
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
110

111
protected:
M
Megvii Engine Team 已提交
112 113 114
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
115 116 117 118 119 120 121 122 123 124
};
using CvtColor = CvtColorForward;

/**
 * \brief Applices an affine transformation
 */
class WarpAffineBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase);
    DEF_OPR_PARAM(WarpAffine);

125 126 127 128 129
public:
    using InterpolationMode = Param::InterpolationMode;
    using BorderMode = Param::BorderMode;

protected:
M
Megvii Engine Team 已提交
130 131 132
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& trans,
            const TensorLayout& dst);
133 134
    std::string param_msg() const;
    int get_real_coord(int p, int len);
135 136 137 138 139
};

class WarpAffineForward : public WarpAffineBase {
    DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1);

140
public:
141 142 143 144 145 146 147 148
    /**
     * \param[in] src input tensor
     * \param[in] trans transform matrix tensor
     * \param[in] dst output tensor
     *
     * \warning src, trans, border_value, dst should be contiguous
     * The size of trans is N * 2 * 3
     */
M
Megvii Engine Team 已提交
149 150 151 152 153 154
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& trans,
            const TensorLayout& dst) = 0;
155

156
protected:
M
Megvii Engine Team 已提交
157 158 159
    void check_exec(
            const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst,
            size_t workspace_in_bytes);
160 161 162 163 164 165 166
};
using WarpAffine = WarpAffineForward;

class GaussianBlurBase : public OperatorBase {
    DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase);
    DEF_OPR_PARAM(GaussianBlur);

167 168 169
protected:
    void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst);
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
170 171 172 173 174
};

class GaussianBlurForward : public GaussianBlurBase {
    DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1);

175
public:
M
Megvii Engine Team 已提交
176 177 178
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
179
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
180 181
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
182

183
protected:
M
Megvii Engine Team 已提交
184 185 186
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
};
using GaussianBlur = GaussianBlurForward;

/**
 * \brief Resize opr.
 */
class ResizeBase : public OperatorBase {
    DEF_OPR_PARAM(Resize);
    DEF_OPR_IMPL(ResizeBase, OperatorBase, 1, 1);

public:
    using InterpolationMode = Param::InterpolationMode;

protected:
    //! get origin coord
202 203 204 205 206
    std::pair<float, int> get_cubic_coord(float scale, int idx);

    std::tuple<float, int, float, int> get_nearest_linear_coord(
            InterpolationMode imode, float scale, int size, int idx);

207 208 209
    //! get nearest index in src
    int get_nearest_src(float scale, int size, int idx);

210 211 212 213 214 215 216
    void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst);
};

class ResizeForward : public ResizeBase {
    DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1);

public:
M
Megvii Engine Team 已提交
217 218 219
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
220

M
Megvii Engine Team 已提交
221 222
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& dst) = 0;
223 224

protected:
M
Megvii Engine Team 已提交
225 226 227
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            size_t workspace_in_bytes);
228 229 230 231 232 233 234
};
using Resize = ResizeForward;

class ResizeBackward : public ResizeBase {
    DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1);

public:
M
Megvii Engine Team 已提交
235 236 237
    virtual void exec(
            _megdnn_tensor_in diff, _megdnn_tensor_out grad,
            _megdnn_workspace workspace) = 0;
238

M
Megvii Engine Team 已提交
239 240
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& diff, const TensorLayout& mat) = 0;
241 242

protected:
M
Megvii Engine Team 已提交
243 244 245
    void check_exec(
            const TensorLayout& diff, const TensorLayout& mat,
            size_t workspace_in_bytes);
246 247
};

248 249 250 251 252 253 254 255 256 257 258 259
/**
 * \brief Remap opr.
 */
class RemapBase : public OperatorBase {
    DEF_OPR_PARAM(Remap);
    DEF_OPR_IMPL(RemapBase, OperatorBase, 2, 1);

public:
    using InterpolationMode = Param::InterpolationMode;
    using BorderMode = Param::BorderMode;

protected:
M
Megvii Engine Team 已提交
260 261 262 263 264
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& dst);
    void deduce_layout_fwd(
            const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
265 266 267 268 269 270
};

class RemapForward : public RemapBase {
    DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1);

public:
M
Megvii Engine Team 已提交
271 272 273
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
274

M
Megvii Engine Team 已提交
275 276
    void deduce_layout(
            const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst);
277

M
Megvii Engine Team 已提交
278 279 280
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& dst) = 0;
281 282

protected:
M
Megvii Engine Team 已提交
283 284 285
    void check_exec(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& dst, size_t workspace_in_bytes);
286 287 288
};
using Remap = RemapForward;

289 290 291 292
class RemapBackwardData : public RemapBase {
    DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1);

public:
M
Megvii Engine Team 已提交
293 294 295
    virtual void exec(
            _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
            _megdnn_workspace workspace) = 0;
296

M
Megvii Engine Team 已提交
297 298 299
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& map_xy, const TensorLayout& diff,
            const TensorLayout& grad) = 0;
300 301

protected:
M
Megvii Engine Team 已提交
302 303 304
    void check_exec(
            const TensorLayout& map_xy, const TensorLayout& diff,
            const TensorLayout& grad, size_t workspace_in_bytes);
305 306 307 308 309 310
};

class RemapBackwardMat : public RemapBase {
    DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1);

public:
M
Megvii Engine Team 已提交
311 312 313
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff,
            _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
314

M
Megvii Engine Team 已提交
315 316 317
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& diff, const TensorLayout& grad) = 0;
318 319

protected:
M
Megvii Engine Team 已提交
320 321 322 323
    void check_exec(
            const TensorLayout& src, const TensorLayout& map_xy,
            const TensorLayout& diff, const TensorLayout& grad,
            size_t workspace_in_bytes);
324 325
};

326
class SeparableFilterBase : public OperatorBase {
327 328
    DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase);
    DEF_OPR_PARAM(SeparableFilter);
329 330

protected:
M
Megvii Engine Team 已提交
331 332 333 334 335 336
    void deduce_layout_fwd(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, TensorLayout& dst);
    void check_layout_fwd(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, const TensorLayout& dst);
337 338
};

339
class SeparableFilterForward : public SeparableFilterBase {
340
    DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1);
341 342

public:
M
Megvii Engine Team 已提交
343 344 345 346 347 348 349 350 351 352
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_in filter_x,
            _megdnn_tensor_in filter_y, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
    void deduce_layout(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, TensorLayout& dst);
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, const TensorLayout& dst) = 0;
353 354

protected:
M
Megvii Engine Team 已提交
355 356 357 358
    void check_exec(
            const TensorLayout& src, const TensorLayout& filter_x,
            const TensorLayout& filter_y, const TensorLayout& dst,
            size_t workspace_in_bytes);
359 360 361 362 363 364 365 366
};
using SeparableFilter = SeparableFilterForward;

}  // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"

// vim: syntax=cpp.doxygen