提交 d651a911 编写于 作者: T tensor-tang

fix build on win, fix use condition of crf decoding and layer norm and

enhance test precision

test=develop
上级 141ebcd4
......@@ -20,4 +20,6 @@ endif()
cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper)
if(NOT WIN32)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper)
endif()
......@@ -22,11 +22,16 @@ namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
// Note: intrinsic code is not runtime build.
// For example, if you build code on AVX, and run on AVX512 it can only use AVX
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num) {
const int step_size =
platform::MayIUse(platform::avx512f) ? ZMM_FLOAT_BLOCK : YMM_FLOAT_BLOCK;
#ifdef __AVX512F__
const int step_size = ZMM_FLOAT_BLOCK;
#else
const int step_size = YMM_FLOAT_BLOCK;
#endif
const int end = tag_num / step_size;
const int rest = tag_num % step_size;
/* Setup the alpha initial value.*/
......@@ -157,7 +162,12 @@ void CRFDecoding(const int seq_len, const float* x, const float* w,
}
bool CRFDecodingKernel::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
#ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK;
#else
constexpr int block = YMM_FLOAT_BLOCK;
#endif
return platform::MayIUse(platform::avx) && d >= block;
}
} // namespace intrinsic
......
......@@ -154,7 +154,7 @@ void LayerNorm(float* x, float* out, float* mean, float* var,
}
bool LayerNormKernel::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
}
} // namespace intrinsic
......
......@@ -37,7 +37,7 @@ template <typename T>
void ExpectEQ(const T* target, const T* refer, int n) {
if (std::is_floating_point<T>::value) {
for (int i = 0; i < n; ++i) {
EXPECT_NEAR(target[i], refer[i], 1e-3);
EXPECT_NEAR(target[i], refer[i], 1e-5);
}
} else {
for (int i = 0; i < n; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册