Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
36ed83d2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36ed83d2
编写于
9月 25, 2020
作者:
G
GaoWei8
提交者:
GitHub
9月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine PADDLE_ENFORCE (#27360)
* refine PADDLE_ENFORCE
上级
effd51b6
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
278 addition
and
76 deletion
+278
-76
paddle/fluid/operators/benchmark/op_tester.cc
paddle/fluid/operators/benchmark/op_tester.cc
+20
-12
paddle/fluid/operators/benchmark/op_tester_config.cc
paddle/fluid/operators/benchmark/op_tester_config.cc
+15
-5
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+9
-3
paddle/fluid/operators/jit/gen/embseqpool.cc
paddle/fluid/operators/jit/gen/embseqpool.cc
+25
-5
paddle/fluid/operators/jit/gen/matmul.cc
paddle/fluid/operators/jit/gen/matmul.cc
+20
-4
paddle/fluid/operators/jit/gen/matmul.h
paddle/fluid/operators/jit/gen/matmul.h
+4
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+8
-2
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+7
-2
paddle/fluid/operators/jit/gen/sgd.cc
paddle/fluid/operators/jit/gen/sgd.cc
+18
-3
paddle/fluid/operators/jit/gen/vbroadcast.cc
paddle/fluid/operators/jit/gen/vbroadcast.cc
+5
-1
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+8
-3
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+17
-6
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+18
-7
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+2
-1
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+51
-10
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+42
-9
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+9
-2
未找到文件。
paddle/fluid/operators/benchmark/op_tester.cc
浏览文件 @
36ed83d2
...
...
@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) {
CreateInputVarDesc
();
CreateOutputVarDesc
();
}
else
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Operator '%s' is not registered."
,
config_
.
op_type
));
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Operator '%s' is not registered in OpTester."
,
config_
.
op_type
));
}
if
(
config_
.
device_id
>=
0
)
{
...
...
@@ -81,7 +81,8 @@ void OpTester::Run() {
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
platform
::
SetDeviceId
(
config_
.
device_id
);
#else
PADDLE_THROW
(
"'CUDAPlace' is not supported in CPU only device."
);
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"'CUDAPlace' is not supported in CPU only device."
));
#endif
}
...
...
@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) {
}
else
if
(
str
==
"fp64"
)
{
return
framework
::
proto
::
VarType
::
FP64
;
}
else
{
PADDLE_THROW
(
"Unsupported dtype %s."
,
str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %s in OpTester."
,
str
.
c_str
()));
}
}
...
...
@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() {
case
framework
::
proto
::
AttrType
::
INTS
:
case
framework
::
proto
::
AttrType
::
FLOATS
:
case
framework
::
proto
::
AttrType
::
STRINGS
:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Not supported STRINGS type
yet."
));
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported STRINGS type in OpTester
yet."
));
break
;
case
framework
::
proto
::
AttrType
::
LONG
:
{
int64_t
value
=
StringTo
<
int64_t
>
(
value_str
);
...
...
@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() {
}
break
;
case
framework
::
proto
::
AttrType
::
LONGS
:
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport attr type %d in OpTester."
,
type
));
}
}
}
...
...
@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
}
is
.
close
();
}
else
{
PADDLE_THROW
(
"Unsupported initializer %s."
,
initializer
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported initializer %s in OpTester."
,
initializer
.
c_str
()));
}
if
(
!
platform
::
is_cpu_place
(
place_
))
{
...
...
@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) {
static_cast
<
double
>
(
1.0
),
item
.
second
.
initializer
,
item
.
second
.
filename
);
}
else
{
PADDLE_THROW
(
"Unsupported dtype %d."
,
data_type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %d in OpTester."
,
data_type
));
}
VLOG
(
3
)
<<
"Set lod for tensor "
<<
var_name
;
...
...
@@ -473,7 +478,8 @@ std::string OpTester::DebugString() {
<<
"
\n
"
;
}
break
;
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport attr type %d in OpTester."
,
attr_type
));
}
ss
<<
GenSpaces
(
--
count
)
<<
"}
\n
"
;
}
...
...
@@ -484,8 +490,10 @@ std::string OpTester::DebugString() {
TEST
(
op_tester
,
base
)
{
if
(
!
FLAGS_op_config_list
.
empty
())
{
std
::
ifstream
fin
(
FLAGS_op_config_list
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
FLAGS_op_config_list
.
c_str
());
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
),
true
,
platform
::
errors
::
InvalidArgument
(
"OpTester cannot open file %s"
,
FLAGS_op_config_list
.
c_str
()));
std
::
vector
<
OpTesterConfig
>
op_configs
;
while
(
!
fin
.
eof
())
{
VLOG
(
4
)
<<
"Reading config "
<<
op_configs
.
size
()
<<
"..."
;
...
...
paddle/fluid/operators/benchmark/op_tester_config.cc
浏览文件 @
36ed83d2
...
...
@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) {
}
else
if
(
dtype_str
==
"fp64"
||
dtype_str
==
"double"
)
{
dtype
=
"fp64"
;
}
else
{
PADDLE_THROW
(
"Unsupported dtype %s"
,
dtype_str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %s in OpInputConfig."
,
dtype_str
.
c_str
()));
}
VLOG
(
4
)
<<
"dtype of input "
<<
name
<<
" is: "
<<
dtype
;
}
...
...
@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) {
const
std
::
vector
<
std
::
string
>
supported_initializers
=
{
"random"
,
"natural"
,
"zeros"
,
"file"
};
if
(
!
Has
(
supported_initializers
,
initializer_str
))
{
PADDLE_THROW
(
"Unsupported initializer %s"
,
initializer_str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported initializer %s in OpInputConfig."
,
initializer_str
.
c_str
()));
}
initializer
=
initializer_str
;
...
...
@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) {
}
}
EraseEndSep
(
&
lod_str
);
PADDLE_ENFORCE_GE
(
lod_str
.
length
(),
4U
);
PADDLE_ENFORCE_GE
(
lod_str
.
length
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The length of lod string should be "
"equal to or larger than 4. But length of lod string is %zu."
,
lod_str
.
length
()));
VLOG
(
4
)
<<
"lod: "
<<
lod_str
<<
", length: "
<<
lod_str
.
length
();
// Parse the lod_str
...
...
@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) {
OpTesterConfig
::
OpTesterConfig
(
const
std
::
string
&
filename
)
{
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
filename
.
c_str
());
PADDLE_ENFORCE_EQ
(
static_cast
<
bool
>
(
fin
),
true
,
platform
::
errors
::
InvalidArgument
(
"OpTesterConfig cannot open file %s."
,
filename
.
c_str
()));
Init
(
fin
);
}
...
...
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
36ed83d2
...
...
@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
}
using
Tensor
=
paddle
::
framework
::
Tensor
;
template
<
typename
KernelTuple
,
typename
PlaceType
>
void
BenchKernelXYZN
()
{
using
T
=
typename
KernelTuple
::
data_type
;
...
...
@@ -320,8 +319,15 @@ void BenchKernelSgd() {
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
);
PADDLE_ENFORCE_GT
(
n
,
0
);
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The range of Sgd (upper - lower) should be equal to or lower "
"than n-1 (Sgd size -1). But upper - lower is %d and n-1 is %d."
,
static_cast
<
size_t
>
(
upper
-
lower
),
(
n
-
1
)));
PADDLE_ENFORCE_GT
(
n
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The Sgd size should be larger than 0. But the n is %d."
,
n
));
std
::
vector
<
int64_t
>
all
,
out
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
all
.
push_back
(
i
);
...
...
paddle/fluid/operators/jit/gen/embseqpool.cc
浏览文件 @
36ed83d2
...
...
@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
table_height
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
table_width
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
index_height
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
index_width
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
out_width
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
table_height
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_height of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
table_height
));
PADDLE_ENFORCE_GT
(
attr
.
table_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
table_width
));
PADDLE_ENFORCE_GT
(
attr
.
index_height
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute index_height of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
index_height
));
PADDLE_ENFORCE_GT
(
attr
.
index_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute index_width of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
index_width
));
PADDLE_ENFORCE_GT
(
attr
.
out_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute out_width of EmbSeqPool should be "
"larger than 0. But it is %d."
,
attr
.
out_width
));
return
make_unique
<
EmbSeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
...
...
paddle/fluid/operators/jit/gen/matmul.cc
浏览文件 @
36ed83d2
...
...
@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() {
preCode
();
int
block
,
rest
;
const
auto
groups
=
packed_groups
(
n_
,
k_
,
&
block
,
&
rest
);
PADDLE_ENFORCE_GT
(
groups
.
front
(),
0
);
PADDLE_ENFORCE_GT
(
groups
.
front
(),
0
,
platform
::
errors
::
InvalidArgument
(
"The number of rest registers should "
"be larger than 0. But it is %d."
,
groups
.
front
()));
const
int
block_len
=
sizeof
(
float
)
*
block
;
const
int
x_reg_idx
=
(
block
==
ZMM_FLOAT_BLOCK
?
32
:
16
)
-
1
;
...
...
@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
matmul_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
m
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
n
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
k
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
m
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute m (first matrix's row) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
m
));
PADDLE_ENFORCE_GT
(
attr
.
n
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute n (first matrix's col) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
n
));
PADDLE_ENFORCE_GT
(
attr
.
k
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute k (second matrix's col) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
k
));
return
make_unique
<
MatMulJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
...
...
paddle/fluid/operators/jit/gen/matmul.h
浏览文件 @
36ed83d2
...
...
@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode {
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
m_
(
attr
.
m
),
n_
(
attr
.
n
),
k_
(
attr
.
k
)
{
PADDLE_ENFORCE_EQ
(
m_
,
1
,
"Only support m==1 yet"
);
PADDLE_ENFORCE_EQ
(
m_
,
1
,
platform
::
errors
::
Unimplemented
(
"Jitcode of matmul only support m==1 (first "
"matrix's row) now. But m is %d."
,
m_
));
this
->
genCode
();
}
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
36ed83d2
...
...
@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute width of SeqPool should "
"be larger than 0. But it is %d."
,
attr
.
w
));
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute height of SeqPool should "
"be larger than 0. But it is %d."
,
attr
.
h
));
return
make_unique
<
SeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
...
...
paddle/fluid/operators/jit/gen/seqpool.h
浏览文件 @
36ed83d2
...
...
@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode {
vmovss
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
reg_idx
++
;
}
PADDLE_ENFORCE_EQ
(
reg_idx
,
rest_used_num_regs
,
"All heights should use same regs"
);
PADDLE_ENFORCE_EQ
(
reg_idx
,
rest_used_num_regs
,
platform
::
errors
::
InvalidArgument
(
"All heights of SeqPool should use the same number of registers."
"It equals to the numbr of rest registers. But use %d registers "
"and the numbr of rest registers is %d."
,
reg_idx
,
rest_used_num_regs
));
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
...
...
paddle/fluid/operators/jit/gen/sgd.cc
浏览文件 @
36ed83d2
...
...
@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
size_t
CodeSize
(
const
sgd_attr_t
&
attr
)
const
override
{
return
96
+
32
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
sgd_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_EQ
(
attr
.
param_width
,
attr
.
grad_width
);
PADDLE_ENFORCE_LE
(
attr
.
selected_rows_size
,
attr
.
grad_height
);
PADDLE_ENFORCE_GE
(
attr
.
selected_rows_size
,
0
);
PADDLE_ENFORCE_EQ
(
attr
.
param_width
,
attr
.
grad_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
.
param_width
,
attr
.
grad_width
));
PADDLE_ENFORCE_LE
(
attr
.
selected_rows_size
,
attr
.
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
.
selected_rows_size
,
attr
.
grad_height
));
PADDLE_ENFORCE_GE
(
attr
.
selected_rows_size
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or larger than 0. But selected_rows_size is %d."
,
attr
.
selected_rows_size
));
return
make_unique
<
SgdJitCode
>
(
attr
,
CodeSize
(
attr
));
}
};
...
...
paddle/fluid/operators/jit/gen/vbroadcast.cc
浏览文件 @
36ed83d2
...
...
@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
return
96
+
(
w
/
YMM_FLOAT_BLOCK
)
*
16
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int64_t
&
w
)
const
override
{
PADDLE_ENFORCE_GT
(
w
,
0
);
PADDLE_ENFORCE_GT
(
w
,
0
,
platform
::
errors
::
InvalidArgument
(
"The width of VBroadcast should be larger than 0. But w is %d."
,
w
));
return
make_unique
<
VBroadcastJitCode
>
(
w
,
CodeSize
(
w
));
}
};
...
...
paddle/fluid/operators/jit/gen_base.cc
浏览文件 @
36ed83d2
...
...
@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const {
void
*
GenBase
::
operator
new
(
size_t
size
)
{
void
*
ptr
;
constexpr
size_t
alignment
=
32ul
;
PADDLE_ENFORCE_EQ
(
posix_memalign
(
&
ptr
,
alignment
,
size
),
0
,
"GenBase Alloc %ld error!"
,
size
);
PADDLE_ENFORCE
(
ptr
,
"Fail to allocate GenBase CPU memory: size = %d ."
,
size
);
PADDLE_ENFORCE_EQ
(
posix_memalign
(
&
ptr
,
alignment
,
size
),
0
,
platform
::
errors
::
InvalidArgument
(
"Jitcode generator (GenBase) allocate %ld memory error!"
,
size
));
PADDLE_ENFORCE_NOT_NULL
(
ptr
,
platform
::
errors
::
InvalidArgument
(
"Fail to allocate jitcode generator "
"(GenBase) CPU memory: size = %d ."
,
size
));
return
ptr
;
}
...
...
paddle/fluid/operators/jit/helper.cc
浏览文件 @
36ed83d2
...
...
@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kEmbSeqPool
);
ONE_CASE
(
kSgd
);
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"JIT kernel do not support type: %d."
,
kt
));
return
"NOT JITKernel"
;
}
return
nullptr
;
...
...
@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) {
ONE_CASE
(
kAvg
);
ONE_CASE
(
kSqrt
);
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
tp
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"SeqPool JIT kernel do not support type: %d."
,
tp
));
return
"NOT PoolType"
;
}
return
nullptr
;
...
...
@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) {
}
else
if
(
lower
==
"tanh"
||
lower
==
"vtanh"
)
{
return
kVTanh
;
}
PADDLE_THROW
(
"Not support type: %s, or forget to add this case"
,
act
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s."
,
act
));
return
kNone
;
}
...
...
@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
int
block
,
rest
;
const
auto
groups
=
packed_groups
(
n
,
k
,
&
block
,
&
rest
);
std
::
for_each
(
groups
.
begin
(),
groups
.
end
(),
[
&
](
int
i
)
{
PADDLE_ENFORCE_GT
(
i
,
0
,
"each element of groups should be larger than 0."
);
PADDLE_ENFORCE_GT
(
i
,
0
,
platform
::
errors
::
InvalidArgument
(
"Each element of groups should be larger than "
"0. However the element: %d doesn't satify."
,
i
));
});
int
sum
=
std
::
accumulate
(
groups
.
begin
(),
groups
.
end
(),
0
);
std
::
memset
(
dst
,
0
,
k
*
sum
*
block
*
sizeof
(
float
));
PADDLE_ENFORCE_GE
(
sum
*
block
,
n
,
"The packed n should be equal to or larger than n"
);
platform
::
errors
::
InvalidArgument
(
"The packed n (sum * block) should be equal to or "
"larger than n (matmul row size). "
"However, the packed n is %d and n is %d."
,
sum
*
block
,
n
));
const
int
block_len
=
sizeof
(
float
)
*
block
;
int
n_offset
=
0
;
...
...
@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
template
<
typename
T
>
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>::
value
>::
type
pack_weights
(
const
T
*
src
,
T
*
dst
,
int
n
,
int
k
)
{
PADDLE_THROW
(
"Only support pack with float type."
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Only supports pack weights with float type."
));
}
}
// namespace jit
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
36ed83d2
...
...
@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() {
auto
&
ref_pool
=
ReferKernelPool
::
Instance
().
AllKernels
();
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
"Every Kernel should have reference function."
);
PADDLE_ENFORCE_NE
(
ref_iter
,
ref_pool
.
end
(),
platform
::
errors
::
PreconditionNotMet
(
"Every Refer Kernel of jitcode should have reference function."
));
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
...
...
@@ -101,7 +103,9 @@ template <typename KernelTuple>
inline
typename
KernelTuple
::
func_type
GetReferFunc
()
{
auto
ker
=
GetReferKernel
<
KernelTuple
>
();
auto
p
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
ker
);
PADDLE_ENFORCE
(
p
,
"The Refer kernel should exsit"
);
PADDLE_ENFORCE_NOT_NULL
(
p
,
platform
::
errors
::
InvalidArgument
(
"Get the reference code of kernel in CPU "
"failed. The Refer kernel should exsit."
));
return
p
->
GetFunc
();
}
...
...
@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels(
// The last implementation should be reference function on CPUPlace.
auto
ref
=
GetReferKernel
<
KernelTuple
>
();
PADDLE_ENFORCE
(
ref
!=
nullptr
,
"Refer Kernel can not be empty."
);
PADDLE_ENFORCE_NOT_NULL
(
ref
,
platform
::
errors
::
InvalidArgument
(
"Get all candicate kernel in CPU failed. "
"The Refer Kernel can not be empty."
));
res
.
emplace_back
(
ref
);
return
res
;
}
...
...
@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
std
::
string
name
=
k
->
ImplType
();
if
(
name
==
"JitCode"
)
{
auto
i
=
dynamic_cast
<
const
GenBase
*>
(
k
);
PADDLE_ENFORCE
(
i
,
"jitcode kernel cast can not fail."
);
PADDLE_ENFORCE_NOT_NULL
(
i
,
platform
::
errors
::
InvalidArgument
(
"Generate jitcode kernel (GenBase) failed."
));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
template
getCode
<
Func
>()));
}
else
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
k
);
PADDLE_ENFORCE
(
i
,
"kernel cast can not fail."
);
PADDLE_ENFORCE_NOT_NULL
(
i
,
platform
::
errors
::
InvalidArgument
(
"Kernel cast (KernelMore) failed."
));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
GetFunc
()));
}
}
...
...
@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename
KernelTuple
::
func_type
GetDefaultBestFunc
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
funcs
=
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
attr
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The candicate jit kernel is at least one in CPU."
));
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
...
...
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
36ed83d2
...
...
@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
}
else
if
(
type
==
kVIdentity
)
{
return
KernelFuncs
<
VIdentityTuple
<
T
>
,
CPUPlace
>::
Cache
().
At
(
d
);
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s"
,
type
));
return
nullptr
;
}
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
36ed83d2
...
...
@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
template
<
typename
T
>
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
const
emb_seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
);
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d, "
"out_width is %d."
,
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
));
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d."
,
i
,
idx
[
i
],
attr
->
table_height
));
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d."
,
i
,
idx
[
i
]));
};
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
...
...
@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
template
<
typename
T
>
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
T
*
out
,
const
sgd_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
);
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
);
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
->
param_width
,
attr
->
grad_width
));
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
->
selected_rows_size
,
attr
->
grad_height
));
T
scalar
=
-
lr
[
0
];
int
width
=
attr
->
grad_width
;
if
(
out
==
param
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
VAXPY
(
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
}
}
else
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
VScal
(
&
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
VAdd
(
param
+
h_idx
*
width
,
out
+
h_idx
*
width
,
out
+
h_idx
*
width
,
width
);
...
...
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
36ed83d2
...
...
@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
}
else
if
(
type
==
kVIdentity
)
{
return
VIdentity
<
T
>
;
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s."
,
type
));
return
nullptr
;
}
...
...
@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
template
<
typename
T
>
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
const
emb_seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
);
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d and "
"out_width is %d."
,
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
));
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d."
,
i
,
idx
[
i
],
attr
->
table_height
));
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d."
,
i
,
idx
[
i
]));
};
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
...
...
@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
template
<
typename
T
>
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
T
*
out
,
const
sgd_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
);
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
);
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
->
param_width
,
attr
->
grad_width
));
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
->
selected_rows_size
,
attr
->
grad_height
));
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
for
(
int64_t
j
=
0
;
j
<
attr
->
grad_width
;
++
j
)
{
out
[
h_idx
*
attr
->
grad_width
+
j
]
=
param
[
h_idx
*
attr
->
grad_width
+
j
]
-
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
36ed83d2
...
...
@@ -850,8 +850,15 @@ void TestKernelSgd() {
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
);
PADDLE_ENFORCE_GT
(
n
,
0
);
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The range of Sgd (upper - lower) should be lower "
"than n-1 (Sgd size -1). But the upper - lower is %d "
"and n-1 is %d."
,
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
));
PADDLE_ENFORCE_GT
(
n
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The Sgd size should be larger than 0. But the n is %d."
,
n
));
std
::
vector
<
int64_t
>
all
,
out
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
all
.
push_back
(
i
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录