提交 64a8e6d2 编写于 作者: T tensor-tang

refine the threshold functions

上级 32822b2a
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <limits>
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -161,6 +162,25 @@ struct CBlas<platform::float16> { ...@@ -161,6 +162,25 @@ struct CBlas<platform::float16> {
} }
#endif #endif
}; };
template <typename T>
inline static bool UseXSMM(const int &m, const int &n, const int &k,
bool transa, bool transb, const T &alpha,
const T &beta) {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
std::abs<T>(alpha - static_cast<T>(1) >
std::numeric_limits<T>::epsilon()) ||
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
return false;
} else {
return true;
}
#endif
return false;
}
template <> template <>
template <typename T> template <typename T>
...@@ -172,8 +192,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -172,8 +192,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
int ldb = (transB == CblasNoTrans) ? N : K; int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N; int ldc = N;
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans && if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
transB == CblasNoTrans) { beta)) {
// refer to https://github.com/hfp/libxsmm/blob/master/README.md // refer to https://github.com/hfp/libxsmm/blob/master/README.md
// Note: SMM use ColMajor // Note: SMM use ColMajor
const char transa = 'N'; const char transa = 'N';
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册