Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3e01a404
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e01a404
编写于
12月 28, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add refer seqpool jitkernel
上级
796322d3
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
45 addition
and
0 deletion
+45
-0
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+20
-0
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+6
-0
paddle/fluid/operators/jit/refer/CMakeLists.txt
paddle/fluid/operators/jit/refer/CMakeLists.txt
+1
-0
paddle/fluid/operators/jit/refer/refer.cc
paddle/fluid/operators/jit/refer/refer.cc
+2
-0
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+16
-0
未找到文件。
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
3e01a404
...
@@ -41,6 +41,7 @@ typedef enum {
...
@@ -41,6 +41,7 @@ typedef enum {
kCRFDecoding
,
kCRFDecoding
,
kLayerNorm
,
kLayerNorm
,
kNCHW16CMulNC
,
kNCHW16CMulNC
,
kSeqPool
,
}
KernelType
;
}
KernelType
;
template
<
typename
T
>
template
<
typename
T
>
...
@@ -112,6 +113,25 @@ struct GRUTuples {
...
@@ -112,6 +113,25 @@ struct GRUTuples {
typedef
void
(
*
func_type
)(
gru_t
*
,
const
gru_attr_t
*
);
typedef
void
(
*
func_type
)(
gru_t
*
,
const
gru_attr_t
*
);
};
};
typedef
enum
{
non
=
0
,
sum
,
avg
,
sqrt
,
}
SeqPoolType
;
typedef
struct
{
int
h
,
w
;
SeqPoolType
type
;
}
seq_pool_attr_t
;
template
<
typename
T
>
struct
SeqPoolTuples
{
typedef
T
data_type
;
typedef
seq_pool_attr_t
attr_type
;
typedef
void
(
*
func_type
)(
const
T
*
,
T
*
,
const
seq_pool_attr_t
*
);
};
template
<
typename
T
>
template
<
typename
T
>
struct
CRFDecodingTuples
{
struct
CRFDecodingTuples
{
typedef
T
data_type
;
typedef
T
data_type
;
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
3e01a404
...
@@ -42,6 +42,12 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
...
@@ -42,6 +42,12 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
(
static_cast
<
int
>
(
attr
.
act_cand
)
<<
act_type_shift
);
(
static_cast
<
int
>
(
attr
.
act_cand
)
<<
act_type_shift
);
}
}
template
<
>
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
key
=
static_cast
<
size_t
>
(
attr
.
type
);
return
key
+
(
attr
.
w
<<
act_type_shift
);
}
}
// namespace jit
}
// namespace jit
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/jit/refer/CMakeLists.txt
浏览文件 @
3e01a404
...
@@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2)
...
@@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER
(
kCRFDecoding
)
USE_JITKERNEL_REFER
(
kCRFDecoding
)
USE_JITKERNEL_REFER
(
kLayerNorm
)
USE_JITKERNEL_REFER
(
kLayerNorm
)
USE_JITKERNEL_REFER
(
kNCHW16CMulNC
)
USE_JITKERNEL_REFER
(
kNCHW16CMulNC
)
USE_JITKERNEL_REFER
(
kSeqPool
)
paddle/fluid/operators/jit/refer/refer.cc
浏览文件 @
3e01a404
...
@@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
...
@@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL
(
kNCHW16CMulNC
,
NCHW16CMulNC
);
REGISTER_REFER_KERNEL
(
kNCHW16CMulNC
,
NCHW16CMulNC
);
REGISTER_REFER_KERNEL
(
kSeqPool
,
SeqPool
);
#undef REGISTER_REFER_KERNEL
#undef REGISTER_REFER_KERNEL
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
3e01a404
...
@@ -332,6 +332,20 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
...
@@ -332,6 +332,20 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
}
}
}
}
template
<
typename
T
>
void
SeqPool
(
const
T
*
x
,
T
*
y
,
const
seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE
(
attr
->
type
==
SeqPoolType
::
sum
,
"Only support sum yet"
);
for
(
int
w
=
0
;
w
<
attr
->
w
;
++
w
)
{
const
T
*
src
=
x
+
w
;
T
*
dst
=
y
+
w
;
*
dst
=
static_cast
<
T
>
(
0
);
for
(
int
h
=
0
;
h
<
attr
->
h
;
++
h
)
{
*
dst
=
*
dst
+
*
src
;
src
+=
attr
->
w
;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
class name##Kernel : public ReferKernel<tuples<T>> { \
...
@@ -370,6 +384,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
...
@@ -370,6 +384,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL
(
NCHW16CMulNC
,
NCHW16CMulNCTuples
);
DECLARE_REFER_KERNEL
(
NCHW16CMulNC
,
NCHW16CMulNCTuples
);
DECLARE_REFER_KERNEL
(
SeqPool
,
SeqPoolTuples
);
#undef DECLARE_REFER_KERNEL
#undef DECLARE_REFER_KERNEL
}
// namespace refer
}
// namespace refer
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录