提交 64a8e6d2 编写于 作者: T tensor-tang

refine the threshold functions

上级 32822b2a
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <vector>
#include "paddle/fluid/operators/math/math_function.h"
......@@ -161,6 +162,25 @@ struct CBlas<platform::float16> {
}
#endif
};
template <typename T>
inline static bool UseXSMM(const int &m, const int &n, const int &k,
bool transa, bool transb, const T &alpha,
const T &beta) {
#ifdef PADDLE_WITH_LIBXSMM
// Refer to https://github.com/hfp/libxsmm/blob/master/README.md
// But the threshold is custom
constexpr int LIBXSMM_THRESHOLD = 20 * 20 * 20;
if (m * n * k > LIBXSMM_THRESHOLD || transa || transb ||
std::abs<T>(alpha - static_cast<T>(1) >
std::numeric_limits<T>::epsilon()) ||
std::abs<T>(beta) > std::numeric_limits<T>::epsilon()) {
return false;
} else {
return true;
}
#endif
return false;
}
template <>
template <typename T>
......@@ -172,8 +192,8 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
#ifdef PADDLE_WITH_LIBXSMM
if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans &&
transB == CblasNoTrans) {
if (UseXSMM(M, N, K, transA != CblasNoTrans, transB != CblasNoTrans, alpha,
beta)) {
// refer to https://github.com/hfp/libxsmm/blob/master/README.md
// Note: SMM use ColMajor
const char transa = 'N';
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册