未验证 提交 5e6848d9 编写于 作者: W Wilber 提交者: GitHub

[Cherry-pick] Support compile for arm ft (#25241)

上级 4d8c10ae
......@@ -88,6 +88,7 @@ option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE}
option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF)
option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF)
option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON)
option(WITH_ARM "Compile PaddlePaddle with arm support" OFF)
# PY_VERSION
if(NOT PY_VERSION)
......@@ -199,6 +200,12 @@ if(WITH_AMD_GPU)
include(hip)
endif(WITH_AMD_GPU)
if(WITH_ARM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
add_definitions(-DPADDLE_WITH_ARM)
endif()
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
......
......@@ -19,6 +19,9 @@ SET(CBLAS_SOURCE_DIR ${THIRD_PARTY_PATH}/openblas/src/extern_openblas)
SET(CBLAS_INSTALL_DIR ${THIRD_PARTY_PATH}/install/openblas)
SET(CBLAS_REPOSITORY https://github.com/xianyi/OpenBLAS.git)
SET(CBLAS_TAG v0.3.7)
IF(WITH_ARM)
SET(CBLAS_TAG v0.2.18)
ENDIF()
cache_third_party(extern_openblas
REPOSITORY ${CBLAS_REPOSITORY}
TAG ${CBLAS_TAG}
......
......@@ -187,7 +187,7 @@ set(GPU_COMMON_FLAGS
-Wno-error=unused-function # Warnings in Numpy Header.
-Wno-error=array-bounds # Warnings in Eigen::array
)
if (NOT WITH_NV_JETSON)
if (NOT WITH_NV_JETSON AND NOT WITH_ARM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64")
endif()
endif(NOT WIN32)
......
......@@ -288,8 +288,8 @@ class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
auto* r_data = bottom_r_data + (offset_r[b] + j) * dim_in;
auto* r_diff = bottom_r_diff + (offset_r[b] + j) * dim_in;
if (diff != 0.0) {
avx_axpy(r_data, l_trans_diff, dim_in, diff);
avx_axpy(l_trans_data, r_diff, dim_in, diff);
axpy(r_data, l_trans_diff, dim_in, diff);
axpy(l_trans_data, r_diff, dim_in, diff);
}
}
}
......
......@@ -385,7 +385,7 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
}
auto weight_type = _blobs_0->type();
if (_is_training == 0 && weight_type != framework::proto::VarType::INT8) {
avx_axpy_noadd(top_data, top_data, top->dims()[0] * top->dims()[1],
axpy_noadd(top_data, top_data, top->dims()[0] * top->dims()[1],
_drop_out_percent);
}
}
......@@ -451,7 +451,7 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
int _space_len) const {
for (int j = 0; j != _num_emb; j += _rand_len) {
unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len;
avx_axpy(top_pos + j, weights + pos, _rand_len, mlr);
axpy(top_pos + j, weights + pos, _rand_len, mlr);
}
}
......@@ -525,9 +525,7 @@ REGISTER_OPERATOR(pyramid_hash_grad, ops::PyramidHashOpGrad);
REGISTER_OP_CPU_KERNEL(
pyramid_hash, ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, float>,
ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, double>,
ops::CPUPyramidHashOPKernel<plt::CPUDeviceContext, int8_t>);
REGISTER_OP_CPU_KERNEL(
pyramid_hash_grad,
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, float>,
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, double>);
ops::CPUPyramidHashOPGradKernel<plt::CPUDeviceContext, float>);
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#if !defined(PADDLE_WITH_ARM)
#include <immintrin.h>
#endif
#include <cfloat>
#include <cmath>
#include <cstring>
......@@ -72,6 +74,8 @@ void call_gemm_batched(const framework::ExecutionContext& ctx,
}
}
#if !defined(PADDLE_WITH_ARM)
#define __m256x __m256
static const unsigned int AVX_STEP_SIZE = 8;
......@@ -83,16 +87,25 @@ static const unsigned int AVX_CUT_LEN_MASK = 7U;
#define _mm256_store_px _mm256_storeu_ps
#define _mm256_broadcast_sx _mm256_broadcast_ss
#define _mm256_mul_pd _mm256_mul_pd
#define _mm256_add_pd _mm256_add_pd
#define _mm256_load_pd _mm256_loadu_pd
#define _mm256_store_pd _mm256_storeu_pd
#define _mm256_broadcast_sd _mm256_broadcast_sd
#define __m128x __m128
static const unsigned int SSE_STEP_SIZE = 2;
static const unsigned int SSE_CUT_LEN_MASK = 1U;
#define _mm_add_px _mm_add_ps
#define _mm_mul_px _mm_mul_ps
#define _mm_load_px _mm_loadu_ps
#define _mm_store_px _mm_storeu_ps
#define _mm_load1_px _mm_load1_ps
inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) {
#endif
template <typename T>
inline void axpy(const T* x, T* y, size_t len, const T alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
#ifdef PADDLE_WITH_AVX
lll = len & ~AVX_CUT_LEN_MASK;
__m256x mm_alpha = _mm256_broadcast_sx(&alpha);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
......@@ -101,66 +114,55 @@ inline void avx_axpy(const float* x, float* y, size_t len, const float alpha) {
_mm256_add_px(_mm256_load_px(y + jjj),
_mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj))));
}
for (; jjj < len; jjj++) {
y[jjj] += alpha * x[jjj];
#elif defined(PADDLE_WITH_ARM)
PADDLE_THROW(platform::errors::Unimplemented("axpy is not supported"));
#else
lll = len & ~SSE_CUT_LEN_MASK;
__m128x mm_alpha = _mm_load1_px(&alpha);
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
_mm_store_px(y + jjj,
_mm_add_px(_mm_load_px(y + jjj),
_mm_mul_px(mm_alpha, _mm_load_px(x + jjj))));
}
}
inline void avx_axpy(const double* x, double* y, size_t len,
const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
lll = len & ~AVX_CUT_LEN_MASK;
double alpha_d = static_cast<double>(alpha);
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_pd(
y + jjj,
_mm256_add_pd(_mm256_load_pd(y + jjj),
_mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj))));
}
#endif
for (; jjj < len; jjj++) {
y[jjj] += alpha * x[jjj];
}
}
inline void avx_axpy_noadd(const double* x, double* y, size_t len,
const float alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
double alpha_d = static_cast<double>(alpha);
lll = len & ~AVX_CUT_LEN_MASK;
__m256d mm_alpha = _mm256_broadcast_sd(&alpha_d);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_pd(y + jjj, _mm256_mul_pd(mm_alpha, _mm256_load_pd(x + jjj)));
}
for (; jjj < len; jjj++) {
y[jjj] = alpha * x[jjj];
}
}
inline void avx_axpy_noadd(const float* x, float* y, size_t len,
const float alpha) {
template <typename T>
inline void axpy_noadd(const T* x, T* y, size_t len, const T alpha) {
unsigned int jjj, lll;
jjj = lll = 0;
#ifdef PADDLE_WITH_AVX
lll = len & ~AVX_CUT_LEN_MASK;
__m256x mm_alpha = _mm256_broadcast_sx(&alpha);
for (jjj = 0; jjj < lll; jjj += AVX_STEP_SIZE) {
_mm256_store_px(y + jjj, _mm256_mul_px(mm_alpha, _mm256_load_px(x + jjj)));
}
#elif defined(PADDLE_WITH_ARM)
PADDLE_THROW(platform::errors::Unimplemented("axpy_noadd is not supported"));
#else
lll = len & ~SSE_CUT_LEN_MASK;
__m128x mm_alpha = _mm_load1_px(&alpha);
for (jjj = 0; jjj < lll; jjj += SSE_STEP_SIZE) {
_mm_store_px(y + jjj, _mm_mul_px(mm_alpha, _mm_load_px(x + jjj)));
}
#endif
for (; jjj < len; jjj++) {
y[jjj] = alpha * x[jjj];
}
}
inline void avx_axpy_noadd(const int8_t* x, int8_t* y, size_t len,
inline void axpy_noadd(const int8_t* x, int8_t* y, size_t len,
const float alpha) {
PADDLE_THROW(platform::errors::Unimplemented(
"int8_t input of avx_axpy_noadd is not supported"));
"int8_t input of axpy_noadd is not supported"));
}
} // namespace operators
......
......@@ -139,6 +139,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
if (cpu_isa == isa_any) {
return true;
} else {
#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_ARM)
int reg[4];
cpuid(reg, 0);
int nIds = reg[0];
......@@ -168,6 +169,7 @@ bool MayIUse(const cpu_isa_t cpu_isa) {
}
}
#endif
#endif
} // namespace platform
} // namespace paddle
......@@ -40,12 +40,14 @@ limitations under the License. */
#ifdef _WIN32
#define cpuid(reg, x) __cpuidex(reg, x, 0)
#else
#if !defined(WITH_NV_JETSON) && !defined(PADDLE_WITH_ARM)
#include <cpuid.h>
inline void cpuid(int reg[4], int x) {
__cpuid_count(x, 0, reg[0], reg[1], reg[2], reg[3]);
}
#endif
#endif
#endif
namespace paddle {
namespace platform {
......
......@@ -6,6 +6,7 @@ import shutil
import sys
import fnmatch
import errno
import platform
from contextlib import contextmanager
from setuptools import Command
......@@ -301,6 +302,7 @@ if '${CMAKE_BUILD_TYPE}' == 'Release':
command = "install_name_tool -id \"@loader_path/../libs/\" ${PADDLE_BINARY_DIR}/python/paddle/fluid/${FLUID_CORE_NAME}" + '.so'
else:
command = "patchelf --set-rpath '$ORIGIN/../libs/' ${PADDLE_BINARY_DIR}/python/paddle/fluid/${FLUID_CORE_NAME}" + '.so'
if platform.machine() != 'aarch64':
if os.system(command) != 0:
raise Exception("patch ${FLUID_CORE_NAME}.%s failed, command: %s" % (ext_name, command))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册