From 987f065a82390cc2228c2acfe7f0a726d64a8a54 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 26 May 2017 22:55:27 -0700 Subject: [PATCH] 1. Support scalar computing. 2. Centralize the use of sse and neon instrinsic. 3. Disable neon intrinsic when enable gpu. --- paddle/cuda/include/hl_cpu_matrix_kernel.cuh | 36 +-- ...el.cuh => hl_cpu_matrix_kernel_detail.cuh} | 122 ++++--- paddle/cuda/include/hl_cpu_scalar.cuh | 74 +++++ paddle/cuda/include/hl_cpu_simd_neon.cuh | 98 ++++++ paddle/cuda/include/hl_cpu_simd_sse.cuh | 142 +++++++++ paddle/cuda/include/hl_matrix_base.cuh | 26 +- paddle/cuda/include/hl_matrix_base_detail.cuh | 151 +++++++++ paddle/cuda/include/hl_matrix_base_neon.cuh | 161 ---------- paddle/cuda/include/hl_matrix_base_sse.cuh | 211 ------------ paddle/cuda/include/hl_matrix_type.cuh | 40 ++- paddle/cuda/include/hl_neon_matrix_kernel.cuh | 299 ------------------ 11 files changed, 545 insertions(+), 815 deletions(-) rename paddle/cuda/include/{hl_sse_matrix_kernel.cuh => hl_cpu_matrix_kernel_detail.cuh} (89%) create mode 100644 paddle/cuda/include/hl_cpu_scalar.cuh create mode 100644 paddle/cuda/include/hl_cpu_simd_neon.cuh create mode 100644 paddle/cuda/include/hl_cpu_simd_sse.cuh create mode 100644 paddle/cuda/include/hl_matrix_base_detail.cuh delete mode 100644 paddle/cuda/include/hl_matrix_base_neon.cuh delete mode 100644 paddle/cuda/include/hl_matrix_base_sse.cuh delete mode 100644 paddle/cuda/include/hl_neon_matrix_kernel.cuh diff --git a/paddle/cuda/include/hl_cpu_matrix_kernel.cuh b/paddle/cuda/include/hl_cpu_matrix_kernel.cuh index 9c49a4bd208..aaa24325514 100644 --- a/paddle/cuda/include/hl_cpu_matrix_kernel.cuh +++ b/paddle/cuda/include/hl_cpu_matrix_kernel.cuh @@ -17,10 +17,9 @@ limitations under the License. */ #include #include "hl_base.h" -#if defined(__ARM_NEON__) || defined(__ARM_NEON) -#include "hl_neon_matrix_kernel.cuh" -#else -#include "hl_sse_matrix_kernel.cuh" + +#ifndef __CUDA_ARCH__ +#include "hl_cpu_matrix_kernel_detail.cuh" #endif /** @@ -114,35 +113,6 @@ void hl_cpu_apply_quaternary_op(Op op, } } -template -void hl_matrix_row_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, int ld, - real *A, int lda) { - for (int i = 0; i < dimM; i++) { - real tmp = agg.init(); - for (int j = 0; j < dimN; j++) { - tmp = agg(tmp, op(A[i * lda + j])); - } - dst[i*ld] = sv(dst[i*ld], tmp); - } -} - -template -void hl_matrix_row_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, int ld, - real *A, int lda, - real *B, int ldb) { - for (int i = 0; i < dimM; i++) { - real tmp = agg.init(); - for (int j = 0; j < dimN; j++) { - tmp = agg(tmp, op(A[i * lda + j], B[i * ldb + j])); - } - dst[i*ld] = sv(dst[i*ld], tmp); - } -} - template void hl_cpu_matrix_row_op(Agg agg, Op op, Saver sv, int dimM, int dimN, diff --git a/paddle/cuda/include/hl_sse_matrix_kernel.cuh b/paddle/cuda/include/hl_cpu_matrix_kernel_detail.cuh similarity index 89% rename from paddle/cuda/include/hl_sse_matrix_kernel.cuh rename to paddle/cuda/include/hl_cpu_matrix_kernel_detail.cuh index 9e50580669d..85ca836fdc4 100644 --- a/paddle/cuda/include/hl_sse_matrix_kernel.cuh +++ b/paddle/cuda/include/hl_cpu_matrix_kernel_detail.cuh @@ -13,26 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef HL_SSE_MATRIX_KERNEL_CUH_ -#define HL_SSE_MATRIX_KERNEL_CUH_ +#ifndef HL_MATRIX_KERNEL_DETAIL_CUH_ +#define HL_MATRIX_KERNEL_DETAIL_CUH_ #include "hl_matrix_type.cuh" -#define VECTOR_SIZE 16 - -#ifndef PADDLE_TYPE_DOUBLE -/* number of float in vector */ -#define VECTOR_LEN 4 -#define VECTOR_SET _mm_set_ps1 -#else -#if defined(__APPLE__) || defined(__OSX__) -#define _mm_set_pd1 _mm_set1_pd -#endif -/* number of double in vector */ -#define VECTOR_LEN 2 -#define VECTOR_SET _mm_set_pd1 -#endif - inline bool hl_check_align(size_t size) { return !(size & (VECTOR_SIZE - 1)); } @@ -41,27 +26,63 @@ inline bool hl_check_align(void *ptr) { return hl_check_align(reinterpret_cast(ptr)); } -#ifndef PADDLE_TYPE_DOUBLE -template -inline real hl_agg_op(Agg agg, vecType mm) { - __m128 lo = _mm_unpacklo_ps(mm, mm); - __m128 hi = _mm_unpackhi_ps(mm, mm); - __m128 tmp1 = agg.vecOp(lo, hi); - __m128 tmp2 = _mm_movehl_ps(tmp1, tmp1); - __m128 ret = agg.vecOp(tmp1, tmp2); +template +void hl_matrix_row_op(Agg agg, Op op, Saver sv, + int dimM, int dimN, + real *dst, int ld, + real *A, int lda) { + for (int i = 0; i < dimM; i++) { + real tmp = agg.init(); + for (int j = 0; j < dimN; j++) { + tmp = agg(tmp, op(A[i * lda + j])); + } + dst[i*ld] = sv(dst[i*ld], tmp); + } +} - return _mm_cvtss_f32(ret); +template +void hl_matrix_row_op(Agg agg, Op op, Saver sv, + int dimM, int dimN, + real *dst, int ld, + real *A, int lda, + real *B, int ldb) { + for (int i = 0; i < dimM; i++) { + real tmp = agg.init(); + for (int j = 0; j < dimN; j++) { + tmp = agg(tmp, op(A[i * lda + j], B[i * ldb + j])); + } + dst[i*ld] = sv(dst[i*ld], tmp); + } } -#else -template -inline real hl_agg_op(Agg agg, vecType mm) { - __m128d lo = _mm_unpacklo_pd(mm, mm); - __m128d hi = _mm_unpackhi_pd(mm, mm); - __m128d ret = agg.vecOp(lo, hi); - - return _mm_cvtsd_f64(ret); + +template +void hl_matrix_column_op(Agg agg, Op op, Saver sv, + int dimM, int dimN, + real *dst, + real *A, int lda) { + for (int j = 0; j < dimN; j++) { + real tmp = agg.init(); + for (int i = 0; i < dimM; i++) { + tmp = agg(tmp, op(A[i * lda + j])); + } + dst[j] = sv(dst[j], tmp); + } +} + +template +void hl_matrix_column_op(Agg agg, Op op, Saver sv, + int dimM, int dimN, + real *dst, + real *A, int lda, + real *B, int ldb) { + for (int j = 0; j < dimN; j++) { + real tmp = agg.init(); + for (int i = 0; i < dimM; i++) { + tmp = agg(tmp, op(A[i * lda + j], B[i * ldb + j])); + } + dst[j] = sv(dst[j], tmp); + } } -#endif template void hl_sse_matrix_row_op(Agg agg, Op op, Saver sv, @@ -118,35 +139,6 @@ void hl_sse_matrix_row_op(Agg agg, Op op, Saver sv, } } -template -void hl_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda) { - for (int j = 0; j < dimN; j++) { - real tmp = agg.init(); - for (int i = 0; i < dimM; i++) { - tmp = agg(tmp, op(A[i * lda + j])); - } - dst[j] = sv(dst[j], tmp); - } -} - -template -void hl_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda, - real *B, int ldb) { - for (int j = 0; j < dimN; j++) { - real tmp = agg.init(); - for (int i = 0; i < dimM; i++) { - tmp = agg(tmp, op(A[i * lda + j], B[i * ldb + j])); - } - dst[j] = sv(dst[j], tmp); - } -} - /* * MaxRow greater than or equal dimN * dimN is multiples of VECTOR_LEN @@ -315,4 +307,4 @@ void hl_sse_matrix_column_op(Agg agg, Op op, Saver sv, } } -#endif /* HL_SSE_MATRIX_KERNEL_CUH_ */ +#endif /* HL_MATRIX_KERNEL_DETAIL_CUH_ */ diff --git a/paddle/cuda/include/hl_cpu_scalar.cuh b/paddle/cuda/include/hl_cpu_scalar.cuh new file mode 100644 index 00000000000..cddf08ce6b6 --- /dev/null +++ b/paddle/cuda/include/hl_cpu_scalar.cuh @@ -0,0 +1,74 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_CPU_SCALAR_CUH_ +#define HL_CPU_SCALAR_CUH_ + +#define VECTOR_SIMD false +#define VECTOR_SET hl_vec_set + +#ifndef PADDLE_TYPE_DOUBLE +/* size of float */ +#define VECTOR_SIZE 4 +#else +/* size of double */ +#define VECTOR_SIZE 8 +#endif + +typedef real vecType; + +/* Consider a real as a vector */ +#define VECTOR_LEN 1 + +template +inline real hl_agg_op(Agg agg, vecType mm) { + return mm; +} + +INLINE real hl_vec_set(const real r) { + return r; +} + +INLINE real hl_vec_max(const real a, const real b) { + return a > b ? a : b; +} + +INLINE real hl_vec_min(const real a, const real b) { + return a > b ? b : a; +} + +INLINE real hl_vec_add(const real a, const real b) { + return a + b; +} + +INLINE real hl_vec_sub(const real a, const real b) { + return a - b; +} + +INLINE real hl_vec_mul(const real a, const real b) { + return a * b; +} + +INLINE real hl_vec_div(const real a, const real b) { + return a / b; +} + +INLINE real hl_vec_classification_error(const real a, + const real b, + const real p, + const real r) { + return ((a > p) == (b > p)) ? 0.0f : 1.0f; +} + +#endif // HL_CPU_SCALAR_CUH_ diff --git a/paddle/cuda/include/hl_cpu_simd_neon.cuh b/paddle/cuda/include/hl_cpu_simd_neon.cuh new file mode 100644 index 00000000000..9ff360c576f --- /dev/null +++ b/paddle/cuda/include/hl_cpu_simd_neon.cuh @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_CPU_SIMD_NEON_CUH_ +#define HL_CPU_SIMD_NEON_CUH_ + +#include + +#define VECTOR_SIMD true +#define VECTOR_SIZE 16 +#define VECTOR_SET hl_vec_set + +#ifndef PADDLE_TYPE_DOUBLE + +typedef float32x4_t vecType; + +/* number of float in vector */ +#define VECTOR_LEN 4 + +template +inline real hl_agg_op(Agg agg, vecType mm) { + float32x4_t rev = vrev64q_f32(mm); + float32x4_t tmp1 = agg.vecOp(rev, rev); + float32x2_t lo = vget_high_f32(rev); + float32x2_t hi = vget_low_f32(rev); + float32x4_t tmp2 = vcombine_f32(hi, lo); + float32x4_t ret = agg.vecOp(tmp1, tmp2); + + return vgetq_lane_f32(ret, 0); +} + +inline float32x4_t hl_vec_set(const real f) { + return vdupq_n_f32(f); +} + +inline float32x4_t hl_vec_max(const float32x4_t a, const float32x4_t b) { + return vmaxq_f32(a, b); +} + +inline float32x4_t hl_vec_min(const float32x4_t a, const float32x4_t b) { + return vminq_f32(a, b); +} + +inline float32x4_t hl_vec_add(const float32x4_t a, const float32x4_t b) { + return vaddq_f32(a, b); +} + +inline float32x4_t hl_vec_sub(const float32x4_t a, const float32x4_t b) { + return vsubq_f32(a, b); +} + +inline float32x4_t hl_vec_mul(const float32x4_t a, const float32x4_t b) { + return vmulq_f32(a, b); +} + +inline float32x4_t hl_vec_div(const float32x4_t a, const float32x4_t b) { + float32x4_t tmp = vrecpeq_f32(b); + return vmulq_f32(a, tmp); +} + +inline float32x4_t hl_vec_classification_error(const float32x4_t a, + const float32x4_t b, + const float32x4_t p, + const float32x4_t r) { + uint32x4_t tmp1 = vcgtq_f32(a, p); + uint32x4_t tmp2 = vcgtq_f32(b, p); + uint32x4_t tmp3 = veorq_u32(tmp1, tmp2); + return vcvtq_f32_u32(vandq_u32(tmp3, vcvtq_u32_f32(r))); +} + +#else + +#ifdef __aarch64__ +typedef float64x2_t vecType; + +/* number of float in vector */ +#define VECTOR_LEN 2 +#define VECTOR_SET vdupq_n_f64 + +#error To be implemented +#else +#error NEON instructions does not support double precision +#endif // __aarch64__ + +#endif + +#endif // HL_CPU_SIMD_NEON_CUH_ diff --git a/paddle/cuda/include/hl_cpu_simd_sse.cuh b/paddle/cuda/include/hl_cpu_simd_sse.cuh new file mode 100644 index 00000000000..9a918770b14 --- /dev/null +++ b/paddle/cuda/include/hl_cpu_simd_sse.cuh @@ -0,0 +1,142 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_CPU_SIMD_SSE_CUH_ +#define HL_CPU_SIMD_SSE_CUH_ + +#include +#include +#include + +#define VECTOR_SIMD true +#define VECTOR_SIZE 16 +#define VECTOR_SET hl_vec_set + +#ifndef PADDLE_TYPE_DOUBLE + +typedef __m128 vecType; + +/* number of float in vector */ +#define VECTOR_LEN 4 + +template +inline real hl_agg_op(Agg agg, vecType mm) { + __m128 lo = _mm_unpacklo_ps(mm, mm); + __m128 hi = _mm_unpackhi_ps(mm, mm); + __m128 tmp1 = agg.vecOp(lo, hi); + __m128 tmp2 = _mm_movehl_ps(tmp1, tmp1); + __m128 ret = agg.vecOp(tmp1, tmp2); + + return _mm_cvtss_f32(ret); +} + +inline __m128 hl_vec_set(const real f) { + return _mm_set_ps1(f); +} + +inline __m128 hl_vec_max(const __m128 a, const __m128 b) { + return _mm_max_ps(a, b); +} + +inline __m128 hl_vec_min(const __m128 a, const __m128 b) { + return _mm_min_ps(a, b); +} + +inline __m128 hl_vec_add(const __m128 a, const __m128 b) { + return _mm_add_ps(a, b); +} + +inline __m128 hl_vec_sub(const __m128 a, const __m128 b) { + return _mm_sub_ps(a, b); +} + +inline __m128 hl_vec_mul(const __m128 a, const __m128 b) { + return _mm_mul_ps(a, b); +} + +inline __m128 hl_vec_div(const __m128 a, const __m128 b) { + return _mm_div_ps(a, b); +} + +inline __m128 hl_vec_classification_error(const __m128 a, + const __m128 b, + const __m128 p, + const __m128 r) { + __m128 tmp1 = _mm_cmpgt_ps(a, p); + __m128 tmp2 = _mm_cmpgt_ps(b, p); + __m128 tmp3 = _mm_xor_ps(tmp1, tmp2); + return _mm_and_ps(tmp3, r); +} + +#else + +typedef __m128d vecType; + +/* number of double in vector */ +#define VECTOR_LEN 2 + +template +inline real hl_agg_op(Agg agg, vecType mm) { + __m128d lo = _mm_unpacklo_pd(mm, mm); + __m128d hi = _mm_unpackhi_pd(mm, mm); + __m128d ret = agg.vecOp(lo, hi); + + return _mm_cvtsd_f64(ret); +} + +inline __m128d hl_vec_set(const real d) { +#if defined(__APPLE__) || defined(__OSX__) + return _mm_set1_pd(d); +#else + return _mm_set_pd1(d); +#endif +} + +inline __m128d hl_vec_max(const __m128d a, const __m128d b) { + return _mm_max_pd(a, b); +} + +inline __m128d hl_vec_min(const __m128d a, const __m128d b) { + return _mm_min_pd(a, b); +} + +inline __m128d hl_vec_add(const __m128d a, const __m128d b) { + return _mm_add_pd(a, b); +} + +inline __m128d hl_vec_sub(const __m128d a, const __m128d b) { + return _mm_sub_pd(a, b); +} + +inline __m128d hl_vec_mul(const __m128d a, const __m128d b) { + return _mm_mul_pd(a, b); +} + +inline __m128d hl_vec_div(const __m128d a, const __m128d b) { + return _mm_div_pd(a, b); +} + +inline __m128d hl_vec_classification_error(const __m128d a, + const __m128d b, + const __m128d p, + const __m128d r) { + __m128d tmp1 = _mm_cmpgt_pd(a, p); + __m128d tmp2 = _mm_cmpgt_pd(b, p); + __m128d tmp3 = _mm_xor_pd(tmp1, tmp2); + return _mm_and_pd(tmp3, r); +} + +#endif + +#endif // HL_CPU_SIMD_SSE_CUH_ diff --git a/paddle/cuda/include/hl_matrix_base.cuh b/paddle/cuda/include/hl_matrix_base.cuh index 8b755c1095c..53fdb47ec9c 100644 --- a/paddle/cuda/include/hl_matrix_base.cuh +++ b/paddle/cuda/include/hl_matrix_base.cuh @@ -18,26 +18,6 @@ limitations under the License. */ #include "hl_matrix_type.cuh" -#ifdef __CUDA_ARCH__ -/** - * CUDA kernel inline function - */ -#define INLINE __device__ inline -#else -/** - * CPP inline function - */ -#define INLINE inline -#endif - -#ifndef PADDLE_TYPE_DOUBLE -#define DEVICE_FMAX fmaxf -#define DEVICE_FMIN fminf -#else -#define DEVICE_FMAX fmax -#define DEVICE_FMIN fmin -#endif - class BaseOp { public: static const bool sse = false; @@ -66,10 +46,8 @@ typedef BaseOp SSESquaredDiff; typedef BaseOp SSEFirst; typedef BaseOp SSESecond; typedef BaseOp SSEClassificationError; -#elif defined(__ARM__NEON__) || defined(__ARM_NEON) -#include "hl_matrix_base_neon.cuh" #else -#include "hl_matrix_base_sse.cuh" +#include "hl_matrix_base_detail.cuh" #endif namespace aggregate { @@ -124,7 +102,7 @@ public: add2(const real s1, const real s2) : SSEAdd2(s1, s2), p1(s1), p2(s2) {} INLINE real operator()(const real a, const real b) const { - return p1 * a + p2 * b; + return p1 * a + p2 * b; } }; diff --git a/paddle/cuda/include/hl_matrix_base_detail.cuh b/paddle/cuda/include/hl_matrix_base_detail.cuh new file mode 100644 index 00000000000..50079ed53de --- /dev/null +++ b/paddle/cuda/include/hl_matrix_base_detail.cuh @@ -0,0 +1,151 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef HL_MATRIX_BASE_DETAIL_CUH_ +#define HL_MATRIX_BASE_DETAIL_CUH_ + +#include "hl_matrix_type.cuh" + +namespace aggregate { +class SSESum { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_add(a, b); + } +}; + +class SSEMax { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_max(a, b); + } +}; + +class SSEMin { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_min(a, b); + } +}; +} // namespace aggregate + +namespace base { +namespace unary { +class SSEIdentity { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a) const { + return a; + } +}; +} // namespace unary + +namespace binary { +class SSEAdd { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_add(a, b); + } +}; + +class SSEAdd2 { +public: + static const bool sse = VECTOR_SIMD; + const real p1; + const real p2; + vecType mp1; + vecType mp2; + +public: + SSEAdd2(const real s1, const real s2) : p1(s1), p2(s2) { + mp1 = hl_vec_set(p1); + mp2 = hl_vec_set(p2); + } + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_add(hl_vec_mul(mp1, a), hl_vec_mul(mp2, b)); + } +}; + +class SSESub { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_sub(a, b); + } +}; + +class SSEMul { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_mul(a, b); + } +}; + +class SSEDiv { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_div(a, b); + } +}; + +class SSESquaredDiff { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_mul(hl_vec_sub(a, b), hl_vec_sub(a, b)); + } +}; + +class SSEFirst { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return a; + } +}; + +class SSESecond { +public: + static const bool sse = VECTOR_SIMD; + INLINE vecType vecOp(const vecType a, const vecType b) const { + return b; + } +}; + +class SSEClassificationError { +public: + static const bool sse = VECTOR_SIMD; + const real p; + vecType mp; + vecType result; + +public: + explicit SSEClassificationError(const real s) : p(s) { + mp = hl_vec_set(p); + result = hl_vec_set(1.0f); + } + INLINE vecType vecOp(const vecType a, const vecType b) const { + return hl_vec_classification_error(a, b, mp, result); + } +}; +} // namespace binary +} // namespace base + +#endif /* HL_MATRIX_BASE_DETAIL_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix_base_neon.cuh b/paddle/cuda/include/hl_matrix_base_neon.cuh deleted file mode 100644 index e13019f5ee2..00000000000 --- a/paddle/cuda/include/hl_matrix_base_neon.cuh +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - - -#ifndef HL_MATRIX_BASE_NEON_CUH_ -#define HL_MATRIX_BASE_NEON_CUH_ - -namespace aggregate { -class SSESum { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vaddq_f32(a, b); - } -}; - -class SSEMax { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vmaxq_f32(a, b); - } -}; - -class SSEMin { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vminq_f32(a, b); - } -}; -} // namespace aggregate - -namespace base { -namespace unary { -class SSEIdentity { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a) const { - return a; - } -}; -} // namespace unary - -namespace binary { -class SSEAdd { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vaddq_f32(a, b); - } -}; - -class SSEAdd2 { -public: - static const bool sse = true; - const real p1; - const real p2; - float32x4_t mp1; - float32x4_t mp2; - -public: - SSEAdd2(const real s1, const real s2) : p1(s1), p2(s2) { - mp1 = vdupq_n_f32(p1); - mp2 = vdupq_n_f32(p2); - } - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - float32x4_t tmp1, tmp2; - tmp1 = vmulq_f32(mp1, a); - tmp2 = vmulq_f32(mp2, b); - return vaddq_f32(tmp1, tmp2); - } -}; - -class SSESub { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vsubq_f32(a, b); - } -}; - -class SSEMul { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return vmulq_f32(a, b); - } -}; - -class SSEDiv { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - float32x4_t tmp; - tmp = vrecpeq_f32(b); - return vmulq_f32(a, tmp); - } -}; - -class SSESquaredDiff { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - float32x4_t tmp; - tmp = vsubq_f32(a, b); - return vmulq_f32(tmp, tmp); - } -}; - -class SSEFirst { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return a; - } -}; - -class SSESecond { -public: - static const bool sse = true; - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - return b; - } -}; - -class SSEClassificationError { -public: - static const bool sse = true; - const real p; - float32x4_t mp; - uint32x4_t result; - -public: - explicit SSEClassificationError(const real s) : p(s) { - mp = vdupq_n_f32(p); - result = vdupq_n_u32(1); - } - // TODO: to be check - INLINE float32x4_t vecOp(const float32x4_t a, const float32x4_t b) const { - uint32x4_t tmp1 = vcgtq_f32(a, mp); - uint32x4_t tmp2 = vcgtq_f32(b, mp); - uint32x4_t tmp3 = veorq_u32(tmp1, tmp2); - return vcvtq_f32_u32(vandq_u32(tmp3, result)); - } -}; -} // namespace binary -} // namespace base - -#endif /* HL_MATRIX_BASE_NEON_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix_base_sse.cuh b/paddle/cuda/include/hl_matrix_base_sse.cuh deleted file mode 100644 index db6c9cca03a..00000000000 --- a/paddle/cuda/include/hl_matrix_base_sse.cuh +++ /dev/null @@ -1,211 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - - -#ifndef HL_MATRIX_BASE_SSE_CUH_ -#define HL_MATRIX_BASE_SSE_CUH_ - -namespace aggregate { -class SSESum { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_add_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_add_pd(a, b); - } -}; - -class SSEMax { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_max_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_max_pd(a, b); - } -}; - -class SSEMin { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_min_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_min_pd(a, b); - } -}; -} // namespace aggregate - -namespace base { -namespace unary { -class SSEIdentity { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a) const { - return a; - } - INLINE __m128d vecOp(const __m128d a) const { - return a; - } -}; -} // namespace unary - -namespace binary { -class SSEAdd { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_add_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_add_pd(a, b); - } -}; - -class SSEAdd2 { -public: - static const bool sse = true; - const real p1; - const real p2; - union {__m128 f; __m128d d;} mp1; - union {__m128 f; __m128d d;} mp2; - -public: - SSEAdd2(const real s1, const real s2) : p1(s1), p2(s2) { - if (sizeof(real) == sizeof(float)) { - mp1.f = _mm_set1_ps(p1); - mp2.f = _mm_set1_ps(p2); - } else { - mp1.d = _mm_set1_pd(p1); - mp2.d = _mm_set1_pd(p2); - } - } - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - __m128 tmp1, tmp2; - tmp1 = _mm_mul_ps(mp1.f, a); - tmp2 = _mm_mul_ps(mp2.f, b); - return _mm_add_ps(tmp1, tmp2); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - __m128d tmp1, tmp2; - tmp1 = _mm_mul_pd(mp1.d, a); - tmp2 = _mm_mul_pd(mp2.d, b); - return _mm_add_pd(tmp1, tmp2); - } -}; - -class SSESub { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_sub_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_sub_pd(a, b); - } -}; - -class SSEMul { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_mul_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_mul_pd(a, b); - } -}; - -class SSEDiv { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_div_ps(a, b); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_div_pd(a, b); - } -}; - -class SSESquaredDiff { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return _mm_mul_ps(_mm_sub_ps(a, b), _mm_sub_ps(a, b)); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return _mm_mul_pd(_mm_sub_pd(a, b), _mm_sub_pd(a, b)); - } -}; - -class SSEFirst { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return a; - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return a; - } -}; - -class SSESecond { -public: - static const bool sse = true; - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - return b; - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - return b; - } -}; - -class SSEClassificationError { -public: - static const bool sse = true; - const real p; - union {__m128 f; __m128d d;} mp; - union {__m128 f; __m128d d;} result; - -public: - explicit SSEClassificationError(const real s) : p(s) { - if (sizeof(real) == sizeof(float)) { - mp.f = _mm_set1_ps(p); - result.f = _mm_set1_ps(1.0f); - } else { - mp.d = _mm_set1_pd(p); - result.d = _mm_set1_pd(1.0); - } - } - INLINE __m128 vecOp(const __m128 a, const __m128 b) const { - __m128 tmp1 = _mm_cmpgt_ps(a, mp.f); - __m128 tmp2 = _mm_cmpgt_ps(b, mp.f); - __m128 tmp3 = _mm_xor_ps(tmp1, tmp2); - return _mm_and_ps(tmp3, result.f); - } - INLINE __m128d vecOp(const __m128d a, const __m128d b) const { - __m128d tmp1 = _mm_cmpgt_pd(a, mp.d); - __m128d tmp2 = _mm_cmpgt_pd(b, mp.d); - __m128d tmp3 = _mm_xor_pd(tmp1, tmp2); - return _mm_and_pd(tmp3, result.d); - } -}; -} // namespace binary -} // namespace base - -#endif /* HL_MATRIX_BASE_SSE_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix_type.cuh b/paddle/cuda/include/hl_matrix_type.cuh index f965ba96679..77f73167fe6 100644 --- a/paddle/cuda/include/hl_matrix_type.cuh +++ b/paddle/cuda/include/hl_matrix_type.cuh @@ -17,35 +17,31 @@ limitations under the License. */ #include "hl_base.h" -#if defined(__CUDA_ARCH__) +#ifdef __CUDA_ARCH__ +/** + * CUDA kernel inline function + */ +#define INLINE __device__ inline +#else +/** + * CPP inline function + */ +#define INLINE inline +#endif + +#ifdef __CUDA_ARCH__ #include #ifndef PADDLE_TYPE_DOUBLE typedef float4 vecType; #else typedef double2 vecType; #endif -#elif (defined __ARM_NEON) || (defined __ARM_NEON__) -#include -#ifndef PADDLE_TYPE_DOUBLE -typedef float32x4_t vecType; -#else -#error NEON instructions does not support double precision -#endif +#elif defined(__SSE3__) +#include "hl_cpu_simd_sse.cuh" +#elif (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !defined(__NVCC__) +#include "hl_cpu_simd_neon.cuh" #else -#include -#include -#include -#ifndef PADDLE_TYPE_DOUBLE -typedef __m128 vecType; -#else -typedef __m128d vecType; -#endif -#endif - -#ifdef __CUDA_ARCH__ -#define INLINE __device__ inline -#else -#define INLINE inline +#include "hl_cpu_scalar.cuh" #endif #endif // HL_MATRIX_TYPE_CUH_ diff --git a/paddle/cuda/include/hl_neon_matrix_kernel.cuh b/paddle/cuda/include/hl_neon_matrix_kernel.cuh deleted file mode 100644 index 7b4e5b00079..00000000000 --- a/paddle/cuda/include/hl_neon_matrix_kernel.cuh +++ /dev/null @@ -1,299 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - - -#ifndef HL_NEON_MATRIX_KERNEL_CUH_ -#define HL_NEON_MATRIX_KERNEL_CUH_ - -#include "hl_matrix_type.cuh" - -#define VECTOR_SIZE 16 - -/* number of float in vector */ -#define VECTOR_LEN 4 -#define VECTOR_SET vdupq_n_f32 - -inline bool hl_check_align(size_t size) { - return !(size & (VECTOR_SIZE - 1)); -} - -inline bool hl_check_align(void *ptr) { - return hl_check_align(reinterpret_cast(ptr)); -} - -template -inline real hl_agg_op(Agg agg, vecType mm) { - float32x4_t rev = vrev64q_f32(mm); - float32x4_t tmp1 = agg.vecOp(rev, rev); - float32x2_t lo = vget_high_f32(rev); - float32x2_t hi = vget_low_f32(rev); - float32x4_t tmp2 = vcombine_f32(hi, lo); - float32x4_t ret = agg.vecOp(tmp1, tmp2); - - return vgetq_lane_f32(ret, 0); -} - -template -void hl_sse_matrix_row_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, int ld, - real *A, int lda) { - for (int i = 0; i < dimM; i++, A += lda) { - vecType mm = VECTOR_SET(agg.init()); - vecType *a = (vecType*)(A); - for (int j = 0; j < dimN / VECTOR_LEN; j++, a++) { - mm = agg.vecOp(mm, op.vecOp(*a)); - } - - int rem = dimN % VECTOR_LEN; - if (rem) { - real tmp = hl_agg_op(agg, mm); - real *a = A + (dimN / VECTOR_LEN) * VECTOR_LEN; - for (int j = 0; j < rem; j++) { - tmp = agg(tmp, op(a[j])); - } - dst[i*ld] = sv(dst[i*ld], tmp); - } else { - dst[i*ld] = sv(dst[i*ld], hl_agg_op(agg, mm)); - } - } -} - -template -void hl_sse_matrix_row_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, int ld, - real *A, int lda, - real *B, int ldb) { - for (int i = 0; i < dimM; i++, A += lda, B += ldb) { - vecType mm = VECTOR_SET(agg.init()); - vecType *a = (vecType*)(A); - vecType *b = (vecType*)(B); - for (int j = 0; j < dimN / VECTOR_LEN; j++, a++, b++) { - mm = agg.vecOp(mm, op.vecOp(*a, *b)); - } - - int rem = dimN % VECTOR_LEN; - if (rem) { - real tmp = hl_agg_op(agg, mm); - real *a = A + (dimN / VECTOR_LEN) * VECTOR_LEN; - real *b = B + (dimN / VECTOR_LEN) * VECTOR_LEN; - for (int j = 0; j < rem; j++) { - tmp = agg(tmp, op(a[j], b[j])); - } - dst[i*ld] = sv(dst[i*ld], tmp); - } else { - dst[i*ld] = sv(dst[i*ld], hl_agg_op(agg, mm)); - } - } -} - -template -void hl_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda) { - for (int j = 0; j < dimN; j++) { - real tmp = agg.init(); - for (int i = 0; i < dimM; i++) { - tmp = agg(tmp, op(A[i * lda + j])); - } - dst[j] = sv(dst[j], tmp); - } -} - -template -void hl_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda, - real *B, int ldb) { - for (int j = 0; j < dimN; j++) { - real tmp = agg.init(); - for (int i = 0; i < dimM; i++) { - tmp = agg(tmp, op(A[i * lda + j], B[i * ldb + j])); - } - dst[j] = sv(dst[j], tmp); - } -} - -/* - * MaxRow greater than or equal dimN - * dimN is multiples of VECTOR_LEN - * so rem <= MaxRow / VECTOR_LEN - */ -template -void hl_sse_column_op_with_rem(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda) { - vecType mm[MaxRow / VECTOR_LEN]; - for (int n = 0; n < MaxRow / VECTOR_LEN; n++) { - mm[n] = VECTOR_SET(agg.init()); - } - - for (int i = 0; i < dimM; i++) { - vecType *a = (vecType*)(A + i * lda); - for (int n = 0; n < dimN / VECTOR_LEN; n++) { - mm[n] = agg.vecOp(mm[n], op.vecOp(a[n])); - } - } - - vecType *result = (vecType*)(dst); - for (int n = 0; n < dimN / VECTOR_LEN; n++) { - result[n] = sv.vecOp(result[n], mm[n]); - } - - int rem = dimN % VECTOR_LEN; - if (rem) { - A += (dimN / VECTOR_LEN) * VECTOR_LEN; - dst += (dimN / VECTOR_LEN) * VECTOR_LEN; - hl_matrix_column_op(agg, op, sv, dimM, rem, dst, A, lda); - } -} - -/* - * dimN is multiples of VECTOR_LEN - * dimN greater than Step - */ -template -void hl_sse_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda) { - for (int j = 0; j < dimN / Step; j++, dst += Step, A += Step) { - vecType mm[Step / VECTOR_LEN]; - for (int n = 0; n < Step / VECTOR_LEN; n++) { - mm[n] = VECTOR_SET(agg.init()); - } - - for (int i = 0; i < dimM; i++) { - vecType *a = (vecType*)(A + i * lda); - for (int n = 0; n < Step / VECTOR_LEN; n++) { - mm[n] = agg.vecOp(mm[n], op.vecOp(a[n])); - } - } - - vecType *result = (vecType*)(dst); - for (int n = 0; n < Step / VECTOR_LEN; n++) { - result[n] = sv.vecOp(result[n], mm[n]); - } - } - - int remRow = dimN % Step; - if (remRow) { - hl_sse_column_op_with_rem(agg, op, sv, dimM, remRow, dst, A, lda); - } -} - -template -void hl_sse_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda) { - if (dimN <= 16) { - hl_sse_matrix_column_op<16>(agg, op, sv, dimM, dimN, dst, A, lda); - } else if (dimN <= 32) { - hl_sse_matrix_column_op<32>(agg, op, sv, dimM, dimN, dst, A, lda); - } else if (dimN <= 1024 || dimM <= 512) { - hl_sse_matrix_column_op<64>(agg, op, sv, dimM, dimN, dst, A, lda); - } else { - hl_sse_matrix_column_op<1024>(agg, op, sv, dimM, dimN, dst, A, lda); - } -} - -template -void hl_sse_column_op_with_rem(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda, - real *B, int ldb) { - vecType mm[MaxRow / VECTOR_LEN]; - for (int n = 0; n < MaxRow / VECTOR_LEN; n++) { - mm[n] = VECTOR_SET(agg.init()); - } - - for (int i = 0; i < dimM; i++) { - vecType *a = (vecType*)(A + i * lda); - vecType *b = (vecType*)(B + i * ldb); - for (int n = 0; n < dimN / VECTOR_LEN; n++) { - mm[n] = agg.vecOp(mm[n], op.vecOp(a[n], b[n])); - } - } - - vecType *result = (vecType*)(dst); - for (int n = 0; n < dimN / VECTOR_LEN; n++) { - result[n] = sv.vecOp(result[n], mm[n]); - } - - int rem = dimN % VECTOR_LEN; - if (rem) { - A += (dimN / VECTOR_LEN) * VECTOR_LEN; - B += (dimN / VECTOR_LEN) * VECTOR_LEN; - dst += (dimN / VECTOR_LEN) * VECTOR_LEN; - hl_matrix_column_op(agg, op, sv, dimM, rem, dst, A, lda, B, ldb); - } -} - -template -void hl_sse_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda, - real *B, int ldb) { - for (int j = 0; j < dimN / Step; j++, dst += Step, A += Step, B += Step) { - vecType mm[Step / VECTOR_LEN]; - for (int n = 0; n < Step / VECTOR_LEN; n++) { - mm[n] = VECTOR_SET(agg.init()); - } - - for (int i = 0; i < dimM; i++) { - vecType *a = (vecType*)(A + i * lda); - vecType *b = (vecType*)(B + i * ldb); - for (int n = 0; n < Step / VECTOR_LEN; n++) { - mm[n] = agg.vecOp(mm[n], op.vecOp(a[n], b[n])); - } - } - - vecType *result = (vecType*)(dst); - for (int n = 0; n < Step / VECTOR_LEN; n++) { - result[n] = sv.vecOp(result[n], mm[n]); - } - } - - int remRow = dimN % Step; - if (remRow) { - hl_sse_column_op_with_rem( - agg, op, sv, dimM, remRow, dst, A, lda, B, ldb); - } -} - -template -void hl_sse_matrix_column_op(Agg agg, Op op, Saver sv, - int dimM, int dimN, - real *dst, - real *A, int lda, - real *B, int ldb) { - if (dimN <= 16) { - hl_sse_matrix_column_op<16>(agg, op, sv, dimM, dimN, dst, A, lda, B, ldb); - } else if (dimN <= 32) { - hl_sse_matrix_column_op<32>(agg, op, sv, dimM, dimN, dst, A, lda, B, ldb); - } else if (dimN <= 1024 || dimM <= 512) { - hl_sse_matrix_column_op<64>(agg, op, sv, dimM, dimN, dst, A, lda, B, ldb); - } else { - hl_sse_matrix_column_op<1024>(agg, op, sv, dimM, dimN, dst, A, lda, B, ldb); - } -} - -#endif /* HL_NEON_MATRIX_KERNEL_CUH_ */ -- GitLab