提交 676995c8 编写于 作者: Y Yihua Xu 提交者: tensor-tang

Optimze Gelu with MKL Erf function (#15770)

* Optimize for gelu operator

* Set up the low accuracy mode of MKL ERF function.

test=develop

* Only enable MKLML ERF when OS is linux

* Use the speical mklml version included vmsErf function to verify gelu mkl kernel.

test=develop

* Add the CUDA macro to avoid NVCC's compile issue.

test=develop

* Add the TODO comments for mklml library modification.

test=develop

* Clean Code

test=develop

* Add the comment of marco for NVCC compiler.

test=develop
上级 c4faf36e
...@@ -39,8 +39,10 @@ IF(WIN32) ...@@ -39,8 +39,10 @@ IF(WIN32)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
ELSE() ELSE()
SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) #TODO(intel-huying):
# Now enable Erf function in mklml library temporarily, it will be updated as offical version later.
SET(MKLML_VER "VsErf_mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
...@@ -24,6 +25,7 @@ limitations under the License. */ ...@@ -24,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -301,8 +303,28 @@ template <typename T> ...@@ -301,8 +303,28 @@ template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> { struct GeluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
// Because the execute or device context can not be deliver here, it keep the
// marco for NVCC.
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
auto x_data = x.data();
auto out_data = out.data();
int n = std::min(x.size(), out.size());
std::memset(out_data, 0, n * sizeof(T));
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
for (int i = 0; i < n; i++) {
out_data[i] += static_cast<T>(1);
}
math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
for (int i = 0; i < n; i++) {
out_data[i] *= static_cast<T>(0.5);
}
#else
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf(); auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp); out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
#endif
} }
}; };
......
...@@ -184,6 +184,9 @@ class Blas { ...@@ -184,6 +184,9 @@ class Blas {
template <typename T> template <typename T>
void VINV(int n, const T* a, T* y) const; void VINV(int n, const T* a, T* y) const;
template <typename T>
void VMERF(int n, const T* a, T* y, int64_t mode) const;
private: private:
const DeviceContext& context_; const DeviceContext& context_;
}; };
...@@ -290,6 +293,11 @@ class BlasT : private Blas<DeviceContext> { ...@@ -290,6 +293,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VINV<T>(args...); Base()->template VINV<T>(args...);
} }
template <typename... ARGS>
void VMERF(ARGS... args) const {
Base()->template VMERF<T>(args...);
}
private: private:
const Blas<DeviceContext>* Base() const { const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this); return static_cast<const Blas<DeviceContext>*>(this);
......
...@@ -123,6 +123,11 @@ struct CBlas<float> { ...@@ -123,6 +123,11 @@ struct CBlas<float> {
static void VINV(ARGS... args) { static void VINV(ARGS... args) {
platform::dynload::vsInv(args...); platform::dynload::vsInv(args...);
} }
template <typename... ARGS>
static void VMERF(ARGS... args) {
platform::dynload::vmsErf(args...);
}
}; };
template <> template <>
...@@ -223,6 +228,11 @@ struct CBlas<double> { ...@@ -223,6 +228,11 @@ struct CBlas<double> {
static void VINV(ARGS... args) { static void VINV(ARGS... args) {
platform::dynload::vdInv(args...); platform::dynload::vdInv(args...);
} }
template <typename... ARGS>
static void VMERF(ARGS... args) {
platform::dynload::vmdErf(args...);
}
}; };
#else #else
...@@ -625,6 +635,19 @@ void Blas<DeviceContext>::VINV(int n, const T *a, T *y) const { ...@@ -625,6 +635,19 @@ void Blas<DeviceContext>::VINV(int n, const T *a, T *y) const {
#endif #endif
} }
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VMERF(int n, const T *a, T *y,
int64_t mode) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMERF(n, a, y, mode);
#else
for (int i = 0; i < n; ++i) {
y[i] = std::erf(a[i]);
}
#endif
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -86,6 +86,8 @@ extern void* mklml_dso_handle; ...@@ -86,6 +86,8 @@ extern void* mklml_dso_handle;
__macro(vdPowx); \ __macro(vdPowx); \
__macro(vsInv); \ __macro(vsInv); \
__macro(vdInv); \ __macro(vdInv); \
__macro(vmsErf); \
__macro(vmdErf); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册