Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
64d5b438
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
64d5b438
编写于
10月 24, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix crf decode avx512
上级
21487d78
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
11 addition
and
12 deletion
+11
-12
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
+11
-12
未找到文件。
paddle/fluid/operators/math/jit_kernel_crf_decode.cc
浏览文件 @
64d5b438
...
@@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -156,17 +156,16 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
} \
} \
}
}
#define INTRIAVX2_FLOAT(
block)
\
#define INTRIAVX2_FLOAT(
isa, block)
\
template <> \
template <> \
CRFDecodeKernelImpl<float, jit::avx2, block>::CRFDecodeKernelImpl( \
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
int tag_num) \
: CRFDecodeKernel<float>() { \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->num_ = tag_num; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
} \
} \
template <> \
template <> \
void CRFDecodeKernelImpl<float,
jit::avx2, block>::Compute(
\
void CRFDecodeKernelImpl<float,
isa, block>::Compute(
\
const int seq_len, const float* x, const float* w, float* alpha, \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
int* track) const { \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \
INIT_ALPHA(AVX2_FLOAT_BLOCK) \
...
@@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -224,7 +223,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
int j_offset = 0; \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/
\
/* Initialize the variables of maximum score and location.*/
\
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<
T>::max());
\
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<
float>::max());
\
__m512i max_j = _mm512_setzero_si512(); \
__m512i max_j = _mm512_setzero_si512(); \
/* Calculate the offset of transition_weights.*/
\
/* Calculate the offset of transition_weights.*/
\
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
...
@@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
...
@@ -245,7 +244,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
__m512 x_content = \
__m512 x_content = \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm512_add_ps(max_score, x_content); \
max_score = _mm512_add_ps(max_score, x_content); \
_mm512_storeu_ps(alpha
_value + seq_offset + this->tag_num_ + j_offset,
\
_mm512_storeu_ps(alpha
+ seq_offset + this->num_ + j_offset,
\
max_score); \
max_score); \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
this->num_ + j_offset), \
this->num_ + j_offset), \
...
@@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
...
@@ -271,14 +270,14 @@ INTRIAVX_FLOAT(kEQ16);
INTRIAVX_FLOAT
(
kGT16
);
INTRIAVX_FLOAT
(
kGT16
);
#endif
#endif
#ifdef __AVX2__
#ifdef __AVX2__
INTRIAVX2_FLOAT
(
kEQ8
);
INTRIAVX2_FLOAT
(
jit
::
avx2
,
kEQ8
);
INTRIAVX2_FLOAT
(
kGT8LT16
);
INTRIAVX2_FLOAT
(
jit
::
avx2
,
kGT8LT16
);
INTRIAVX2_FLOAT
(
kEQ16
);
INTRIAVX2_FLOAT
(
jit
::
avx2
,
kEQ16
);
INTRIAVX2_FLOAT
(
kGT16
);
INTRIAVX2_FLOAT
(
jit
::
avx2
,
kGT16
);
#endif
#endif
#ifdef __AVX512F__
#ifdef __AVX512F__
INTRIAVX2_FLOAT
(
kEQ8
);
INTRIAVX2_FLOAT
(
jit
::
avx512f
,
kEQ8
);
INTRIAVX2_FLOAT
(
kGT8LT16
);
INTRIAVX2_FLOAT
(
jit
::
avx512f
,
kGT8LT16
);
INTRIAVX512_FLOAT
(
kEQ16
);
INTRIAVX512_FLOAT
(
kEQ16
);
INTRIAVX512_FLOAT
(
kGT16
);
INTRIAVX512_FLOAT
(
kGT16
);
#endif
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录