avx_helper.h 3.4 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/x86/avx_helper.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once

#include "megdnn/arch.h"

#include <immintrin.h>
16
#ifdef WIN32
17 18 19
#include <avxintrin.h>
#include <avx2intrin.h>
#include <fmaintrin.h>
20
#endif
21

22 23 24 25
#if !defined (__clang__)
#pragma GCC target ("avx")
#endif

26 27 28 29 30 31 32 33 34 35
namespace megdnn {
namespace x86 {

MEGDNN_ATTRIBUTE_TARGET("avx")
static inline __m256 _mm256_loadu2_m128_emulate(
        const float *hiaddr, const float *loaddr) {
    return _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_loadu_ps(loaddr)),
            _mm_loadu_ps(hiaddr), 1);
}

36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
template <typename ctype, size_t len>
struct Vector;

template <>
struct Vector<float, 8> {
    __m256 value;
    Vector() {}
    Vector(const float v) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = _mm256_set1_ps(v);
    }
    Vector(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = lr.value;
    }
    Vector(const Vector&& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = std::move(lr.value);
    }
    Vector(const __m256& v) MEGDNN_ATTRIBUTE_TARGET("avx") { value = v; }
    static Vector load(const float* addr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        Vector v;
        v.value = _mm256_loadu_ps(addr);
        return v;
    }
    static void save(float* addr, const Vector& v)
            MEGDNN_ATTRIBUTE_TARGET("avx") {
        _mm256_storeu_ps(addr, v.value);
    }
    void save(float* addr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        save(addr, *this);
    }
    Vector operator+(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        Vector dst;
        dst.value = _mm256_add_ps(value, lr.value);
        return dst;
    }
    Vector& operator+=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = _mm256_add_ps(value, lr.value);
        return *this;
    }
    Vector operator-(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        Vector dst;
        dst.value = _mm256_sub_ps(value, lr.value);
        return dst;
    }
    Vector& operator-=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = _mm256_sub_ps(value, lr.value);
        return *this;
    }
    Vector operator*(float lr)MEGDNN_ATTRIBUTE_TARGET("avx") {
        Vector dst;
        dst.value = _mm256_mul_ps(value, _mm256_set1_ps(lr));
        return dst;
    }
    Vector operator*(const Vector& lr)MEGDNN_ATTRIBUTE_TARGET("avx") {
        Vector dst;
        dst.value = _mm256_mul_ps(value, lr.value);
        return dst;
    }
    Vector& operator*=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = _mm256_mul_ps(value, lr.value);
        return *this;
    }
    Vector& operator=(const Vector& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = lr.value;
        return *this;
    }
    Vector& operator=(const Vector&& lr) MEGDNN_ATTRIBUTE_TARGET("avx") {
        value = std::move(lr.value);
        return *this;
    }
    Vector operator-() MEGDNN_ATTRIBUTE_TARGET("avx") {
        Vector dst;
        dst.value = -value;
        return dst;
    }
};

#if !defined (__clang__)
#pragma GCC reset_options
#endif

116 117 118 119
} // namespace x86
} // namespace megdnn

// vim: syntax=cpp.doxygen