Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0145f40f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0145f40f
编写于
1月 05, 2019
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use height from params of jitcode
上级
e0591dee
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
117 addition
and
97 deletion
+117
-97
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+2
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+10
-7
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+90
-72
paddle/fluid/operators/jit/kernel_base.h
paddle/fluid/operators/jit/kernel_base.h
+3
-3
paddle/fluid/operators/jit/kernel_key.cc
paddle/fluid/operators/jit/kernel_key.cc
+2
-4
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+0
-1
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+4
-3
paddle/fluid/operators/math/sequence_pooling.cc
paddle/fluid/operators/math/sequence_pooling.cc
+6
-6
未找到文件。
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
0145f40f
...
@@ -195,8 +195,9 @@ void BenchSeqPoolKernel() {
...
@@ -195,8 +195,9 @@ void BenchSeqPoolKernel() {
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
for
(
auto
type
:
pool_types
)
{
for
(
auto
type
:
pool_types
)
{
for
(
int
w
:
TestSizes
())
{
for
(
int
w
:
TestSizes
())
{
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
TestSizes
())
{
for
(
int
h
:
TestSizes
())
{
const
jit
::
seq_pool_attr_t
attr
(
h
,
w
,
type
)
;
attr
.
h
=
h
;
std
::
vector
<
T
>
x
(
h
*
w
),
y
(
w
);
std
::
vector
<
T
>
x
(
h
*
w
),
y
(
w
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
(),
-
2.
f
,
2.
f
);
RandomVec
<
T
>
(
h
*
w
,
x
.
data
(),
-
2.
f
,
2.
f
);
const
T
*
x_data
=
x
.
data
();
const
T
*
x_data
=
x
.
data
();
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
0145f40f
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
* limitations under the License. */
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
...
@@ -21,20 +22,22 @@ namespace operators {
...
@@ -21,20 +22,22 @@ namespace operators {
namespace
jit
{
namespace
jit
{
namespace
gen
{
namespace
gen
{
thread_local
float
ALIGN32_BEG
float_h
[
1
]
ALIGN32_END
=
{
1.
f
};
// TODO(TJ): try move to private
void
SeqPoolJitCode
::
genCode
()
{
void
SeqPoolJitCode
::
genCode
()
{
constexpr
int
block
=
YMM_FLOAT_BLOCK
;
constexpr
int
block
=
YMM_FLOAT_BLOCK
;
constexpr
int
max_num_regs
=
8
;
constexpr
int
max_num_regs
=
8
;
const
int
num_block
=
w_
/
block
;
const
int
num_block
=
w_
/
block
;
const
int
num_groups
=
num_block
/
max_num_regs
;
const
int
num_groups
=
num_block
/
max_num_regs
;
int
rest_num_regs
=
num_block
%
max_num_regs
;
int
rest_num_regs
=
num_block
%
max_num_regs
;
if
(
type_
==
SeqPoolType
::
kAvg
)
{
mov
(
reg32_int_h
,
dword
[
param_attr
]);
float
scalar
=
1.
f
/
h_
;
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
mov
(
reg
32_scalar
,
scalar
);
mov
(
reg
_tmp
,
reinterpret_cast
<
size_t
>
(
float_h
)
);
}
else
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
fild
(
dword
[
param_attr
]);
f
loat
scalar
=
1.
f
/
std
::
sqrt
(
static_cast
<
float
>
(
h_
)
);
f
stp
(
dword
[
reg_tmp
]
);
mov
(
reg32_
scalar
,
scalar
);
mov
(
reg32_
fp_h
,
dword
[
reg_tmp
]
);
}
}
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
const
int
group_len
=
max_num_regs
*
block
*
sizeof
(
float
);
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
for
(
int
g
=
0
;
g
<
num_groups
;
++
g
)
{
pool_height
<
ymm_t
>
(
g
*
group_len
,
block
,
max_num_regs
);
pool_height
<
ymm_t
>
(
g
*
group_len
,
block
,
max_num_regs
);
...
...
paddle/fluid/operators/jit/gen/seqpool.h
浏览文件 @
0145f40f
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <string>
#include <string>
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for ones
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode {
...
@@ -29,7 +30,7 @@ class SeqPoolJitCode : public JitCode {
explicit
SeqPoolJitCode
(
const
seq_pool_attr_t
&
attr
,
explicit
SeqPoolJitCode
(
const
seq_pool_attr_t
&
attr
,
size_t
code_size
=
256
*
1024
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
h_
(
attr
.
h
),
w_
(
attr
.
w
),
type_
(
attr
.
type
)
{
:
JitCode
(
code_size
,
code_ptr
),
w_
(
attr
.
w
),
type_
(
attr
.
type
)
{
if
(
type_
!=
SeqPoolType
::
kSum
)
{
if
(
type_
!=
SeqPoolType
::
kSum
)
{
LOG
(
FATAL
)
<<
"Only support sum pool yet "
;
LOG
(
FATAL
)
<<
"Only support sum pool yet "
;
}
}
...
@@ -55,13 +56,14 @@ class SeqPoolJitCode : public JitCode {
...
@@ -55,13 +56,14 @@ class SeqPoolJitCode : public JitCode {
void
pool_height
(
int
w_offset
,
int
block
,
int
max_num_regs
)
{
void
pool_height
(
int
w_offset
,
int
block
,
int
max_num_regs
)
{
int
offset
=
w_offset
;
int
offset
=
w_offset
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
vmovups
(
JMM
(
i
),
ptr
[
param
1
+
offset
]);
vmovups
(
JMM
(
i
),
ptr
[
param
_src
+
offset
]);
offset
+=
sizeof
(
float
)
*
block
;
offset
+=
sizeof
(
float
)
*
block
;
}
}
if
(
h_
>
1
)
{
cmp
(
reg32_int_h
,
1
);
Label
l_next_h
;
Label
l_next_h
,
l_h_done
;
mov
(
reg_h
,
1
);
jle
(
l_h_done
,
T_NEAR
);
mov
(
reg_tmp
,
param1
);
mov
(
reg_h_i
,
1
);
mov
(
reg_tmp
,
param_src
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
L
(
l_next_h
);
L
(
l_next_h
);
{
{
...
@@ -72,22 +74,30 @@ class SeqPoolJitCode : public JitCode {
...
@@ -72,22 +74,30 @@ class SeqPoolJitCode : public JitCode {
vaddps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
i
+
max_num_regs
));
vaddps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
i
+
max_num_regs
));
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
block
);
add
(
reg_ptr_src_i
,
sizeof
(
float
)
*
block
);
}
}
inc
(
reg_h
);
inc
(
reg_h_i
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
cmp
(
reg_h
,
h_
);
cmp
(
reg_h_i
,
reg32_int_h
);
jl
(
l_next_h
,
T_NEAR
);
jl
(
l_next_h
,
T_NEAR
);
}
}
}
L
(
l_h_done
);
// save right now
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vbroadcastss
(
JMM
(
max_num_regs
),
reg32_scalar
);
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovups
(
JMM
(
max_num_regs
),
ptr
[
reg_tmp
+
OFFSET_EXP_ONE
]);
movd
(
JMM
(
max_num_regs
+
1
),
reg32_fp_h
);
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
vsqrtps
(
JMM
(
max_num_regs
+
1
),
JMM
(
max_num_regs
+
1
));
}
vdivps
(
JMM
(
max_num_regs
+
2
),
JMM
(
max_num_regs
),
JMM
(
max_num_regs
+
1
));
vbroadcastss
(
JMM
(
max_num_regs
),
JMM
(
max_num_regs
+
2
));
// TODO(TJ): fix me
}
}
offset
=
w_offset
;
offset
=
w_offset
;
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_num_regs
;
++
i
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vmulps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
max_num_regs
));
vmulps
(
JMM
(
i
),
JMM
(
i
),
JMM
(
max_num_regs
));
}
}
vmovups
(
ptr
[
param
2
+
offset
],
JMM
(
i
));
vmovups
(
ptr
[
param
_dst
+
offset
],
JMM
(
i
));
offset
+=
sizeof
(
float
)
*
block
;
offset
+=
sizeof
(
float
)
*
block
;
}
}
}
}
...
@@ -97,15 +107,14 @@ class SeqPoolJitCode : public JitCode {
...
@@ -97,15 +107,14 @@ class SeqPoolJitCode : public JitCode {
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block4
=
rest
/
4
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block2
=
(
rest
%
4
)
/
2
>
0
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
if
(
h_
>
1
)
{
cmp
(
reg32_int_h
,
1
);
Label
l_next_h
;
Label
l_next_h
,
l_h_done
;
mov
(
reg_h
,
1
);
jle
(
l_h_done
,
T_NEAR
);
mov
(
reg_tmp
,
param1
);
mov
(
reg_h_i
,
1
);
mov
(
reg_tmp
,
param_src
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
)
+
w_offset
);
L
(
l_next_h
);
L
(
l_next_h
);
{
{
// int used_regs =load_rest(rest, h * w_ * sizeof(float) + w_offset,
// max_num_regs);
int
reg_idx
=
0
;
int
reg_idx
=
0
;
mov
(
reg_ptr_src_i
,
reg_tmp
);
mov
(
reg_ptr_src_i
,
reg_tmp
);
if
(
has_block4
)
{
if
(
has_block4
)
{
...
@@ -127,17 +136,25 @@ class SeqPoolJitCode : public JitCode {
...
@@ -127,17 +136,25 @@ class SeqPoolJitCode : public JitCode {
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
}
inc
(
reg_h
);
inc
(
reg_h_i
);
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
add
(
reg_tmp
,
w_
*
sizeof
(
float
));
cmp
(
reg_h
,
h_
);
cmp
(
reg_h_i
,
reg32_int_h
);
jl
(
l_next_h
,
T_NEAR
);
jl
(
l_next_h
,
T_NEAR
);
}
}
}
L
(
l_h_done
);
// save right now
// save right now
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
if
(
type_
==
SeqPoolType
::
kAvg
||
type_
==
SeqPoolType
::
kSqrt
)
{
vbroadcastss
(
xmm_t
(
max_num_regs
-
1
),
reg32_scalar
);
mov
(
reg_tmp
,
reinterpret_cast
<
size_t
>
(
exp_float_consts
));
vmovups
(
xmm_t
(
max_num_regs
),
ptr
[
reg_tmp
+
OFFSET_EXP_ONE
]);
movd
(
xmm_t
(
max_num_regs
+
1
),
reg32_fp_h
);
if
(
type_
==
SeqPoolType
::
kSqrt
)
{
vsqrtps
(
xmm_t
(
max_num_regs
+
1
),
xmm_t
(
max_num_regs
+
1
));
}
vdivps
(
xmm_t
(
max_num_regs
+
2
),
xmm_t
(
max_num_regs
),
xmm_t
(
max_num_regs
+
1
));
vbroadcastss
(
xmm_t
(
max_num_regs
),
xmm_t
(
max_num_regs
+
2
));
for
(
int
i
=
0
;
i
<
rest_used_num_regs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
rest_used_num_regs
;
++
i
)
{
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
-
1
));
vmulps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
max_num_regs
));
}
}
}
}
save_rest
(
rest
,
w_offset
);
save_rest
(
rest
,
w_offset
);
...
@@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode {
...
@@ -151,17 +168,17 @@ class SeqPoolJitCode : public JitCode {
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
int
reg_idx
=
reg_start
;
int
reg_idx
=
reg_start
;
if
(
has_block4
)
{
if
(
has_block4
)
{
vmovups
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param
1
+
w_offset
]);
vmovups
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param
_src
+
w_offset
]);
w_offset
+=
sizeof
(
float
)
*
4
;
w_offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
reg_idx
++
;
}
}
if
(
has_block2
)
{
if
(
has_block2
)
{
vmovq
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param
1
+
w_offset
]);
vmovq
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param
_src
+
w_offset
]);
w_offset
+=
sizeof
(
float
)
*
2
;
w_offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
reg_idx
++
;
}
}
if
(
has_block1
)
{
if
(
has_block1
)
{
vmovss
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param
1
+
w_offset
]);
vmovss
(
xmm_t
(
reg_idx
+
num_shift_regs
),
ptr
[
param
_src
+
w_offset
]);
reg_idx
++
;
reg_idx
++
;
}
}
return
reg_idx
;
return
reg_idx
;
...
@@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode {
...
@@ -174,32 +191,33 @@ class SeqPoolJitCode : public JitCode {
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
const
bool
has_block1
=
(
rest
%
2
)
==
1
;
int
reg_idx
=
reg_start
;
int
reg_idx
=
reg_start
;
if
(
has_block4
)
{
if
(
has_block4
)
{
vmovups
(
ptr
[
param
2
+
w_offset
],
xmm_t
(
reg_idx
));
vmovups
(
ptr
[
param
_dst
+
w_offset
],
xmm_t
(
reg_idx
));
w_offset
+=
sizeof
(
float
)
*
4
;
w_offset
+=
sizeof
(
float
)
*
4
;
reg_idx
++
;
reg_idx
++
;
}
}
if
(
has_block2
)
{
if
(
has_block2
)
{
vmovq
(
ptr
[
param
2
+
w_offset
],
xmm_t
(
reg_idx
));
vmovq
(
ptr
[
param
_dst
+
w_offset
],
xmm_t
(
reg_idx
));
w_offset
+=
sizeof
(
float
)
*
2
;
w_offset
+=
sizeof
(
float
)
*
2
;
reg_idx
++
;
reg_idx
++
;
}
}
if
(
has_block1
)
{
if
(
has_block1
)
{
vmovss
(
ptr
[
param
2
+
w_offset
],
xmm_t
(
reg_idx
));
vmovss
(
ptr
[
param
_dst
+
w_offset
],
xmm_t
(
reg_idx
));
}
}
}
}
private:
private:
int
h_
;
int
w_
;
int
w_
;
SeqPoolType
type_
;
SeqPoolType
type_
;
reg64_t
param1
{
abi_param1
};
reg64_t
param_src
{
abi_param1
};
reg64_t
param2
{
abi_param2
};
reg64_t
param_dst
{
abi_param2
};
reg64_t
param3
{
abi_param3
};
reg64_t
param_attr
{
abi_param3
};
reg32_t
reg32_scalar
{
r8d
};
reg64_t
reg_tmp
{
rax
};
reg32_t
reg32_int_h
{
r8d
};
reg32_t
reg32_fp_h
{
r9d
};
reg64_t
reg_h
{
r9
};
reg64_t
reg_h_i
{
r10
};
reg64_t
reg_ptr_src_i
{
r10
};
reg64_t
reg_ptr_src_i
{
r11
};
reg64_t
reg_tmp
{
r11
};
};
};
}
// namespace gen
}
// namespace gen
...
...
paddle/fluid/operators/jit/kernel_base.h
浏览文件 @
0145f40f
...
@@ -46,7 +46,7 @@ typedef enum {
...
@@ -46,7 +46,7 @@ typedef enum {
typedef
enum
{
typedef
enum
{
kNonePoolType
=
0
,
kNonePoolType
=
0
,
kSum
,
kSum
=
1
,
kAvg
,
kAvg
,
kSqrt
,
kSqrt
,
}
SeqPoolType
;
}
SeqPoolType
;
...
@@ -121,10 +121,10 @@ struct GRUTuples {
...
@@ -121,10 +121,10 @@ struct GRUTuples {
};
};
typedef
struct
seq_pool_attr_s
{
typedef
struct
seq_pool_attr_s
{
int
h
,
w
;
int
h
,
w
;
// h should always be the first one
SeqPoolType
type
;
SeqPoolType
type
;
seq_pool_attr_s
()
=
default
;
seq_pool_attr_s
()
=
default
;
explicit
seq_pool_attr_s
(
int
height
,
int
width
,
SeqPoolType
pool_type
)
explicit
seq_pool_attr_s
(
int
width
,
SeqPoolType
pool_type
,
int
height
=
1
)
:
h
(
height
),
w
(
width
),
type
(
pool_type
)
{}
:
h
(
height
),
w
(
width
),
type
(
pool_type
)
{}
}
seq_pool_attr_t
;
}
seq_pool_attr_t
;
...
...
paddle/fluid/operators/jit/kernel_key.cc
浏览文件 @
0145f40f
...
@@ -45,10 +45,8 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
...
@@ -45,10 +45,8 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
template
<
>
template
<
>
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
JitCodeKey
<
seq_pool_attr_t
>
(
const
seq_pool_attr_t
&
attr
)
{
size_t
key
=
attr
.
w
;
size_t
key
=
attr
.
w
;
// TODO(TJ): support height, then removed it from key
constexpr
int
pool_type_shift
=
3
;
constexpr
int
w_shift
=
30
;
return
(
key
<<
pool_type_shift
)
+
static_cast
<
int
>
(
attr
.
type
);
return
(
key
<<
act_type_shift
)
+
static_cast
<
int
>
(
attr
.
type
)
+
(
static_cast
<
size_t
>
(
attr
.
h
)
<<
(
act_type_shift
+
w_shift
));
}
}
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
0145f40f
...
@@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
...
@@ -334,7 +334,6 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
template
<
typename
T
>
template
<
typename
T
>
void
SeqPool
(
const
T
*
x
,
T
*
y
,
const
seq_pool_attr_t
*
attr
)
{
void
SeqPool
(
const
T
*
x
,
T
*
y
,
const
seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE
(
attr
->
type
==
SeqPoolType
::
kSum
,
"Only support sum yet"
);
for
(
int
w
=
0
;
w
<
attr
->
w
;
++
w
)
{
for
(
int
w
=
0
;
w
<
attr
->
w
;
++
w
)
{
const
T
*
src
=
x
+
w
;
const
T
*
src
=
x
+
w
;
T
*
dst
=
y
+
w
;
T
*
dst
=
y
+
w
;
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
0145f40f
...
@@ -439,9 +439,10 @@ void TestSeqPoolKernel() {
...
@@ -439,9 +439,10 @@ void TestSeqPoolKernel() {
// TODO(TJ): support more
// TODO(TJ): support more
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
std
::
vector
<
jit
::
SeqPoolType
>
pool_types
=
{
jit
::
SeqPoolType
::
kSum
};
for
(
auto
type
:
pool_types
)
{
for
(
auto
type
:
pool_types
)
{
for
(
int
h
:
TestSizes
())
{
for
(
int
w
:
TestSizes
())
{
for
(
int
w
:
TestSizes
())
{
const
jit
::
seq_pool_attr_t
attr
(
h
,
w
,
type
);
jit
::
seq_pool_attr_t
attr
(
w
,
type
);
for
(
int
h
:
TestSizes
())
{
attr
.
h
=
h
;
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
SeqPoolTuples
<
T
>>
();
auto
ref
=
jit
::
GetRefer
<
KT
,
jit
::
SeqPoolTuples
<
T
>>
();
EXPECT_TRUE
(
ref
!=
nullptr
);
EXPECT_TRUE
(
ref
!=
nullptr
);
std
::
vector
<
T
>
x
(
h
*
w
),
yref
(
w
);
std
::
vector
<
T
>
x
(
h
*
w
),
yref
(
w
);
...
...
paddle/fluid/operators/math/sequence_pooling.cc
浏览文件 @
0145f40f
...
@@ -252,14 +252,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
...
@@ -252,14 +252,14 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
));
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
place
));
const
T
*
src
=
input
.
data
<
T
>
();
const
T
*
src
=
input
.
data
<
T
>
();
T
*
dst
=
output
->
mutable_data
<
T
>
(
place
);
T
*
dst
=
output
->
mutable_data
<
T
>
(
place
);
jit
::
seq_pool_attr_t
attr
;
jit
::
seq_pool_attr_t
attr
(
attr
.
w
=
input
.
numel
()
/
input
.
dims
()[
0
];
static_cast
<
int
>
(
input
.
numel
()
/
input
.
dims
()[
0
]),
attr
.
type
=
jit
::
SeqPoolType
::
kSum
;
jit
::
SeqPoolType
::
kSum
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
attr
.
h
=
static_cast
<
int
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
auto
seqpool
=
auto
seqpool
=
jit
::
Get
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
jit
::
Get
<
jit
::
kSeqPool
,
jit
::
SeqPoolTuples
<
T
>
,
platform
::
CPUPlace
>
(
attr
);
attr
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod
.
size
())
-
1
;
++
i
)
{
attr
.
h
=
static_cast
<
int
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
seqpool
(
src
,
dst
,
&
attr
);
seqpool
(
src
,
dst
,
&
attr
);
dst
+=
attr
.
w
;
dst
+=
attr
.
w
;
src
+=
attr
.
h
*
attr
.
w
;
src
+=
attr
.
h
*
attr
.
w
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录