提交 676995c8 编写于 作者: Y Yihua Xu 提交者: tensor-tang

Optimze Gelu with MKL Erf function (#15770)

* Optimize for gelu operator

* Set up the low accuracy mode of MKL ERF function.

test=develop

* Only enable MKLML ERF when OS is linux

* Use the speical mklml version included vmsErf function to verify gelu mkl kernel.

test=develop

* Add the CUDA macro to avoid NVCC's compile issue.

test=develop

* Add the TODO comments for mklml library modification.

test=develop

* Clean Code

test=develop

* Add the comment of marco for NVCC compiler.

test=develop
上级 c4faf36e
......@@ -40,7 +40,9 @@ IF(WIN32)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
ELSE()
SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
#TODO(intel-huying):
# Now enable Erf function in mklml library temporarily, it will be updated as offical version later.
SET(MKLML_VER "VsErf_mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
......@@ -24,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -301,8 +303,28 @@ template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
std::memset(out_data, 0, n * sizeof(T));
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
#endif
}
};
......
......@@ -184,6 +184,9 @@ class Blas {
template <typename T>
void VINV(int n, const T* a, T* y) const;
template <typename T>
void VMERF(int n, const T* a, T* y, int64_t mode) const;
private:
const DeviceContext& context_;
};
......@@ -290,6 +293,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VINV<T>(args...);
}
template <typename... ARGS>
void VMERF(ARGS... args) const {
Base()->template VMERF<T>(args...);
}
private:
const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this);
......
......@@ -123,6 +123,11 @@ struct CBlas<float> {
static void VINV(ARGS... args) {
platform::dynload::vsInv(args...);
}
template <typename... ARGS>
static void VMERF(ARGS... args) {
platform::dynload::vmsErf(args...);
}
};
template <>
......@@ -223,6 +228,11 @@ struct CBlas<double> {
static void VINV(ARGS... args) {
platform::dynload::vdInv(args...);
}
template <typename... ARGS>
static void VMERF(ARGS... args) {
platform::dynload::vmdErf(args...);
}
};
#else
......@@ -625,6 +635,19 @@ void Blas<DeviceContext>::VINV(int n, const T *a, T *y) const {
#endif
}
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y,
int64_t mode) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMERF(n, a, y, mode);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::erf(a[i]);
}
#endif
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -86,6 +86,8 @@ extern void* mklml_dso_handle;
__macro(vdPowx); \
__macro(vsInv); \
__macro(vdInv); \
__macro(vmsErf); \
__macro(vmdErf); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册