Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
36ed83d2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36ed83d2
编写于
9月 25, 2020
作者:
G
GaoWei8
提交者:
GitHub
9月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine PADDLE_ENFORCE (#27360)
* refine PADDLE_ENFORCE
上级
effd51b6
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
278 addition
and
76 deletion
+278
-76
paddle/fluid/operators/benchmark/op_tester.cc
paddle/fluid/operators/benchmark/op_tester.cc
+20
-12
paddle/fluid/operators/benchmark/op_tester_config.cc
paddle/fluid/operators/benchmark/op_tester_config.cc
+15
-5
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+9
-3
paddle/fluid/operators/jit/gen/embseqpool.cc
paddle/fluid/operators/jit/gen/embseqpool.cc
+25
-5
paddle/fluid/operators/jit/gen/matmul.cc
paddle/fluid/operators/jit/gen/matmul.cc
+20
-4
paddle/fluid/operators/jit/gen/matmul.h
paddle/fluid/operators/jit/gen/matmul.h
+4
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+8
-2
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+7
-2
paddle/fluid/operators/jit/gen/sgd.cc
paddle/fluid/operators/jit/gen/sgd.cc
+18
-3
paddle/fluid/operators/jit/gen/vbroadcast.cc
paddle/fluid/operators/jit/gen/vbroadcast.cc
+5
-1
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+8
-3
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+17
-6
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+18
-7
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+2
-1
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+51
-10
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+42
-9
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+9
-2
未找到文件。
paddle/fluid/operators/benchmark/op_tester.cc
浏览文件 @
36ed83d2
...
@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) {
...
@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) {
CreateInputVarDesc
();
CreateInputVarDesc
();
CreateOutputVarDesc
();
CreateOutputVarDesc
();
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Operator '%s' is not registered."
,
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
config_
.
op_type
));
"Operator '%s' is not registered in OpTester."
,
config_
.
op_type
));
}
}
if
(
config_
.
device_id
>=
0
)
{
if
(
config_
.
device_id
>=
0
)
{
...
@@ -81,7 +81,8 @@ void OpTester::Run() {
...
@@ -81,7 +81,8 @@ void OpTester::Run() {
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
platform
::
SetDeviceId
(
config_
.
device_id
);
platform
::
SetDeviceId
(
config_
.
device_id
);
#else
#else
PADDLE_THROW
(
"'CUDAPlace' is not supported in CPU only device."
);
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"'CUDAPlace' is not supported in CPU only device."
));
#endif
#endif
}
}
...
@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) {
...
@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) {
}
else
if
(
str
==
"fp64"
)
{
}
else
if
(
str
==
"fp64"
)
{
return
framework
::
proto
::
VarType
::
FP64
;
return
framework
::
proto
::
VarType
::
FP64
;
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported dtype %s."
,
str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %s in OpTester."
,
str
.
c_str
()));
}
}
}
}
...
@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() {
...
@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() {
case
framework
::
proto
::
AttrType
::
INTS
:
case
framework
::
proto
::
AttrType
::
INTS
:
case
framework
::
proto
::
AttrType
::
FLOATS
:
case
framework
::
proto
::
AttrType
::
FLOATS
:
case
framework
::
proto
::
AttrType
::
STRINGS
:
case
framework
::
proto
::
AttrType
::
STRINGS
:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
platform
::
errors
::
Unimplemented
(
"Not supported STRINGS type
yet."
));
"Unsupported STRINGS type in OpTester
yet."
));
break
;
break
;
case
framework
::
proto
::
AttrType
::
LONG
:
{
case
framework
::
proto
::
AttrType
::
LONG
:
{
int64_t
value
=
StringTo
<
int64_t
>
(
value_str
);
int64_t
value
=
StringTo
<
int64_t
>
(
value_str
);
...
@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() {
...
@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() {
}
break
;
}
break
;
case
framework
::
proto
::
AttrType
::
LONGS
:
case
framework
::
proto
::
AttrType
::
LONGS
:
default:
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport attr type %d in OpTester."
,
type
));
}
}
}
}
}
}
...
@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
...
@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
}
}
is
.
close
();
is
.
close
();
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported initializer %s."
,
initializer
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported initializer %s in OpTester."
,
initializer
.
c_str
()));
}
}
if
(
!
platform
::
is_cpu_place
(
place_
))
{
if
(
!
platform
::
is_cpu_place
(
place_
))
{
...
@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) {
...
@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) {
static_cast
<
double
>
(
1.0
),
item
.
second
.
initializer
,
static_cast
<
double
>
(
1.0
),
item
.
second
.
initializer
,
item
.
second
.
filename
);
item
.
second
.
filename
);
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported dtype %d."
,
data_type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %d in OpTester."
,
data_type
));
}
}
VLOG
(
3
)
<<
"Set lod for tensor "
<<
var_name
;
VLOG
(
3
)
<<
"Set lod for tensor "
<<
var_name
;
...
@@ -473,7 +478,8 @@ std::string OpTester::DebugString() {
...
@@ -473,7 +478,8 @@ std::string OpTester::DebugString() {
<<
"
\n
"
;
<<
"
\n
"
;
}
break
;
}
break
;
default:
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport attr type %d in OpTester."
,
attr_type
));
}
}
ss
<<
GenSpaces
(
--
count
)
<<
"}
\n
"
;
ss
<<
GenSpaces
(
--
count
)
<<
"}
\n
"
;
}
}
...
@@ -484,8 +490,10 @@ std::string OpTester::DebugString() {
...
@@ -484,8 +490,10 @@ std::string OpTester::DebugString() {
TEST
(
op_tester
,
base
)
{
TEST
(
op_tester
,
base
)
{
if
(
!
FLAGS_op_config_list
.
empty
())
{
if
(
!
FLAGS_op_config_list
.
empty
())
{
std
::
ifstream
fin
(
FLAGS_op_config_list
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
fin
(
FLAGS_op_config_list
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
PADDLE_ENFORCE_EQ
(
FLAGS_op_config_list
.
c_str
());
static_cast
<
bool
>
(
fin
),
true
,
platform
::
errors
::
InvalidArgument
(
"OpTester cannot open file %s"
,
FLAGS_op_config_list
.
c_str
()));
std
::
vector
<
OpTesterConfig
>
op_configs
;
std
::
vector
<
OpTesterConfig
>
op_configs
;
while
(
!
fin
.
eof
())
{
while
(
!
fin
.
eof
())
{
VLOG
(
4
)
<<
"Reading config "
<<
op_configs
.
size
()
<<
"..."
;
VLOG
(
4
)
<<
"Reading config "
<<
op_configs
.
size
()
<<
"..."
;
...
...
paddle/fluid/operators/benchmark/op_tester_config.cc
浏览文件 @
36ed83d2
...
@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) {
...
@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) {
}
else
if
(
dtype_str
==
"fp64"
||
dtype_str
==
"double"
)
{
}
else
if
(
dtype_str
==
"fp64"
||
dtype_str
==
"double"
)
{
dtype
=
"fp64"
;
dtype
=
"fp64"
;
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported dtype %s"
,
dtype_str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %s in OpInputConfig."
,
dtype_str
.
c_str
()));
}
}
VLOG
(
4
)
<<
"dtype of input "
<<
name
<<
" is: "
<<
dtype
;
VLOG
(
4
)
<<
"dtype of input "
<<
name
<<
" is: "
<<
dtype
;
}
}
...
@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) {
...
@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) {
const
std
::
vector
<
std
::
string
>
supported_initializers
=
{
"random"
,
"natural"
,
const
std
::
vector
<
std
::
string
>
supported_initializers
=
{
"random"
,
"natural"
,
"zeros"
,
"file"
};
"zeros"
,
"file"
};
if
(
!
Has
(
supported_initializers
,
initializer_str
))
{
if
(
!
Has
(
supported_initializers
,
initializer_str
))
{
PADDLE_THROW
(
"Unsupported initializer %s"
,
initializer_str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported initializer %s in OpInputConfig."
,
initializer_str
.
c_str
()));
}
}
initializer
=
initializer_str
;
initializer
=
initializer_str
;
...
@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) {
...
@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) {
}
}
}
}
EraseEndSep
(
&
lod_str
);
EraseEndSep
(
&
lod_str
);
PADDLE_ENFORCE_GE
(
lod_str
.
length
(),
4U
);
PADDLE_ENFORCE_GE
(
lod_str
.
length
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The length of lod string should be "
"equal to or larger than 4. But length of lod string is %zu."
,
lod_str
.
length
()));
VLOG
(
4
)
<<
"lod: "
<<
lod_str
<<
", length: "
<<
lod_str
.
length
();
VLOG
(
4
)
<<
"lod: "
<<
lod_str
<<
", length: "
<<
lod_str
.
length
();
// Parse the lod_str
// Parse the lod_str
...
@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) {
...
@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) {
OpTesterConfig
::
OpTesterConfig
(
const
std
::
string
&
filename
)
{
OpTesterConfig
::
OpTesterConfig
(
const
std
::
string
&
filename
)
{
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
PADDLE_ENFORCE_EQ
(
filename
.
c_str
());
static_cast
<
bool
>
(
fin
),
true
,
platform
::
errors
::
InvalidArgument
(
"OpTesterConfig cannot open file %s."
,
filename
.
c_str
()));
Init
(
fin
);
Init
(
fin
);
}
}
...
...
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
36ed83d2
...
@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
...
@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
}
}
using
Tensor
=
paddle
::
framework
::
Tensor
;
using
Tensor
=
paddle
::
framework
::
Tensor
;
template
<
typename
KernelTuple
,
typename
PlaceType
>
template
<
typename
KernelTuple
,
typename
PlaceType
>
void
BenchKernelXYZN
()
{
void
BenchKernelXYZN
()
{
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
...
@@ -320,8 +319,15 @@ void BenchKernelSgd() {
...
@@ -320,8 +319,15 @@ void BenchKernelSgd() {
const
T
lr
=
0.1
;
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
);
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_GT
(
n
,
0
);
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The range of Sgd (upper - lower) should be equal to or lower "
"than n-1 (Sgd size -1). But upper - lower is %d and n-1 is %d."
,
static_cast
<
size_t
>
(
upper
-
lower
),
(
n
-
1
)));
PADDLE_ENFORCE_GT
(
n
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The Sgd size should be larger than 0. But the n is %d."
,
n
));
std
::
vector
<
int64_t
>
all
,
out
;
std
::
vector
<
int64_t
>
all
,
out
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
all
.
push_back
(
i
);
all
.
push_back
(
i
);
...
...
paddle/fluid/operators/jit/gen/embseqpool.cc
浏览文件 @
36ed83d2
...
@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
...
@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
table_height
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
table_height
,
0
,
PADDLE_ENFORCE_GT
(
attr
.
table_width
,
0
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
attr
.
index_height
,
0
);
"The attribute table_height of EmbSeqPool should "
PADDLE_ENFORCE_GT
(
attr
.
index_width
,
0
);
"be larger than 0. But it is %d."
,
PADDLE_ENFORCE_GT
(
attr
.
out_width
,
0
);
attr
.
table_height
));
PADDLE_ENFORCE_GT
(
attr
.
table_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
table_width
));
PADDLE_ENFORCE_GT
(
attr
.
index_height
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute index_height of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
index_height
));
PADDLE_ENFORCE_GT
(
attr
.
index_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute index_width of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
index_width
));
PADDLE_ENFORCE_GT
(
attr
.
out_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute out_width of EmbSeqPool should be "
"larger than 0. But it is %d."
,
attr
.
out_width
));
return
make_unique
<
EmbSeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
EmbSeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/matmul.cc
浏览文件 @
36ed83d2
...
@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() {
...
@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() {
preCode
();
preCode
();
int
block
,
rest
;
int
block
,
rest
;
const
auto
groups
=
packed_groups
(
n_
,
k_
,
&
block
,
&
rest
);
const
auto
groups
=
packed_groups
(
n_
,
k_
,
&
block
,
&
rest
);
PADDLE_ENFORCE_GT
(
groups
.
front
(),
0
);
PADDLE_ENFORCE_GT
(
groups
.
front
(),
0
,
platform
::
errors
::
InvalidArgument
(
"The number of rest registers should "
"be larger than 0. But it is %d."
,
groups
.
front
()));
const
int
block_len
=
sizeof
(
float
)
*
block
;
const
int
block_len
=
sizeof
(
float
)
*
block
;
const
int
x_reg_idx
=
(
block
==
ZMM_FLOAT_BLOCK
?
32
:
16
)
-
1
;
const
int
x_reg_idx
=
(
block
==
ZMM_FLOAT_BLOCK
?
32
:
16
)
-
1
;
...
@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
...
@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
matmul_attr_t
&
attr
)
const
override
{
const
matmul_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
m
,
0
);
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
attr
.
n
,
0
);
attr
.
m
,
0
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
attr
.
k
,
0
);
"The attribute m (first matrix's row) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
m
));
PADDLE_ENFORCE_GT
(
attr
.
n
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute n (first matrix's col) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
n
));
PADDLE_ENFORCE_GT
(
attr
.
k
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute k (second matrix's col) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
k
));
return
make_unique
<
MatMulJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
MatMulJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/matmul.h
浏览文件 @
36ed83d2
...
@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode {
...
@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode {
size_t
code_size
=
256
*
1024
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
m_
(
attr
.
m
),
n_
(
attr
.
n
),
k_
(
attr
.
k
)
{
:
JitCode
(
code_size
,
code_ptr
),
m_
(
attr
.
m
),
n_
(
attr
.
n
),
k_
(
attr
.
k
)
{
PADDLE_ENFORCE_EQ
(
m_
,
1
,
"Only support m==1 yet"
);
PADDLE_ENFORCE_EQ
(
m_
,
1
,
platform
::
errors
::
Unimplemented
(
"Jitcode of matmul only support m==1 (first "
"matrix's row) now. But m is %d."
,
m_
));
this
->
genCode
();
this
->
genCode
();
}
}
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
36ed83d2
...
@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
...
@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
seq_pool_attr_t
&
attr
)
const
override
{
const
seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
);
"The attribute width of SeqPool should "
"be larger than 0. But it is %d."
,
attr
.
w
));
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute height of SeqPool should "
"be larger than 0. But it is %d."
,
attr
.
h
));
return
make_unique
<
SeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
SeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/seqpool.h
浏览文件 @
36ed83d2
...
@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode {
...
@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode {
vmovss
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
vmovss
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
reg_idx
++
;
reg_idx
++
;
}
}
PADDLE_ENFORCE_EQ
(
reg_idx
,
rest_used_num_regs
,
PADDLE_ENFORCE_EQ
(
"All heights should use same regs"
);
reg_idx
,
rest_used_num_regs
,
platform
::
errors
::
InvalidArgument
(
"All heights of SeqPool should use the same number of registers."
"It equals to the numbr of rest registers. But use %d registers "
"and the numbr of rest registers is %d."
,
reg_idx
,
rest_used_num_regs
));
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
}
...
...
paddle/fluid/operators/jit/gen/sgd.cc
浏览文件 @
36ed83d2
...
@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
...
@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
size_t
CodeSize
(
const
sgd_attr_t
&
attr
)
const
override
{
return
96
+
32
*
8
;
}
size_t
CodeSize
(
const
sgd_attr_t
&
attr
)
const
override
{
return
96
+
32
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
sgd_attr_t
&
attr
)
const
override
{
const
sgd_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_EQ
(
attr
.
param_width
,
attr
.
grad_width
);
PADDLE_ENFORCE_EQ
(
attr
.
param_width
,
attr
.
grad_width
,
PADDLE_ENFORCE_LE
(
attr
.
selected_rows_size
,
attr
.
grad_height
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GE
(
attr
.
selected_rows_size
,
0
);
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
.
param_width
,
attr
.
grad_width
));
PADDLE_ENFORCE_LE
(
attr
.
selected_rows_size
,
attr
.
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
.
selected_rows_size
,
attr
.
grad_height
));
PADDLE_ENFORCE_GE
(
attr
.
selected_rows_size
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or larger than 0. But selected_rows_size is %d."
,
attr
.
selected_rows_size
));
return
make_unique
<
SgdJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
SgdJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/vbroadcast.cc
浏览文件 @
36ed83d2
...
@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
...
@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
return
96
+
(
w
/
YMM_FLOAT_BLOCK
)
*
16
*
8
;
return
96
+
(
w
/
YMM_FLOAT_BLOCK
)
*
16
*
8
;
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int64_t
&
w
)
const
override
{
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int64_t
&
w
)
const
override
{
PADDLE_ENFORCE_GT
(
w
,
0
);
PADDLE_ENFORCE_GT
(
w
,
0
,
platform
::
errors
::
InvalidArgument
(
"The width of VBroadcast should be larger than 0. But w is %d."
,
w
));
return
make_unique
<
VBroadcastJitCode
>
(
w
,
CodeSize
(
w
));
return
make_unique
<
VBroadcastJitCode
>
(
w
,
CodeSize
(
w
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen_base.cc
浏览文件 @
36ed83d2
...
@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const {
...
@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const {
void
*
GenBase
::
operator
new
(
size_t
size
)
{
void
*
GenBase
::
operator
new
(
size_t
size
)
{
void
*
ptr
;
void
*
ptr
;
constexpr
size_t
alignment
=
32ul
;
constexpr
size_t
alignment
=
32ul
;
PADDLE_ENFORCE_EQ
(
posix_memalign
(
&
ptr
,
alignment
,
size
),
0
,
PADDLE_ENFORCE_EQ
(
"GenBase Alloc %ld error!"
,
size
);
posix_memalign
(
&
ptr
,
alignment
,
size
),
0
,
PADDLE_ENFORCE
(
ptr
,
"Fail to allocate GenBase CPU memory: size = %d ."
,
size
);
platform
::
errors
::
InvalidArgument
(
"Jitcode generator (GenBase) allocate %ld memory error!"
,
size
));
PADDLE_ENFORCE_NOT_NULL
(
ptr
,
platform
::
errors
::
InvalidArgument
(
"Fail to allocate jitcode generator "
"(GenBase) CPU memory: size = %d ."
,
size
));
return
ptr
;
return
ptr
;
}
}
...
...
paddle/fluid/operators/jit/helper.cc
浏览文件 @
36ed83d2
...
@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) {
...
@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kEmbSeqPool
);
ONE_CASE
(
kEmbSeqPool
);
ONE_CASE
(
kSgd
);
ONE_CASE
(
kSgd
);
default:
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"JIT kernel do not support type: %d."
,
kt
));
return
"NOT JITKernel"
;
return
"NOT JITKernel"
;
}
}
return
nullptr
;
return
nullptr
;
...
@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) {
...
@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) {
ONE_CASE
(
kAvg
);
ONE_CASE
(
kAvg
);
ONE_CASE
(
kSqrt
);
ONE_CASE
(
kSqrt
);
default:
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
tp
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"SeqPool JIT kernel do not support type: %d."
,
tp
));
return
"NOT PoolType"
;
return
"NOT PoolType"
;
}
}
return
nullptr
;
return
nullptr
;
...
@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) {
...
@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) {
}
else
if
(
lower
==
"tanh"
||
lower
==
"vtanh"
)
{
}
else
if
(
lower
==
"tanh"
||
lower
==
"vtanh"
)
{
return
kVTanh
;
return
kVTanh
;
}
}
PADDLE_THROW
(
"Not support type: %s, or forget to add this case"
,
act
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s."
,
act
));
return
kNone
;
return
kNone
;
}
}
...
@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
...
@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
int
block
,
rest
;
int
block
,
rest
;
const
auto
groups
=
packed_groups
(
n
,
k
,
&
block
,
&
rest
);
const
auto
groups
=
packed_groups
(
n
,
k
,
&
block
,
&
rest
);
std
::
for_each
(
groups
.
begin
(),
groups
.
end
(),
[
&
](
int
i
)
{
std
::
for_each
(
groups
.
begin
(),
groups
.
end
(),
[
&
](
int
i
)
{
PADDLE_ENFORCE_GT
(
i
,
0
,
"each element of groups should be larger than 0."
);
PADDLE_ENFORCE_GT
(
i
,
0
,
platform
::
errors
::
InvalidArgument
(
"Each element of groups should be larger than "
"0. However the element: %d doesn't satify."
,
i
));
});
});
int
sum
=
std
::
accumulate
(
groups
.
begin
(),
groups
.
end
(),
0
);
int
sum
=
std
::
accumulate
(
groups
.
begin
(),
groups
.
end
(),
0
);
std
::
memset
(
dst
,
0
,
k
*
sum
*
block
*
sizeof
(
float
));
std
::
memset
(
dst
,
0
,
k
*
sum
*
block
*
sizeof
(
float
));
PADDLE_ENFORCE_GE
(
sum
*
block
,
n
,
PADDLE_ENFORCE_GE
(
sum
*
block
,
n
,
"The packed n should be equal to or larger than n"
);
platform
::
errors
::
InvalidArgument
(
"The packed n (sum * block) should be equal to or "
"larger than n (matmul row size). "
"However, the packed n is %d and n is %d."
,
sum
*
block
,
n
));
const
int
block_len
=
sizeof
(
float
)
*
block
;
const
int
block_len
=
sizeof
(
float
)
*
block
;
int
n_offset
=
0
;
int
n_offset
=
0
;
...
@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
...
@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
template
<
typename
T
>
template
<
typename
T
>
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>::
value
>::
type
pack_weights
(
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>::
value
>::
type
pack_weights
(
const
T
*
src
,
T
*
dst
,
int
n
,
int
k
)
{
const
T
*
src
,
T
*
dst
,
int
n
,
int
k
)
{
PADDLE_THROW
(
"Only support pack with float type."
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Only supports pack weights with float type."
));
}
}
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
36ed83d2
...
@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() {
...
@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() {
auto
&
ref_pool
=
ReferKernelPool
::
Instance
().
AllKernels
();
auto
&
ref_pool
=
ReferKernelPool
::
Instance
().
AllKernels
();
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
PADDLE_ENFORCE_NE
(
"Every Kernel should have reference function."
);
ref_iter
,
ref_pool
.
end
(),
platform
::
errors
::
PreconditionNotMet
(
"Every Refer Kernel of jitcode should have reference function."
));
auto
&
ref_impls
=
ref_iter
->
second
;
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
...
@@ -101,7 +103,9 @@ template <typename KernelTuple>
...
@@ -101,7 +103,9 @@ template <typename KernelTuple>
inline
typename
KernelTuple
::
func_type
GetReferFunc
()
{
inline
typename
KernelTuple
::
func_type
GetReferFunc
()
{
auto
ker
=
GetReferKernel
<
KernelTuple
>
();
auto
ker
=
GetReferKernel
<
KernelTuple
>
();
auto
p
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
ker
);
auto
p
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
ker
);
PADDLE_ENFORCE
(
p
,
"The Refer kernel should exsit"
);
PADDLE_ENFORCE_NOT_NULL
(
p
,
platform
::
errors
::
InvalidArgument
(
"Get the reference code of kernel in CPU "
"failed. The Refer kernel should exsit."
));
return
p
->
GetFunc
();
return
p
->
GetFunc
();
}
}
...
@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels(
...
@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels(
// The last implementation should be reference function on CPUPlace.
// The last implementation should be reference function on CPUPlace.
auto
ref
=
GetReferKernel
<
KernelTuple
>
();
auto
ref
=
GetReferKernel
<
KernelTuple
>
();
PADDLE_ENFORCE
(
ref
!=
nullptr
,
"Refer Kernel can not be empty."
);
PADDLE_ENFORCE_NOT_NULL
(
ref
,
platform
::
errors
::
InvalidArgument
(
"Get all candicate kernel in CPU failed. "
"The Refer Kernel can not be empty."
));
res
.
emplace_back
(
ref
);
res
.
emplace_back
(
ref
);
return
res
;
return
res
;
}
}
...
@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
...
@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
std
::
string
name
=
k
->
ImplType
();
std
::
string
name
=
k
->
ImplType
();
if
(
name
==
"JitCode"
)
{
if
(
name
==
"JitCode"
)
{
auto
i
=
dynamic_cast
<
const
GenBase
*>
(
k
);
auto
i
=
dynamic_cast
<
const
GenBase
*>
(
k
);
PADDLE_ENFORCE
(
i
,
"jitcode kernel cast can not fail."
);
PADDLE_ENFORCE_NOT_NULL
(
i
,
platform
::
errors
::
InvalidArgument
(
"Generate jitcode kernel (GenBase) failed."
));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
template
getCode
<
Func
>()));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
template
getCode
<
Func
>()));
}
else
{
}
else
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
k
);
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
k
);
PADDLE_ENFORCE
(
i
,
"kernel cast can not fail."
);
PADDLE_ENFORCE_NOT_NULL
(
i
,
platform
::
errors
::
InvalidArgument
(
"Kernel cast (KernelMore) failed."
));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
GetFunc
()));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
GetFunc
()));
}
}
}
}
...
@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
...
@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename
KernelTuple
::
func_type
GetDefaultBestFunc
(
typename
KernelTuple
::
func_type
GetDefaultBestFunc
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
funcs
=
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
attr
);
auto
funcs
=
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
attr
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The candicate jit kernel is at least one in CPU."
));
// Here could do some runtime benchmark of this attr and return the best one.
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
// which is searched in order and tuned by offline.
...
...
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
36ed83d2
...
@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
...
@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
}
else
if
(
type
==
kVIdentity
)
{
}
else
if
(
type
==
kVIdentity
)
{
return
KernelFuncs
<
VIdentityTuple
<
T
>
,
CPUPlace
>::
Cache
().
At
(
d
);
return
KernelFuncs
<
VIdentityTuple
<
T
>
,
CPUPlace
>::
Cache
().
At
(
d
);
}
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s"
,
type
));
return
nullptr
;
return
nullptr
;
}
}
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
36ed83d2
...
@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
...
@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
template
<
typename
T
>
template
<
typename
T
>
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
const
emb_seq_pool_attr_t
*
attr
)
{
const
emb_seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
);
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d, "
"out_width is %d."
,
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
));
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
"idx value: %d, i: %d"
,
PADDLE_ENFORCE_LT
(
idx
[
i
],
i
);
idx
[
i
],
attr
->
table_height
,
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
platform
::
errors
::
InvalidArgument
(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d."
,
i
,
idx
[
i
],
attr
->
table_height
));
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d."
,
i
,
idx
[
i
]));
};
};
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
...
@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
...
@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
template
<
typename
T
>
template
<
typename
T
>
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
T
*
out
,
const
sgd_attr_t
*
attr
)
{
T
*
out
,
const
sgd_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
);
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
,
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
);
platform
::
errors
::
InvalidArgument
(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
->
param_width
,
attr
->
grad_width
));
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
->
selected_rows_size
,
attr
->
grad_height
));
T
scalar
=
-
lr
[
0
];
T
scalar
=
-
lr
[
0
];
int
width
=
attr
->
grad_width
;
int
width
=
attr
->
grad_width
;
if
(
out
==
param
)
{
if
(
out
==
param
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
VAXPY
(
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
VAXPY
(
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
}
}
}
else
{
}
else
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
VScal
(
&
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
VScal
(
&
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
VAdd
(
param
+
h_idx
*
width
,
out
+
h_idx
*
width
,
out
+
h_idx
*
width
,
VAdd
(
param
+
h_idx
*
width
,
out
+
h_idx
*
width
,
out
+
h_idx
*
width
,
width
);
width
);
...
...
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
36ed83d2
...
@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
...
@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
}
else
if
(
type
==
kVIdentity
)
{
}
else
if
(
type
==
kVIdentity
)
{
return
VIdentity
<
T
>
;
return
VIdentity
<
T
>
;
}
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s."
,
type
));
return
nullptr
;
return
nullptr
;
}
}
...
@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
...
@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
template
<
typename
T
>
template
<
typename
T
>
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
const
emb_seq_pool_attr_t
*
attr
)
{
const
emb_seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
);
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d and "
"out_width is %d."
,
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
));
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
"idx value: %d, i: %d"
,
PADDLE_ENFORCE_LT
(
idx
[
i
],
i
);
idx
[
i
],
attr
->
table_height
,
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
platform
::
errors
::
InvalidArgument
(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d."
,
i
,
idx
[
i
],
attr
->
table_height
));
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d."
,
i
,
idx
[
i
]));
};
};
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
...
@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
...
@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
template
<
typename
T
>
template
<
typename
T
>
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
T
*
out
,
const
sgd_attr_t
*
attr
)
{
T
*
out
,
const
sgd_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
);
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
,
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
);
platform
::
errors
::
InvalidArgument
(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
->
param_width
,
attr
->
grad_width
));
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
->
selected_rows_size
,
attr
->
grad_height
));
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
for
(
int64_t
j
=
0
;
j
<
attr
->
grad_width
;
++
j
)
{
for
(
int64_t
j
=
0
;
j
<
attr
->
grad_width
;
++
j
)
{
out
[
h_idx
*
attr
->
grad_width
+
j
]
=
out
[
h_idx
*
attr
->
grad_width
+
j
]
=
param
[
h_idx
*
attr
->
grad_width
+
j
]
-
param
[
h_idx
*
attr
->
grad_width
+
j
]
-
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
36ed83d2
...
@@ -850,8 +850,15 @@ void TestKernelSgd() {
...
@@ -850,8 +850,15 @@ void TestKernelSgd() {
const
T
lr
=
0.1
;
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
);
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
,
PADDLE_ENFORCE_GT
(
n
,
0
);
paddle
::
platform
::
errors
::
InvalidArgument
(
"The range of Sgd (upper - lower) should be lower "
"than n-1 (Sgd size -1). But the upper - lower is %d "
"and n-1 is %d."
,
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
));
PADDLE_ENFORCE_GT
(
n
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The Sgd size should be larger than 0. But the n is %d."
,
n
));
std
::
vector
<
int64_t
>
all
,
out
;
std
::
vector
<
int64_t
>
all
,
out
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
all
.
push_back
(
i
);
all
.
push_back
(
i
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录