linalg.h 7.8 KB
Newer Older
1 2 3 4
/**
 * \file dnn/include/megdnn/oprs/linalg.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include "megdnn/internal/opr_header_prologue.h"

namespace megdnn {

class BatchedMatrixMulForward
        : public OperatorBase,
          public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> {
    DEF_OPR_PARAM(MatrixMul);
    DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1);

public:
    /**
     * \brief C = op(A) * op(B)
     * \param A (B, m, k) if transposeA is false, (B, k, m) otherwise
     * \param B (B, k, n) if transposeB is false, (B, n, k) otherwise
     * \param C (B, m, n)
     *
     * A, B, C must be 3-dimensional and C must be contiguous. A and B must
     * have stride[2] == 1, and stride[1] >= shape[2],
     * and stride[0] >= shape[1] * stride[1]
     *
     * op(A) = A if transposeA is false, otherwise op(A) = A^t.
     * op(B) = B if transposeB is false, otherwise op(B) = B^t.
     */
M
Megvii Engine Team 已提交
36 37 38 39 40 41 42
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
    void deduce_dtype(DType A, DType B, DType& C);
    void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
43

44 45 46 47
    static Algorithm::OprType get_opr_type() {
        return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD;
    }

48
protected:
M
Megvii Engine Team 已提交
49 50 51
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
};
using BatchedMatrixMul = BatchedMatrixMulForward;

class MatrixMulForward : public OperatorBase,
                         public detail::MultiAlgoOpr<MatrixMulForward, 3> {
    DEF_OPR_PARAM(MatrixMul);
    DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1);

public:
    /**
     * \brief C = op(A) * op(B)
     * \param A (m, k) if transposeA is false, (k, m) otherwise
     * \param B (k, n) if transposeB is false, (n, k) otherwise
     * \param C (m, n)
     *
     * A, B, C must be 2-dimensional and C must be contiguous. A and B must
     * have stride[1] == 1, and stride[0] >= shape[1]
     *
     * op(A) = A if transposeA is false, otherwise op(A) = A^t.
     * op(B) = B if transposeB is false, otherwise op(B) = B^t.
     */
M
Megvii Engine Team 已提交
73 74 75
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
76
    void deduce_dtype(DType A, DType B, DType& C);
M
Megvii Engine Team 已提交
77 78 79
    void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
80

M
Megvii Engine Team 已提交
81
    static size_t pack_size(const Param::Format format);
82 83 84 85 86

    static Algorithm::OprType get_opr_type() {
        return Algorithm::OprType::MATRIX_MUL_FORWARD;
    }

87
protected:
M
Megvii Engine Team 已提交
88 89 90
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
};
using MatrixMul = MatrixMulForward;

/*!
 * \brief compute the inverse of a batch of matrices
 *
 * Input and output tensors have the same shape [..., n, n] where the last two
 * dimensions represent the matrices.
 *
 * Currently only float32 is supported.
 */
class MatrixInverse : public OperatorBase {
    DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1);
    DEF_OPR_PARAM(Empty);

public:
M
Megvii Engine Team 已提交
107 108 109
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out dst,
            _megdnn_workspace workspace) = 0;
110
    void deduce_layout(const TensorLayout& src, TensorLayout& dst);
M
Megvii Engine Team 已提交
111
    size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst);
112 113 114 115 116 117 118

protected:
    /*!
     * \brief get canonized params; throw exception on error.
     *
     * Note that \p batch and \p n can be null
     */
M
Megvii Engine Team 已提交
119
    static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n);
120 121 122 123 124 125 126

    /*!
     * \brief canonize and validate input params for exec() impls
     *
     * Since get_workspace_in_bytes() would be called, \p batch and \p n can not
     * be null
     */
M
Megvii Engine Team 已提交
127 128 129
    void check_exec(
            const TensorLayout& src, const TensorLayout& dst,
            _megdnn_workspace workspace, size_t* batch, size_t* n);
130

M
Megvii Engine Team 已提交
131 132
    virtual size_t get_workspace_in_bytes(
            size_t batch, size_t n, size_t dtype_size) = 0;
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
};

//! inter-product of two vectors
class DotForward : public OperatorBase {
    DEF_OPR_PARAM(Empty);
    DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1);

public:
    /**
     * \param[in] A
     * \param[in] B
     * \param[out] C
     *
     * Calculating the dot product of A and B and store it in C.
     * A, B, C must be contiguous. A and B must have the same 1-dimensional
     * shape and non-negative strides. C must be scalar.
     */
M
Megvii Engine Team 已提交
150 151 152
    virtual void exec(
            _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
            _megdnn_workspace workspace) = 0;
153 154
    MGE_WIN_DECLSPEC_FUC void deduce_layout(
            const TensorLayout& A, const TensorLayout& B, TensorLayout& C);
M
Megvii Engine Team 已提交
155 156
    virtual size_t get_workspace_in_bytes(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0;
157 158

protected:
M
Megvii Engine Team 已提交
159 160 161
    void check_exec(
            const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
            size_t workspace_in_bytes);
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
};
using Dot = DotForward;

/*!
 * \brief Compute the singular value decomposition of a batch of matrices
 *
 * Input tensors have the shape [..., m, n], where the last two
 * dimensions represent the matrices. For the output tensor u, s, vt,
 * the following equation holds: u * diag(s) * vt == src.
 *
 * Currently only float32 is supported.
 */
class SVDForward : public OperatorBase {
    DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3);
    DEF_OPR_PARAM(SVD);

public:
    /**
     * \brief u, s, vt = SVD(src) and u * diag(s) * vt == src
     * \param src (..., m, n) The input tensor, let p = min(m, n)
     * \param u (..., m, p) if full_matrices is false,
                (..., m, m) if full_matrices is true,
                empty tensor if compute_uv is false.
                The left singular vector.

     * \param s (..., p) The singular values.
     * \param vt (..., p, n) if full_matrices is false,
                 (..., n, n) if full_matrices is true,
                 empty tensor if compute_uv is false.
                 The right singular vector.
     *
     * src must be contiguous. The computation might be significantly faster
     * if compute_uv is false (default to true).
     *
     */
M
Megvii Engine Team 已提交
197 198 199 200 201 202 203 204 205
    virtual void exec(
            _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s,
            _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0;
    void deduce_layout(
            const TensorLayout& src, TensorLayout& u, TensorLayout& s,
            TensorLayout& vt);
    size_t get_workspace_in_bytes(
            const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
            const TensorLayout& vt);
206 207

protected:
M
Megvii Engine Team 已提交
208 209 210 211 212 213 214
    static void canonize_params(
            const TensorLayout& layout, size_t* batch, size_t* m, size_t* n);
    virtual size_t get_workspace_in_bytes(
            size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0;
    void check_exec(
            const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
            const TensorLayout& vt, size_t workspace_in_bytes);
215 216 217 218 219 220 221 222 223
};

using SVD = SVDForward;

}  // namespace megdnn

#include "megdnn/internal/opr_header_epilogue.h"

// vim: syntax=cpp.doxygen