提交 7e349c25 编写于 作者: l702572275's avatar l702572275

Modify the avx template

上级 2175fea8
......@@ -20,7 +20,7 @@ limitations under the License.
namespace oneflow {
#ifdef WITH_AVX
template<typename T>
template<typename T, typename Enable = void>
class VectorizedAvx2 {
public:
static void fmadd(size_t begin, size_t end, const T* x, const T* y, T* out, T alpha);
......
......@@ -22,11 +22,10 @@ namespace oneflow {
#ifdef WITH_AVX
#include <immintrin.h>
template<>
class VectorizedAvx2<double> {
template<typename T>
class VectorizedAvx2<T, typename std::enable_if<std::is_same<T, double>::value>::type> {
public:
static void fmadd(size_t begin, size_t end, const double* x, const double* y, double* out,
double alpha) {
static void fmadd(size_t begin, size_t end, const T* x, const T* y, T* out, T alpha) {
size_t i = begin;
size_t stride = 4;
......@@ -51,7 +50,7 @@ class VectorizedAvx2<double> {
}
}
static void add(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void add(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 4;
......@@ -74,7 +73,7 @@ class VectorizedAvx2<double> {
}
}
static void sub(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void sub(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 4;
......@@ -97,7 +96,7 @@ class VectorizedAvx2<double> {
}
}
static void mul(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void mul(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 4;
......@@ -120,7 +119,7 @@ class VectorizedAvx2<double> {
}
}
static void div(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void div(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 4;
......
......@@ -22,11 +22,10 @@ namespace oneflow {
#ifdef WITH_AVX
#include <immintrin.h>
template<>
class VectorizedAvx2<float> {
template<typename T>
class VectorizedAvx2<T, typename std::enable_if<std::is_same<T, float>::value>::type> {
public:
static void fmadd(size_t begin, size_t end, const float* x, const float* y, float* out,
float alpha) {
static void fmadd(size_t begin, size_t end, const T* x, const T* y, T* out, T alpha) {
size_t i = begin;
size_t stride = 8;
......@@ -49,7 +48,7 @@ class VectorizedAvx2<float> {
for (; i < end; i++) { out[i] = x[i] * alpha + y[i]; }
}
}
static void add(int64_t begin, int64_t end, const float* x, const float* y, float* out) {
static void add(int64_t begin, int64_t end, const T* x, const T* y, T* out) {
int64_t i = begin;
int64_t stride = 8;
......@@ -71,7 +70,7 @@ class VectorizedAvx2<float> {
}
}
static void sub(size_t begin, size_t end, const float* x, const float* y, float* out) {
static void sub(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 8;
......@@ -93,7 +92,7 @@ class VectorizedAvx2<float> {
}
}
static void mul(int64_t begin, int64_t end, const float* x, const float* y, float* out) {
static void mul(int64_t begin, int64_t end, const T* x, const T* y, T* out) {
int64_t i = begin;
int64_t stride = 8;
......@@ -115,7 +114,7 @@ class VectorizedAvx2<float> {
}
}
static void div(size_t begin, size_t end, const float* x, const float* y, float* out) {
static void div(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 8;
......
......@@ -21,70 +21,27 @@ namespace oneflow {
#ifdef WITH_AVX
#include <immintrin.h>
template<>
class VectorizedAvx2<int8_t> {
public:
static void add(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] + y[i]; }
}
static void sub(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] - y[i]; }
}
static void mul(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] * y[i]; }
}
static void div(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] / y[i]; }
}
};
template<>
class VectorizedAvx2<int32_t> {
public:
static void add(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] + y[i]; }
}
static void sub(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] - y[i]; }
}
static void mul(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] * y[i]; }
}
static void div(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] / y[i]; }
}
};
template<>
class VectorizedAvx2<int64_t> {
template<typename T>
class VectorizedAvx2<T, typename std::enable_if<!(std::is_same<T, float>::value
|| std::is_same<T, double>::value)>::type> {
public:
// static void fmadd(size_t begin, size_t end, const int32_t * x, const int32_t * y, int32_t *out,
// int32_t alpha)
// {
// for (size_t i = begin; i <= end; i ++) {
// out[i] = x[i] * alpha + y[i];
// }
// static void fmadd(size_t begin, size_t end, const T* x, const T* y, T* out, T alpha) {
// }
static void add(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
static void add(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] + y[i]; }
}
static void sub(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
static void sub(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] - y[i]; }
}
static void mul(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
static void mul(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] * y[i]; }
}
static void div(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
static void div(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] / y[i]; }
}
};
......
......@@ -20,17 +20,13 @@ limitations under the License.
namespace oneflow {
#ifdef WITH_AVX
template<typename T>
template<typename T, typename Enable = void>
class VectorizedAvx512 {
public:
static void fmadd(size_t begin, size_t end, const T* x, const T* y, T* out, T alpha);
static void add(size_t begin, size_t end, const T* x, const T* y, T* out);
static void sub(size_t begin, size_t end, const T* x, const T* y, T* out);
static void mul(size_t begin, size_t end, const T* x, const T* y, T* out);
static void div(size_t begin, size_t end, const T* x, const T* y, T* out);
};
......
......@@ -22,11 +22,10 @@ namespace oneflow {
#ifdef WITH_AVX
#include <immintrin.h>
template<>
class VectorizedAvx512<double> {
template<typename T>
class VectorizedAvx512<T, typename std::enable_if<std::is_same<T, double>::value>::type> {
public:
static void fmadd(size_t begin, size_t end, const double* x, const double* y, double* out,
double alpha) {
static void fmadd(size_t begin, size_t end, const T* x, const T* y, T* out, T alpha) {
size_t i = begin;
size_t stride = 8;
......@@ -51,7 +50,7 @@ class VectorizedAvx512<double> {
}
}
static void add(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void add(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 8;
......@@ -74,7 +73,7 @@ class VectorizedAvx512<double> {
}
}
static void sub(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void sub(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 8;
......@@ -97,7 +96,7 @@ class VectorizedAvx512<double> {
}
}
static void mul(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void mul(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 8;
......@@ -120,7 +119,7 @@ class VectorizedAvx512<double> {
}
}
static void div(size_t begin, size_t end, const double* x, const double* y, double* out) {
static void div(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 8;
......
......@@ -23,11 +23,10 @@ namespace oneflow {
#ifdef WITH_AVX
#include <immintrin.h>
template<>
class VectorizedAvx512<float> {
template<typename T>
class VectorizedAvx512<T, typename std::enable_if<std::is_same<T, float>::value>::type> {
public:
static void fmadd(size_t begin, size_t end, const float* x, const float* y, float* out,
float alpha) {
static void fmadd(size_t begin, size_t end, const T* x, const T* y, T* out, T alpha) {
size_t i = begin;
size_t stride = 16;
......@@ -51,7 +50,7 @@ class VectorizedAvx512<float> {
}
}
static void add(size_t begin, size_t end, const float* x, const float* y, float* out) {
static void add(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 16;
......@@ -73,7 +72,7 @@ class VectorizedAvx512<float> {
}
}
static void sub(size_t begin, size_t end, const float* x, const float* y, float* out) {
static void sub(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 16;
......@@ -95,7 +94,7 @@ class VectorizedAvx512<float> {
}
}
static void mul(size_t begin, size_t end, const float* x, const float* y, float* out) {
static void mul(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 16;
......@@ -117,7 +116,7 @@ class VectorizedAvx512<float> {
}
}
static void div(size_t begin, size_t end, const float* x, const float* y, float* out) {
static void div(size_t begin, size_t end, const T* x, const T* y, T* out) {
size_t i = begin;
size_t stride = 16;
......
......@@ -22,62 +22,23 @@ namespace oneflow {
#ifdef WITH_AVX
#include <immintrin.h>
template<>
class VectorizedAvx512<int8_t> {
template<typename T>
class VectorizedAvx512<T, typename std::enable_if<!(std::is_same<T, float>::value
|| std::is_same<T, double>::value)>::type> {
public:
static void add(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
static void add(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] + y[i]; }
}
static void sub(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
static void sub(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] - y[i]; }
}
static void mul(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
static void mul(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] * y[i]; }
}
static void div(size_t begin, size_t end, const int8_t* x, const int8_t* y, int8_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] / y[i]; }
}
};
template<>
class VectorizedAvx512<int> {
public:
static void add(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] + y[i]; }
}
static void sub(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] - y[i]; }
}
static void mul(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] * y[i]; }
}
static void div(size_t begin, size_t end, const int32_t* x, const int32_t* y, int32_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] / y[i]; }
}
};
template<>
class VectorizedAvx512<int64_t> {
public:
static void add(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] + y[i]; }
}
static void sub(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] - y[i]; }
}
static void mul(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] * y[i]; }
}
static void div(size_t begin, size_t end, const int64_t* x, const int64_t* y, int64_t* out) {
static void div(size_t begin, size_t end, const T* x, const T* y, T* out) {
for (size_t i = begin; i <= end; i++) { out[i] = x[i] / y[i]; }
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册