search_compute.h 5.4 KB
Newer Older
A
Aurelius84 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#if !defined(PADDLE_WITH_ARM)
A
Aurelius84 已提交
18
#include <immintrin.h>
19
#endif
A
Aurelius84 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
#include <cfloat>
#include <cmath>
#include <cstring>

#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;

template <typename DeviceContext, typename T>
void call_gemm(const math::BlasT<DeviceContext, T>& blas,
               const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
               const int M, const int N, const int K, const T alpha, const T* A,
               const T* B, const T beta, T* C) {
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}

template <typename T>
void call_gemm(const framework::ExecutionContext& ctx,
               const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
               const int M, const int N, const int K, const T alpha, const T* A,
               const T* B, const T beta, T* C) {
  int lda = (TransA == CblasNoTrans) ? K : M;
  int ldb = (TransB == CblasNoTrans) ? N : K;
  auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
  blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}

template <typename DeviceContext, typename T>
void call_gemm_with_lda(const math::BlasT<DeviceContext, T>& blas,
                        const CBLAS_TRANSPOSE TransA,
                        const CBLAS_TRANSPOSE TransB, const int M, const int N,
                        const int K, const T alpha, const T* A, const T* B,
                        const T beta, T* C, int lda) {
  int ldb = (TransB == CblasNoTrans) ? N : K;

  blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}

template <typename T>
void call_gemm_batched(const framework::ExecutionContext& ctx,
                       const CBLAS_TRANSPOSE TransA,
                       const CBLAS_TRANSPOSE TransB, const int M, const int N,
                       const int K, const T alpha, const T** A, const T** B,
                       const T beta, T** C, const int batch) {
  for (int i = 0; i < batch; ++i) {
    call_gemm(ctx, TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i]);
  }
}

77 78
#if !defined(PADDLE_WITH_ARM)

A
Aurelius84 已提交
79 80 81 82 83 84 85 86 87 88 89
#define __m256x __m256

static const unsigned int AVX_STEP_SIZE = 8;
static const unsigned int AVX_CUT_LEN_MASK = 7U;

#define _mm256_mul_px _mm256_mul_ps
#define _mm256_add_px _mm256_add_ps
#define _mm256_load_px _mm256_loadu_ps
#define _mm256_store_px _mm256_storeu_ps
#define _mm256_broadcast_sx _mm256_broadcast_ss

90
#define __m128x __m128
91

92 93 94 95 96 97 98 99 100
static const unsigned int SSE_STEP_SIZE = 2;
static const unsigned int SSE_CUT_LEN_MASK = 1U;

#define _mm_add_px _mm_add_ps
#define _mm_mul_px _mm_mul_ps
#define _mm_load_px _mm_loadu_ps
#define _mm_store_px _mm_storeu_ps
#define _mm_load1_px _mm_load1_ps

101 102
#endif

103 104
template <typename T>
inline void axpy(const T* x, T* y, size_t len, const T alpha) {
A
Aurelius84 已提交
105 106 107
  unsigned int jjj, lll;
  jjj = lll = 0;

108
#ifdef PADDLE_WITH_AVX
A
Aurelius84 已提交
109 110 111 112 113 114 115 116
  lll = len & ~AVX_CUT_LEN_MASK;
  __m256x mm_alpha = _mm256_broadcast_sx(&alpha);
  for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
    _mm256_store_px(
        y + jjj,
        _mm256_add_px(_mm256_load_px(y + jjj),
                      _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))));
  }
117 118
#elif defined(PADDLE_WITH_ARM)
  PADDLE_THROW(platform::errors::Unimplemented("axpy is not supported"));
119 120 121 122 123 124 125
#else
  lll = len & ~SSE_CUT_LEN_MASK;
  __m128x mm_alpha = _mm_load1_px(&alpha);
  for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
    _mm_store_px(y + jjj,
                 _mm_add_px(_mm_load_px(y + jjj),
                            _mm_mul_px(mm_alpha, _mm_load_px(x + jjj))));
A
Aurelius84 已提交
126 127
  }

128
#endif
129 130 131 132 133 134

  for (; jjj < len; jjj++) {
    y[jjj] += alpha * x[jjj];
  }
}

135 136
template <typename T>
inline void axpy_noadd(const T* x, T* y, size_t len, const T alpha) {
A
Aurelius84 已提交
137 138 139
  unsigned int jjj, lll;
  jjj = lll = 0;

140
#ifdef PADDLE_WITH_AVX
A
Aurelius84 已提交
141 142 143 144 145
  lll = len & ~AVX_CUT_LEN_MASK;
  __m256x mm_alpha = _mm256_broadcast_sx(&alpha);
  for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
    _mm256_store_px(y + jjj, _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj)));
  }
146 147
#elif defined(PADDLE_WITH_ARM)
  PADDLE_THROW(platform::errors::Unimplemented("axpy_noadd is not supported"));
148 149 150 151 152 153 154 155
#else
  lll = len & ~SSE_CUT_LEN_MASK;
  __m128x mm_alpha = _mm_load1_px(&alpha);
  for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
    _mm_store_px(y + jjj, _mm_mul_px(mm_alpha, _mm_load_px(x + jjj)));
  }

#endif
A
Aurelius84 已提交
156 157 158 159 160

  for (; jjj < len; jjj++) {
    y[jjj] = alpha * x[jjj];
  }
}
161 162 163

inline void axpy_noadd(const int8_t* x, int8_t* y, size_t len,
                       const float alpha) {
164
  PADDLE_THROW(platform::errors::Unimplemented(
165
      "int8_t input of axpy_noadd is not supported"));
166
}
A
Aurelius84 已提交
167

A
Aurelius84 已提交
168 169
}  // namespace operators
}  // namespace paddle