diff --git a/paddle/cuda/include/hl_cpu_scalar.cuh b/paddle/cuda/include/hl_cpu_scalar.cuh index c5e58015f3192e38a916fcae78c3c389b270b0b0..cddf08ce6b615dbb62885daf29c8c378892d40bf 100644 --- a/paddle/cuda/include/hl_cpu_scalar.cuh +++ b/paddle/cuda/include/hl_cpu_scalar.cuh @@ -15,25 +15,60 @@ limitations under the License. */ #ifndef HL_CPU_SCALAR_CUH_ #define HL_CPU_SCALAR_CUH_ +#define VECTOR_SIMD false +#define VECTOR_SET hl_vec_set + #ifndef PADDLE_TYPE_DOUBLE /* size of float */ -#define VECTOR_SIZE 4 +#define VECTOR_SIZE 4 #else /* size of double */ -#define VECTOR_SIZE 8 +#define VECTOR_SIZE 8 #endif typedef real vecType; -inline void set_zero(vecType &mm) { mm = (vecType) 0.0f; } - /* Consider a real as a vector */ -#define VECTOR_LEN 1 -#define VECTOR_SET set_zero +#define VECTOR_LEN 1 template inline real hl_agg_op(Agg agg, vecType mm) { return mm; } +INLINE real hl_vec_set(const real r) { + return r; +} + +INLINE real hl_vec_max(const real a, const real b) { + return a > b ? a : b; +} + +INLINE real hl_vec_min(const real a, const real b) { + return a > b ? b : a; +} + +INLINE real hl_vec_add(const real a, const real b) { + return a + b; +} + +INLINE real hl_vec_sub(const real a, const real b) { + return a - b; +} + +INLINE real hl_vec_mul(const real a, const real b) { + return a * b; +} + +INLINE real hl_vec_div(const real a, const real b) { + return a / b; +} + +INLINE real hl_vec_classification_error(const real a, + const real b, + const real p, + const real r) { + return ((a > p) == (b > p)) ? 0.0f : 1.0f; +} + #endif // HL_CPU_SCALAR_CUH_ diff --git a/paddle/cuda/include/hl_cpu_simd_neon.cuh b/paddle/cuda/include/hl_cpu_simd_neon.cuh index aaba35df09167ea575789a2895fcd92f94216eb9..9ff360c576fe1f948999144e6fcba8c111d29008 100644 --- a/paddle/cuda/include/hl_cpu_simd_neon.cuh +++ b/paddle/cuda/include/hl_cpu_simd_neon.cuh @@ -17,15 +17,16 @@ limitations under the License. */ #include -#define VECTOR_SIZE 16 +#define VECTOR_SIMD true +#define VECTOR_SIZE 16 +#define VECTOR_SET hl_vec_set #ifndef PADDLE_TYPE_DOUBLE typedef float32x4_t vecType; /* number of float in vector */ -#define VECTOR_LEN 4 -#define VECTOR_SET vdupq_n_f32 +#define VECTOR_LEN 4 template inline real hl_agg_op(Agg agg, vecType mm) { @@ -39,19 +40,58 @@ inline real hl_agg_op(Agg agg, vecType mm) { return vgetq_lane_f32(ret, 0); } +inline float32x4_t hl_vec_set(const real f) { + return vdupq_n_f32(f); +} + +inline float32x4_t hl_vec_max(const float32x4_t a, const float32x4_t b) { + return vmaxq_f32(a, b); +} + +inline float32x4_t hl_vec_min(const float32x4_t a, const float32x4_t b) { + return vminq_f32(a, b); +} + +inline float32x4_t hl_vec_add(const float32x4_t a, const float32x4_t b) { + return vaddq_f32(a, b); +} + +inline float32x4_t hl_vec_sub(const float32x4_t a, const float32x4_t b) { + return vsubq_f32(a, b); +} + +inline float32x4_t hl_vec_mul(const float32x4_t a, const float32x4_t b) { + return vmulq_f32(a, b); +} + +inline float32x4_t hl_vec_div(const float32x4_t a, const float32x4_t b) { + float32x4_t tmp = vrecpeq_f32(b); + return vmulq_f32(a, tmp); +} + +inline float32x4_t hl_vec_classification_error(const float32x4_t a, + const float32x4_t b, + const float32x4_t p, + const float32x4_t r) { + uint32x4_t tmp1 = vcgtq_f32(a, p); + uint32x4_t tmp2 = vcgtq_f32(b, p); + uint32x4_t tmp3 = veorq_u32(tmp1, tmp2); + return vcvtq_f32_u32(vandq_u32(tmp3, vcvtq_u32_f32(r))); +} + #else #ifdef __aarch64__ typedef float64x2_t vecType; /* number of float in vector */ -#define VECTOR_LEN 2 -#define VECTOR_SET vdupq_n_f64 +#define VECTOR_LEN 2 +#define VECTOR_SET vdupq_n_f64 #error To be implemented #else #error NEON instructions does not support double precision -#endif +#endif // __aarch64__ #endif diff --git a/paddle/cuda/include/hl_cpu_simd_sse.cuh b/paddle/cuda/include/hl_cpu_simd_sse.cuh index 99286c1a3f07d22aa10cdfde176e4ea812ab29c6..9a918770b14d0c106b89bfbc82ef5f7feec17b8c 100644 --- a/paddle/cuda/include/hl_cpu_simd_sse.cuh +++ b/paddle/cuda/include/hl_cpu_simd_sse.cuh @@ -12,22 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef HL_SIMD_SSE_CUH_ -#define HL_SIMD_SSE_CUH_ +#ifndef HL_CPU_SIMD_SSE_CUH_ +#define HL_CPU_SIMD_SSE_CUH_ #include #include #include -#define VECTOR_SIZE 16 +#define VECTOR_SIMD true +#define VECTOR_SIZE 16 +#define VECTOR_SET hl_vec_set #ifndef PADDLE_TYPE_DOUBLE typedef __m128 vecType; /* number of float in vector */ -#define VECTOR_LEN 4 -#define VECTOR_SET _mm_set_ps1 +#define VECTOR_LEN 4 template inline real hl_agg_op(Agg agg, vecType mm) { @@ -40,16 +41,50 @@ inline real hl_agg_op(Agg agg, vecType mm) { return _mm_cvtss_f32(ret); } +inline __m128 hl_vec_set(const real f) { + return _mm_set_ps1(f); +} + +inline __m128 hl_vec_max(const __m128 a, const __m128 b) { + return _mm_max_ps(a, b); +} + +inline __m128 hl_vec_min(const __m128 a, const __m128 b) { + return _mm_min_ps(a, b); +} + +inline __m128 hl_vec_add(const __m128 a, const __m128 b) { + return _mm_add_ps(a, b); +} + +inline __m128 hl_vec_sub(const __m128 a, const __m128 b) { + return _mm_sub_ps(a, b); +} + +inline __m128 hl_vec_mul(const __m128 a, const __m128 b) { + return _mm_mul_ps(a, b); +} + +inline __m128 hl_vec_div(const __m128 a, const __m128 b) { + return _mm_div_ps(a, b); +} + +inline __m128 hl_vec_classification_error(const __m128 a, + const __m128 b, + const __m128 p, + const __m128 r) { + __m128 tmp1 = _mm_cmpgt_ps(a, p); + __m128 tmp2 = _mm_cmpgt_ps(b, p); + __m128 tmp3 = _mm_xor_ps(tmp1, tmp2); + return _mm_and_ps(tmp3, r); +} + #else typedef __m128d vecType; /* number of double in vector */ -#define VECTOR_LEN 2 -#if defined(__APPLE__) || defined(__OSX__) -#define _mm_set_pd1 _mm_set1_pd -#endif -#define VECTOR_SET _mm_set_pd1 +#define VECTOR_LEN 2 template inline real hl_agg_op(Agg agg, vecType mm) { @@ -60,6 +95,48 @@ inline real hl_agg_op(Agg agg, vecType mm) { return _mm_cvtsd_f64(ret); } +inline __m128d hl_vec_set(const real d) { +#if defined(__APPLE__) || defined(__OSX__) + return _mm_set1_pd(d); +#else + return _mm_set_pd1(d); +#endif +} + +inline __m128d hl_vec_max(const __m128d a, const __m128d b) { + return _mm_max_pd(a, b); +} + +inline __m128d hl_vec_min(const __m128d a, const __m128d b) { + return _mm_min_pd(a, b); +} + +inline __m128d hl_vec_add(const __m128d a, const __m128d b) { + return _mm_add_pd(a, b); +} + +inline __m128d hl_vec_sub(const __m128d a, const __m128d b) { + return _mm_sub_pd(a, b); +} + +inline __m128d hl_vec_mul(const __m128d a, const __m128d b) { + return _mm_mul_pd(a, b); +} + +inline __m128d hl_vec_div(const __m128d a, const __m128d b) { + return _mm_div_pd(a, b); +} + +inline __m128d hl_vec_classification_error(const __m128d a, + const __m128d b, + const __m128d p, + const __m128d r) { + __m128d tmp1 = _mm_cmpgt_pd(a, p); + __m128d tmp2 = _mm_cmpgt_pd(b, p); + __m128d tmp3 = _mm_xor_pd(tmp1, tmp2); + return _mm_and_pd(tmp3, r); +} + #endif -#endif // HL_SIMD_SSE_CUH_ +#endif // HL_CPU_SIMD_SSE_CUH_ diff --git a/paddle/cuda/include/hl_matrix_base.cuh b/paddle/cuda/include/hl_matrix_base.cuh index 545120128b41d919d9df4ec179b85997603a05f2..53fdb47ec9c05f5cf85d0956176ad9abf6d656f9 100644 --- a/paddle/cuda/include/hl_matrix_base.cuh +++ b/paddle/cuda/include/hl_matrix_base.cuh @@ -18,26 +18,6 @@ limitations under the License. */ #include "hl_matrix_type.cuh" -#ifdef __CUDA_ARCH__ -/** - * CUDA kernel inline function - */ -#define INLINE __device__ inline -#else -/** - * CPP inline function - */ -#define INLINE inline -#endif - -#ifndef PADDLE_TYPE_DOUBLE -#define DEVICE_FMAX fmaxf -#define DEVICE_FMIN fminf -#else -#define DEVICE_FMAX fmax -#define DEVICE_FMIN fmin -#endif - class BaseOp { public: static const bool sse = false; @@ -52,11 +32,7 @@ public: } }; -#if defined(__SSE3__) -#include "hl_matrix_base_sse.cuh" -#elif (defined(__ARM__NEON__) || defined(__ARM_NEON)) -#include "hl_matrix_base_neon.cuh" -#else +#ifdef __CUDA_ARCH__ typedef BaseOp SSESum; typedef BaseOp SSEMax; typedef BaseOp SSEMin; @@ -70,6 +46,8 @@ typedef BaseOp SSESquaredDiff; typedef BaseOp SSEFirst; typedef BaseOp SSESecond; typedef BaseOp SSEClassificationError; +#else +#include "hl_matrix_base_detail.cuh" #endif namespace aggregate { @@ -124,7 +102,7 @@ public: add2(const real s1, const real s2) : SSEAdd2(s1, s2), p1(s1), p2(s2) {} INLINE real operator()(const real a, const real b) const { - return p1 * a + p2 * b; + return p1 * a + p2 * b; } }; diff --git a/paddle/cuda/include/hl_matrix_base_detail.cuh b/paddle/cuda/include/hl_matrix_base_detail.cuh new file mode 100644 index 0000000000000000000000000000000000000000..50079ed53de7a7b6026284afb73b5335096c145b --- /dev/null +++ b/paddle/cuda/include/hl_matrix_base_detail.cuh @@ -0,0 +1,151 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_MATRIX_BASE_DETAIL_CUH_ +#define HL_MATRIX_BASE_DETAIL_CUH_ + +#include "hl_matrix_type.cuh" + +namespace aggregate { +class SSESum { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_add(a, b); + } +}; + +class SSEMax { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_max(a, b); + } +}; + +class SSEMin { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_min(a, b); + } +}; +} // namespace aggregate + +namespace base { +namespace unary { +class SSEIdentity { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a) const { + return a; + } +}; +} // namespace unary + +namespace binary { +class SSEAdd { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_add(a, b); + } +}; + +class SSEAdd2 { +public: + static const bool sse = VECTOR_SIMD; + const real p1; + const real p2; + vecType mp1; + vecType mp2; + +public: + SSEAdd2(const real s1, const real s2) : p1(s1), p2(s2) { + mp1 = hl_vec_set(p1); + mp2 = hl_vec_set(p2); + } + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_add(hl_vec_mul(mp1, a), hl_vec_mul(mp2, b)); + } +}; + +class SSESub { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_sub(a, b); + } +}; + +class SSEMul { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_mul(a, b); + } +}; + +class SSEDiv { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_div(a, b); + } +}; + +class SSESquaredDiff { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_mul(hl_vec_sub(a, b), hl_vec_sub(a, b)); + } +}; + +class SSEFirst { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return a; + } +}; + +class SSESecond { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return b; + } +}; + +class SSEClassificationError { +public: + static const bool sse = VECTOR_SIMD; + const real p; + vecType mp; + vecType result; + +public: + explicit SSEClassificationError(const real s) : p(s) { + mp = hl_vec_set(p); + result = hl_vec_set(1.0f); + } + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_classification_error(a, b, mp, result); + } +}; +} // namespace binary +} // namespace base + +#endif /* HL_MATRIX_BASE_DETAIL_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix_base_neon.cuh b/paddle/cuda/include/hl_matrix_base_neon.cuh deleted file mode 100644 index e13019f5ee24ad600005c99678426ee3808b0e54..0000000000000000000000000000000000000000 --- a/paddle/cuda/include/hl_matrix_base_neon.cuh +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - - -#ifndef HL_MATRIX_BASE_NEON_CUH_ -#define HL_MATRIX_BASE_NEON_CUH_ - -namespace aggregate { -class SSESum { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vaddq_f32(a, b); - } -}; - -class SSEMax { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vmaxq_f32(a, b); - } -}; - -class SSEMin { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vminq_f32(a, b); - } -}; -} // namespace aggregate - -namespace base { -namespace unary { -class SSEIdentity { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a) const { - return a; - } -}; -} // namespace unary - -namespace binary { -class SSEAdd { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vaddq_f32(a, b); - } -}; - -class SSEAdd2 { -public: - static const bool sse = true; - const real p1; - const real p2; - float32x4_t mp1; - float32x4_t mp2; - -public: - SSEAdd2(const real s1, const real s2) : p1(s1), p2(s2) { - mp1 = vdupq_n_f32(p1); - mp2 = vdupq_n_f32(p2); - } - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - float32x4_t tmp1, tmp2; - tmp1 = vmulq_f32(mp1, a); - tmp2 = vmulq_f32(mp2, b); - return vaddq_f32(tmp1, tmp2); - } -}; - -class SSESub { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vsubq_f32(a, b); - } -}; - -class SSEMul { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vmulq_f32(a, b); - } -}; - -class SSEDiv { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - float32x4_t tmp; - tmp = vrecpeq_f32(b); - return vmulq_f32(a, tmp); - } -}; - -class SSESquaredDiff { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - float32x4_t tmp; - tmp = vsubq_f32(a, b); - return vmulq_f32(tmp, tmp); - } -}; - -class SSEFirst { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return a; - } -}; - -class SSESecond { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return b; - } -}; - -class SSEClassificationError { -public: - static const bool sse = true; - const real p; - float32x4_t mp; - uint32x4_t result; - -public: - explicit SSEClassificationError(const real s) : p(s) { - mp = vdupq_n_f32(p); - result = vdupq_n_u32(1); - } - // TODO: to be check - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - uint32x4_t tmp1 = vcgtq_f32(a, mp); - uint32x4_t tmp2 = vcgtq_f32(b, mp); - uint32x4_t tmp3 = veorq_u32(tmp1, tmp2); - return vcvtq_f32_u32(vandq_u32(tmp3, result)); - } -}; -} // namespace binary -} // namespace base - -#endif /* HL_MATRIX_BASE_NEON_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix_base_sse.cuh b/paddle/cuda/include/hl_matrix_base_sse.cuh deleted file mode 100644 index db6c9cca03a8974a15cd2e7fbaf73033e3a57f4b..0000000000000000000000000000000000000000 --- a/paddle/cuda/include/hl_matrix_base_sse.cuh +++ /dev/null @@ -1,211 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - - -#ifndef HL_MATRIX_BASE_SSE_CUH_ -#define HL_MATRIX_BASE_SSE_CUH_ - -namespace aggregate { -class SSESum { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_add_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_add_pd(a, b); - } -}; - -class SSEMax { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_max_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_max_pd(a, b); - } -}; - -class SSEMin { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_min_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_min_pd(a, b); - } -}; -} // namespace aggregate - -namespace base { -namespace unary { -class SSEIdentity { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a) const { - return a; - } - INLINE __m128d vecOp(const __m128d a) const { - return a; - } -}; -} // namespace unary - -namespace binary { -class SSEAdd { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_add_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_add_pd(a, b); - } -}; - -class SSEAdd2 { -public: - static const bool sse = true; - const real p1; - const real p2; - union {__m128 f; __m128d d;} mp1; - union {__m128 f; __m128d d;} mp2; - -public: - SSEAdd2(const real s1, const real s2) : p1(s1), p2(s2) { - if (sizeof(real) == sizeof(float)) { - mp1.f = _mm_set1_ps(p1); - mp2.f = _mm_set1_ps(p2); - } else { - mp1.d = _mm_set1_pd(p1); - mp2.d = _mm_set1_pd(p2); - } - } - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - __m128 tmp1, tmp2; - tmp1 = _mm_mul_ps(mp1.f, a); - tmp2 = _mm_mul_ps(mp2.f, b); - return _mm_add_ps(tmp1, tmp2); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - __m128d tmp1, tmp2; - tmp1 = _mm_mul_pd(mp1.d, a); - tmp2 = _mm_mul_pd(mp2.d, b); - return _mm_add_pd(tmp1, tmp2); - } -}; - -class SSESub { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_sub_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_sub_pd(a, b); - } -}; - -class SSEMul { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_mul_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_mul_pd(a, b); - } -}; - -class SSEDiv { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_div_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_div_pd(a, b); - } -}; - -class SSESquaredDiff { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_mul_ps(_mm_sub_ps(a, b), _mm_sub_ps(a, b)); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_mul_pd(_mm_sub_pd(a, b), _mm_sub_pd(a, b)); - } -}; - -class SSEFirst { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return a; - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return a; - } -}; - -class SSESecond { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return b; - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return b; - } -}; - -class SSEClassificationError { -public: - static const bool sse = true; - const real p; - union {__m128 f; __m128d d;} mp; - union {__m128 f; __m128d d;} result; - -public: - explicit SSEClassificationError(const real s) : p(s) { - if (sizeof(real) == sizeof(float)) { - mp.f = _mm_set1_ps(p); - result.f = _mm_set1_ps(1.0f); - } else { - mp.d = _mm_set1_pd(p); - result.d = _mm_set1_pd(1.0); - } - } - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - __m128 tmp1 = _mm_cmpgt_ps(a, mp.f); - __m128 tmp2 = _mm_cmpgt_ps(b, mp.f); - __m128 tmp3 = _mm_xor_ps(tmp1, tmp2); - return _mm_and_ps(tmp3, result.f); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - __m128d tmp1 = _mm_cmpgt_pd(a, mp.d); - __m128d tmp2 = _mm_cmpgt_pd(b, mp.d); - __m128d tmp3 = _mm_xor_pd(tmp1, tmp2); - return _mm_and_pd(tmp3, result.d); - } -}; -} // namespace binary -} // namespace base - -#endif /* HL_MATRIX_BASE_SSE_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix_type.cuh b/paddle/cuda/include/hl_matrix_type.cuh index 7d6face5f0e5436a601017c14e9068e81e2cd901..2ced2fb1ab1afc8b0619a904709ce4504291dd86 100644 --- a/paddle/cuda/include/hl_matrix_type.cuh +++ b/paddle/cuda/include/hl_matrix_type.cuh @@ -17,6 +17,18 @@ limitations under the License. */ #include "hl_base.h" +#ifdef __CUDA_ARCH__ +/** + * CUDA kernel inline function + */ +#define INLINE __device__ inline +#else +/** + * CPP inline function + */ +#define INLINE inline +#endif + #ifdef __CUDA_ARCH__ #include #ifndef PADDLE_TYPE_DOUBLE @@ -32,10 +44,4 @@ typedef double2 vecType; #include "hl_cpu_scalar.cuh" #endif -#ifdef __CUDA_ARCH__ -#define INLINE __device__ inline -#else -#define INLINE inline -#endif - #endif // HL_MATRIX_TYPE_CUH_