提交 430adf43 编写于 作者: L Liu Yiqun

Move the definition of hl_vec_add/sub/mul/div/max/min to hl_tensor_ops.h

上级 8f5d22b0
...@@ -40,30 +40,6 @@ INLINE real hl_vec_set(const real r) { ...@@ -40,30 +40,6 @@ INLINE real hl_vec_set(const real r) {
return r; return r;
} }
INLINE real hl_vec_max(const real a, const real b) {
return a > b ? a : b;
}
INLINE real hl_vec_min(const real a, const real b) {
return a > b ? b : a;
}
INLINE real hl_vec_add(const real a, const real b) {
return a + b;
}
INLINE real hl_vec_sub(const real a, const real b) {
return a - b;
}
INLINE real hl_vec_mul(const real a, const real b) {
return a * b;
}
INLINE real hl_vec_div(const real a, const real b) {
return a / b;
}
INLINE real hl_vec_classification_error(const real a, INLINE real hl_vec_classification_error(const real a,
const real b, const real b,
const real p, const real p,
......
...@@ -44,31 +44,6 @@ inline float32x4_t hl_vec_set(const real f) { ...@@ -44,31 +44,6 @@ inline float32x4_t hl_vec_set(const real f) {
return vdupq_n_f32(f); return vdupq_n_f32(f);
} }
inline float32x4_t hl_vec_max(const float32x4_t a, const float32x4_t b) {
return vmaxq_f32(a, b);
}
inline float32x4_t hl_vec_min(const float32x4_t a, const float32x4_t b) {
return vminq_f32(a, b);
}
inline float32x4_t hl_vec_add(const float32x4_t a, const float32x4_t b) {
return vaddq_f32(a, b);
}
inline float32x4_t hl_vec_sub(const float32x4_t a, const float32x4_t b) {
return vsubq_f32(a, b);
}
inline float32x4_t hl_vec_mul(const float32x4_t a, const float32x4_t b) {
return vmulq_f32(a, b);
}
inline float32x4_t hl_vec_div(const float32x4_t a, const float32x4_t b) {
float32x4_t tmp = vrecpeq_f32(b);
return vmulq_f32(a, tmp);
}
inline float32x4_t hl_vec_classification_error(const float32x4_t a, inline float32x4_t hl_vec_classification_error(const float32x4_t a,
const float32x4_t b, const float32x4_t b,
const float32x4_t p, const float32x4_t p,
......
...@@ -45,30 +45,6 @@ inline __m128 hl_vec_set(const real f) { ...@@ -45,30 +45,6 @@ inline __m128 hl_vec_set(const real f) {
return _mm_set_ps1(f); return _mm_set_ps1(f);
} }
inline __m128 hl_vec_max(const __m128 a, const __m128 b) {
return _mm_max_ps(a, b);
}
inline __m128 hl_vec_min(const __m128 a, const __m128 b) {
return _mm_min_ps(a, b);
}
inline __m128 hl_vec_add(const __m128 a, const __m128 b) {
return _mm_add_ps(a, b);
}
inline __m128 hl_vec_sub(const __m128 a, const __m128 b) {
return _mm_sub_ps(a, b);
}
inline __m128 hl_vec_mul(const __m128 a, const __m128 b) {
return _mm_mul_ps(a, b);
}
inline __m128 hl_vec_div(const __m128 a, const __m128 b) {
return _mm_div_ps(a, b);
}
inline __m128 hl_vec_classification_error(const __m128 a, inline __m128 hl_vec_classification_error(const __m128 a,
const __m128 b, const __m128 b,
const __m128 p, const __m128 p,
...@@ -103,30 +79,6 @@ inline __m128d hl_vec_set(const real d) { ...@@ -103,30 +79,6 @@ inline __m128d hl_vec_set(const real d) {
#endif #endif
} }
inline __m128d hl_vec_max(const __m128d a, const __m128d b) {
return _mm_max_pd(a, b);
}
inline __m128d hl_vec_min(const __m128d a, const __m128d b) {
return _mm_min_pd(a, b);
}
inline __m128d hl_vec_add(const __m128d a, const __m128d b) {
return _mm_add_pd(a, b);
}
inline __m128d hl_vec_sub(const __m128d a, const __m128d b) {
return _mm_sub_pd(a, b);
}
inline __m128d hl_vec_mul(const __m128d a, const __m128d b) {
return _mm_mul_pd(a, b);
}
inline __m128d hl_vec_div(const __m128d a, const __m128d b) {
return _mm_div_pd(a, b);
}
inline __m128d hl_vec_classification_error(const __m128d a, inline __m128d hl_vec_classification_error(const __m128d a,
const __m128d b, const __m128d b,
const __m128d p, const __m128d p,
......
...@@ -16,13 +16,14 @@ limitations under the License. */ ...@@ -16,13 +16,14 @@ limitations under the License. */
#define HL_MATRIX_BASE_DETAIL_CUH_ #define HL_MATRIX_BASE_DETAIL_CUH_
#include "hl_matrix_type.cuh" #include "hl_matrix_type.cuh"
#include "hl_tensor_ops.h"
namespace aggregate { namespace aggregate {
class SSESum { class SSESum {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_add(a, b); return hppl::binary::add<vecType>()(a, b);
} }
}; };
...@@ -30,7 +31,7 @@ class SSEMax { ...@@ -30,7 +31,7 @@ class SSEMax {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_max(a, b); return hppl::binary::max<vecType>()(a, b);
} }
}; };
...@@ -38,7 +39,7 @@ class SSEMin { ...@@ -38,7 +39,7 @@ class SSEMin {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_min(a, b); return hppl::binary::min<vecType>()(a, b);
} }
}; };
} // namespace aggregate } // namespace aggregate
...@@ -59,7 +60,7 @@ class SSEAdd { ...@@ -59,7 +60,7 @@ class SSEAdd {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_add(a, b); return hppl::binary::add<vecType>()(a, b);
} }
}; };
...@@ -77,7 +78,7 @@ public: ...@@ -77,7 +78,7 @@ public:
mp2 = hl_vec_set(p2); mp2 = hl_vec_set(p2);
} }
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_add(hl_vec_mul(mp1, a), hl_vec_mul(mp2, b)); return hppl::binary::add_scale<vecType>(mp1, mp2)(a, b);
} }
}; };
...@@ -85,7 +86,7 @@ class SSESub { ...@@ -85,7 +86,7 @@ class SSESub {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_sub(a, b); return hppl::binary::sub<vecType>()(a, b);
} }
}; };
...@@ -93,7 +94,7 @@ class SSEMul { ...@@ -93,7 +94,7 @@ class SSEMul {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_mul(a, b); return hppl::binary::mul<vecType>()(a, b);
} }
}; };
...@@ -101,7 +102,7 @@ class SSEDiv { ...@@ -101,7 +102,7 @@ class SSEDiv {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_div(a, b); return hppl::binary::div<vecType>()(a, b);
} }
}; };
...@@ -109,7 +110,8 @@ class SSESquaredDiff { ...@@ -109,7 +110,8 @@ class SSESquaredDiff {
public: public:
static const bool sse = VECTOR_SIMD; static const bool sse = VECTOR_SIMD;
INLINE vecType vecOp(const vecType a, const vecType b) const { INLINE vecType vecOp(const vecType a, const vecType b) const {
return hl_vec_mul(hl_vec_sub(a, b), hl_vec_sub(a, b)); vecType tmp = hppl::binary::sub<vecType>()(a, b);
return hppl::binary::mul<vecType>()(tmp, tmp);
} }
}; };
......
...@@ -38,10 +38,12 @@ typedef double2 vecType; ...@@ -38,10 +38,12 @@ typedef double2 vecType;
#endif #endif
#elif defined(__SSE3__) #elif defined(__SSE3__)
#include "hl_cpu_simd_sse.cuh" #include "hl_cpu_simd_sse.cuh"
#define PADDLE_USE_SSE3
#elif (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !defined(__NVCC__) #elif (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !defined(__NVCC__)
// Currently nvcc does not support neon intrinsic. // Currently nvcc does not support neon intrinsic.
// TODO: Extract simd intrinsic implementation from .cu files. // TODO: Extract simd intrinsic implementation from .cu files.
#include "hl_cpu_simd_neon.cuh" #include "hl_cpu_simd_neon.cuh"
#define PADDLE_USE_NEON
#else #else
#include "hl_cpu_scalar.cuh" #include "hl_cpu_scalar.cuh"
#endif #endif
......
...@@ -328,6 +328,208 @@ public: ...@@ -328,6 +328,208 @@ public:
INLINE T operator()(const T a, const T b) const { return a < b ? b : a; } INLINE T operator()(const T a, const T b) const { return a < b ? b : a; }
}; };
#ifdef PADDLE_USE_SSE3
#ifndef PADDLE_TYPE_DOUBLE
template <>
class add<__m128> {
public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_add_ps(a, b);
}
};
template <>
class add_scale<__m128> {
private:
const __m128 p1;
const __m128 p2;
public:
INLINE add_scale(const __m128 s1, const __m128 s2) : p1(s1), p2(s2) {}
INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_add_ps(_mm_mul_ps(p1, a), _mm_mul_ps(p2, b));
}
};
template <>
class sub<__m128> {
public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_sub_ps(a, b);
}
};
template <>
class mul<__m128> {
public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_mul_ps(a, b);
}
};
template <>
class div<__m128> {
public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_div_ps(a, b);
}
};
template <>
class min<__m128> {
public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_min_ps(a, b);
}
};
template <>
class max<__m128> {
public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_max_ps(a, b);
}
};
#else
template <>
class add<__m128d> {
public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_add_pd(a, b);
}
};
template <>
class add_scale<__m128d> {
private:
const __m128d p1;
const __m128d p2;
public:
INLINE add_scale(const __m128d s1, const __m128d s2) : p1(s1), p2(s2) {}
INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_add_pd(_mm_mul_pd(p1, a), _mm_mul_pd(p2, b));
}
};
template <>
class sub<__m128d> {
public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_sub_pd(a, b);
}
};
template <>
class mul<__m128d> {
public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_mul_pd(a, b);
}
};
template <>
class div<__m128d> {
public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_div_pd(a, b);
}
};
template <>
class min<__m128d> {
public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_min_pd(a, b);
}
};
template <>
class max<__m128d> {
public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_max_pd(a, b);
}
};
#endif // PADDLE_TYPE_DOUBLE
#endif // PADDLE_USE_SSE3
#ifdef PADDLE_USE_NEON
#ifndef PADDLE_TYPE_DOUBLE
template <>
class add<float32x4_t> {
public:
INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const {
return vmulq_f32(a, b);
}
};
template <>
class add_scale<float32x4_t> {
private:
const float32x4_t p1;
const float32x4_t p2;
public:
INLINE add_scale(const float32x4_t s1, const float32x4_t s2)
: p1(s1), p2(s2) {}
INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const {
return vaddq_f32(vmulq_f32(p1, a), vmulq_f32(p2, b));
}
};
template <>
class sub<float32x4_t> {
public:
INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const {
return vsubq_f32(a, b);
}
};
template <>
class mul<float32x4_t> {
public:
INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const {
return vmulq_f32(a, b);
}
};
template <>
class div<float32x4_t> {
public:
INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const {
float32x4_t tmp = vrecpeq_f32(b);
return vmulq_f32(a, tmp);
}
};
template <>
class min<float32x4_t> {
public:
INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const {
return vminq_f32(a, b);
}
};
template <>
class max<float32x4_t> {
public:
INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const {
return vmaxq_f32(a, b);
}
}
#else
#error To be implemented
#endif // PADDLE_TYPE_DOUBLE
#endif // PADDLE_USE_NEON
} // namespace binary } // namespace binary
} // namespace hppl } // namespace hppl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册