utils.h 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#pragma once

#include <cstring>
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"

namespace megdnn {
namespace fallback {

template <typename ctype, size_t len>
struct Vector;

template <>
struct Vector<float, 4> {
15
    GI_FLOAT32_FIXLEN_t value;
16
    Vector() {}
17
    Vector(const float v) { value = GiFloat32Type2FixLenType(GiBroadcastFloat32(v)); }
18 19
    Vector(const Vector& lr) { value = lr.value; }
    Vector(const Vector&& lr) { value = std::move(lr.value); }
20
    Vector(const GI_FLOAT32_t& v) { value = GiFloat32Type2FixLenType(v); }
21 22
    static Vector load(const float* addr) {
        Vector v;
23
        v.value = GiFloat32Type2FixLenType(GiLoadFloat32(addr));
24 25
        return v;
    }
26 27 28
    static void save(float* addr, const Vector& v) {
        GiStoreFloat32(addr, GiFixLenType2GiFloat32Type(v.value));
    }
29 30 31
    void save(float* addr) { save(addr, *this); }
    Vector operator+(const Vector& lr) {
        Vector dst;
32 33 34
        dst.value = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value),
                GiFixLenType2GiFloat32Type(lr.value)));
35 36 37
        return dst;
    }
    Vector& operator+=(const Vector& lr) {
38 39 40
        value = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value),
                GiFixLenType2GiFloat32Type(lr.value)));
41 42 43 44
        return *this;
    }
    Vector operator-(const Vector& lr) {
        Vector dst;
45 46 47
        dst.value = GiFloat32Type2FixLenType(GiSubtractFloat32(
                GiFixLenType2GiFloat32Type(value),
                GiFixLenType2GiFloat32Type(lr.value)));
48 49 50
        return dst;
    }
    Vector& operator-=(const Vector& lr) {
51 52 53
        value = GiFloat32Type2FixLenType(GiSubtractFloat32(
                GiFixLenType2GiFloat32Type(value),
                GiFixLenType2GiFloat32Type(lr.value)));
54 55 56 57
        return *this;
    }
    Vector operator*(float lr) {
        Vector dst;
58 59
        dst.value = GiFloat32Type2FixLenType(
                GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value), lr));
60 61 62 63
        return dst;
    }
    Vector operator*(const Vector& lr) {
        Vector dst;
64 65 66
        dst.value = GiFloat32Type2FixLenType(GiMultiplyFloat32(
                GiFixLenType2GiFloat32Type(value),
                GiFixLenType2GiFloat32Type(lr.value)));
67 68 69
        return dst;
    }
    Vector& operator*=(const Vector& lr) {
70 71 72
        value = GiFloat32Type2FixLenType(GiMultiplyFloat32(
                GiFixLenType2GiFloat32Type(value),
                GiFixLenType2GiFloat32Type(lr.value)));
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
        return *this;
    }
    Vector& operator=(const Vector& lr) {
        value = lr.value;
        return *this;
    }
    Vector& operator=(const Vector&& lr) {
        value = std::move(lr.value);
        return *this;
    }
    Vector operator-() {
        Vector dst;
        dst.value = -value;
        return dst;
    }
};

template <>
struct Vector<float, 8> {
92
    GI_FLOAT32_FIXLEN_V2_t value;
93 94
    Vector() {}
    Vector(const float v) {
95 96
        value.val[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v));
        value.val[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(v));
97 98 99
    }
    Vector(const Vector& lr) { value = lr.value; }
    Vector(const Vector&& lr) { value = std::move(lr.value); }
100
    Vector(const GI_FLOAT32_V2_t& v) { value = GiFloat32Type2FixLenV2Type(v); }
101 102
    static Vector load(const float* addr) {
        Vector v;
103
        v.value = GiFloat32Type2FixLenV2Type(GiLoadFloat32V2(addr));
104 105
        return v;
    }
106 107 108
    static void save(float* addr, const Vector& v) {
        GiStoreFloat32V2(addr, GiFixLenType2GiFloat32V2Type(v.value));
    }
109 110 111 112

    void save(float* addr) { save(addr, *this); }
    Vector operator+(const Vector& lr) {
        Vector dst;
113 114 115 116 117 118
        dst.value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0])));
        dst.value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1])));
119 120 121
        return dst;
    }
    Vector& operator+=(const Vector& lr) {
122 123 124 125 126 127
        value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0])));
        value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1])));
128 129 130
        return *this;
    }
    Vector& add(const Vector& lr) {
131 132 133 134 135 136
        value.val[0] = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0])));
        value.val[1] = GiFloat32Type2FixLenType(GiAddFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1])));
137 138 139 140
        return *this;
    }
    Vector operator-(const Vector& lr) {
        Vector dst;
141 142 143 144 145 146
        dst.value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0])));
        dst.value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1])));
147 148 149
        return dst;
    }
    Vector& operator-=(const Vector& lr) {
150 151 152 153 154 155
        value.val[0] = GiFloat32Type2FixLenType(GiSubtractFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0])));
        value.val[1] = GiFloat32Type2FixLenType(GiSubtractFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1])));
156 157 158 159
        return *this;
    }
    Vector operator*(float lr) {
        Vector dst;
160 161 162 163
        dst.value.val[0] = GiFloat32Type2FixLenType(
                GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[0]), lr));
        dst.value.val[1] = GiFloat32Type2FixLenType(
                GiMultiplyScalerFloat32(GiFixLenType2GiFloat32Type(value.val[1]), lr));
164 165 166 167
        return dst;
    }
    //! val + lr * n
    Vector& mla(const Vector& lr, float n) {
168 169 170 171 172 173
        value.val[0] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0]), n));
        value.val[1] = GiFloat32Type2FixLenType(GiMultiplyAddScalarFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1]), n));
174 175 176 177 178
        return *this;
    }

    Vector operator*(const Vector& lr) {
        Vector dst;
179 180 181 182 183 184
        dst.value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0])));
        dst.value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1])));
185 186 187
        return dst;
    }
    Vector& operator*=(const Vector& lr) {
188 189 190 191 192 193
        value.val[0] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
                GiFixLenType2GiFloat32Type(value.val[0]),
                GiFixLenType2GiFloat32Type(lr.value.val[0])));
        value.val[1] = GiFloat32Type2FixLenType(GiMultiplyFloat32(
                GiFixLenType2GiFloat32Type(value.val[1]),
                GiFixLenType2GiFloat32Type(lr.value.val[1])));
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        return *this;
    }
    Vector& operator=(const Vector& lr) {
        value = lr.value;
        return *this;
    }
    Vector& operator=(const Vector&& lr) {
        value = std::move(lr.value);
        return *this;
    }
    Vector operator-() {
        Vector dst;
        dst.value.val[0] = -value.val[0];
        dst.value.val[1] = -value.val[1];
        return dst;
    }
};

}  // namespace fallback
}  // namespace megdnn

// vim: syntax=cpp.doxygen