Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
36ed83d2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
36ed83d2
编写于
9月 25, 2020
作者:
G
GaoWei8
提交者:
GitHub
9月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine PADDLE_ENFORCE (#27360)
* refine PADDLE_ENFORCE
上级
effd51b6
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
278 addition
and
76 deletion
+278
-76
paddle/fluid/operators/benchmark/op_tester.cc
paddle/fluid/operators/benchmark/op_tester.cc
+20
-12
paddle/fluid/operators/benchmark/op_tester_config.cc
paddle/fluid/operators/benchmark/op_tester_config.cc
+15
-5
paddle/fluid/operators/jit/benchmark.cc
paddle/fluid/operators/jit/benchmark.cc
+9
-3
paddle/fluid/operators/jit/gen/embseqpool.cc
paddle/fluid/operators/jit/gen/embseqpool.cc
+25
-5
paddle/fluid/operators/jit/gen/matmul.cc
paddle/fluid/operators/jit/gen/matmul.cc
+20
-4
paddle/fluid/operators/jit/gen/matmul.h
paddle/fluid/operators/jit/gen/matmul.h
+4
-1
paddle/fluid/operators/jit/gen/seqpool.cc
paddle/fluid/operators/jit/gen/seqpool.cc
+8
-2
paddle/fluid/operators/jit/gen/seqpool.h
paddle/fluid/operators/jit/gen/seqpool.h
+7
-2
paddle/fluid/operators/jit/gen/sgd.cc
paddle/fluid/operators/jit/gen/sgd.cc
+18
-3
paddle/fluid/operators/jit/gen/vbroadcast.cc
paddle/fluid/operators/jit/gen/vbroadcast.cc
+5
-1
paddle/fluid/operators/jit/gen_base.cc
paddle/fluid/operators/jit/gen_base.cc
+8
-3
paddle/fluid/operators/jit/helper.cc
paddle/fluid/operators/jit/helper.cc
+17
-6
paddle/fluid/operators/jit/helper.h
paddle/fluid/operators/jit/helper.h
+18
-7
paddle/fluid/operators/jit/more/mix/mix.cc
paddle/fluid/operators/jit/more/mix/mix.cc
+2
-1
paddle/fluid/operators/jit/more/mkl/mkl.h
paddle/fluid/operators/jit/more/mkl/mkl.h
+51
-10
paddle/fluid/operators/jit/refer/refer.h
paddle/fluid/operators/jit/refer/refer.h
+42
-9
paddle/fluid/operators/jit/test.cc
paddle/fluid/operators/jit/test.cc
+9
-2
未找到文件。
paddle/fluid/operators/benchmark/op_tester.cc
浏览文件 @
36ed83d2
...
@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) {
...
@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) {
CreateInputVarDesc
();
CreateInputVarDesc
();
CreateOutputVarDesc
();
CreateOutputVarDesc
();
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Operator '%s' is not registered."
,
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
config_
.
op_type
));
"Operator '%s' is not registered in OpTester."
,
config_
.
op_type
));
}
}
if
(
config_
.
device_id
>=
0
)
{
if
(
config_
.
device_id
>=
0
)
{
...
@@ -81,7 +81,8 @@ void OpTester::Run() {
...
@@ -81,7 +81,8 @@ void OpTester::Run() {
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kAll
);
platform
::
SetDeviceId
(
config_
.
device_id
);
platform
::
SetDeviceId
(
config_
.
device_id
);
#else
#else
PADDLE_THROW
(
"'CUDAPlace' is not supported in CPU only device."
);
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"'CUDAPlace' is not supported in CPU only device."
));
#endif
#endif
}
}
...
@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) {
...
@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) {
}
else
if
(
str
==
"fp64"
)
{
}
else
if
(
str
==
"fp64"
)
{
return
framework
::
proto
::
VarType
::
FP64
;
return
framework
::
proto
::
VarType
::
FP64
;
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported dtype %s."
,
str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %s in OpTester."
,
str
.
c_str
()));
}
}
}
}
...
@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() {
...
@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() {
case
framework
::
proto
::
AttrType
::
INTS
:
case
framework
::
proto
::
AttrType
::
INTS
:
case
framework
::
proto
::
AttrType
::
FLOATS
:
case
framework
::
proto
::
AttrType
::
FLOATS
:
case
framework
::
proto
::
AttrType
::
STRINGS
:
case
framework
::
proto
::
AttrType
::
STRINGS
:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
platform
::
errors
::
Unimplemented
(
"Not supported STRINGS type
yet."
));
"Unsupported STRINGS type in OpTester
yet."
));
break
;
break
;
case
framework
::
proto
::
AttrType
::
LONG
:
{
case
framework
::
proto
::
AttrType
::
LONG
:
{
int64_t
value
=
StringTo
<
int64_t
>
(
value_str
);
int64_t
value
=
StringTo
<
int64_t
>
(
value_str
);
...
@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() {
...
@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() {
}
break
;
}
break
;
case
framework
::
proto
::
AttrType
::
LONGS
:
case
framework
::
proto
::
AttrType
::
LONGS
:
default:
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport attr type %d in OpTester."
,
type
));
}
}
}
}
}
}
...
@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
...
@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
}
}
is
.
close
();
is
.
close
();
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported initializer %s."
,
initializer
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported initializer %s in OpTester."
,
initializer
.
c_str
()));
}
}
if
(
!
platform
::
is_cpu_place
(
place_
))
{
if
(
!
platform
::
is_cpu_place
(
place_
))
{
...
@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) {
...
@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) {
static_cast
<
double
>
(
1.0
),
item
.
second
.
initializer
,
static_cast
<
double
>
(
1.0
),
item
.
second
.
initializer
,
item
.
second
.
filename
);
item
.
second
.
filename
);
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported dtype %d."
,
data_type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %d in OpTester."
,
data_type
));
}
}
VLOG
(
3
)
<<
"Set lod for tensor "
<<
var_name
;
VLOG
(
3
)
<<
"Set lod for tensor "
<<
var_name
;
...
@@ -473,7 +478,8 @@ std::string OpTester::DebugString() {
...
@@ -473,7 +478,8 @@ std::string OpTester::DebugString() {
<<
"
\n
"
;
<<
"
\n
"
;
}
break
;
}
break
;
default:
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupport attr type %d in OpTester."
,
attr_type
));
}
}
ss
<<
GenSpaces
(
--
count
)
<<
"}
\n
"
;
ss
<<
GenSpaces
(
--
count
)
<<
"}
\n
"
;
}
}
...
@@ -484,8 +490,10 @@ std::string OpTester::DebugString() {
...
@@ -484,8 +490,10 @@ std::string OpTester::DebugString() {
TEST
(
op_tester
,
base
)
{
TEST
(
op_tester
,
base
)
{
if
(
!
FLAGS_op_config_list
.
empty
())
{
if
(
!
FLAGS_op_config_list
.
empty
())
{
std
::
ifstream
fin
(
FLAGS_op_config_list
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
fin
(
FLAGS_op_config_list
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
PADDLE_ENFORCE_EQ
(
FLAGS_op_config_list
.
c_str
());
static_cast
<
bool
>
(
fin
),
true
,
platform
::
errors
::
InvalidArgument
(
"OpTester cannot open file %s"
,
FLAGS_op_config_list
.
c_str
()));
std
::
vector
<
OpTesterConfig
>
op_configs
;
std
::
vector
<
OpTesterConfig
>
op_configs
;
while
(
!
fin
.
eof
())
{
while
(
!
fin
.
eof
())
{
VLOG
(
4
)
<<
"Reading config "
<<
op_configs
.
size
()
<<
"..."
;
VLOG
(
4
)
<<
"Reading config "
<<
op_configs
.
size
()
<<
"..."
;
...
...
paddle/fluid/operators/benchmark/op_tester_config.cc
浏览文件 @
36ed83d2
...
@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) {
...
@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) {
}
else
if
(
dtype_str
==
"fp64"
||
dtype_str
==
"double"
)
{
}
else
if
(
dtype_str
==
"fp64"
||
dtype_str
==
"double"
)
{
dtype
=
"fp64"
;
dtype
=
"fp64"
;
}
else
{
}
else
{
PADDLE_THROW
(
"Unsupported dtype %s"
,
dtype_str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported dtype %s in OpInputConfig."
,
dtype_str
.
c_str
()));
}
}
VLOG
(
4
)
<<
"dtype of input "
<<
name
<<
" is: "
<<
dtype
;
VLOG
(
4
)
<<
"dtype of input "
<<
name
<<
" is: "
<<
dtype
;
}
}
...
@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) {
...
@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) {
const
std
::
vector
<
std
::
string
>
supported_initializers
=
{
"random"
,
"natural"
,
const
std
::
vector
<
std
::
string
>
supported_initializers
=
{
"random"
,
"natural"
,
"zeros"
,
"file"
};
"zeros"
,
"file"
};
if
(
!
Has
(
supported_initializers
,
initializer_str
))
{
if
(
!
Has
(
supported_initializers
,
initializer_str
))
{
PADDLE_THROW
(
"Unsupported initializer %s"
,
initializer_str
.
c_str
());
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported initializer %s in OpInputConfig."
,
initializer_str
.
c_str
()));
}
}
initializer
=
initializer_str
;
initializer
=
initializer_str
;
...
@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) {
...
@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) {
}
}
}
}
EraseEndSep
(
&
lod_str
);
EraseEndSep
(
&
lod_str
);
PADDLE_ENFORCE_GE
(
lod_str
.
length
(),
4U
);
PADDLE_ENFORCE_GE
(
lod_str
.
length
(),
4U
,
platform
::
errors
::
InvalidArgument
(
"The length of lod string should be "
"equal to or larger than 4. But length of lod string is %zu."
,
lod_str
.
length
()));
VLOG
(
4
)
<<
"lod: "
<<
lod_str
<<
", length: "
<<
lod_str
.
length
();
VLOG
(
4
)
<<
"lod: "
<<
lod_str
<<
", length: "
<<
lod_str
.
length
();
// Parse the lod_str
// Parse the lod_str
...
@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) {
...
@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) {
OpTesterConfig
::
OpTesterConfig
(
const
std
::
string
&
filename
)
{
OpTesterConfig
::
OpTesterConfig
(
const
std
::
string
&
filename
)
{
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
fin
(
filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
fin
),
"Cannot open file %s"
,
PADDLE_ENFORCE_EQ
(
filename
.
c_str
());
static_cast
<
bool
>
(
fin
),
true
,
platform
::
errors
::
InvalidArgument
(
"OpTesterConfig cannot open file %s."
,
filename
.
c_str
()));
Init
(
fin
);
Init
(
fin
);
}
}
...
...
paddle/fluid/operators/jit/benchmark.cc
浏览文件 @
36ed83d2
...
@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
...
@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
}
}
using
Tensor
=
paddle
::
framework
::
Tensor
;
using
Tensor
=
paddle
::
framework
::
Tensor
;
template
<
typename
KernelTuple
,
typename
PlaceType
>
template
<
typename
KernelTuple
,
typename
PlaceType
>
void
BenchKernelXYZN
()
{
void
BenchKernelXYZN
()
{
using
T
=
typename
KernelTuple
::
data_type
;
using
T
=
typename
KernelTuple
::
data_type
;
...
@@ -320,8 +319,15 @@ void BenchKernelSgd() {
...
@@ -320,8 +319,15 @@ void BenchKernelSgd() {
const
T
lr
=
0.1
;
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
);
PADDLE_ENFORCE_LE
(
PADDLE_ENFORCE_GT
(
n
,
0
);
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The range of Sgd (upper - lower) should be equal to or lower "
"than n-1 (Sgd size -1). But upper - lower is %d and n-1 is %d."
,
static_cast
<
size_t
>
(
upper
-
lower
),
(
n
-
1
)));
PADDLE_ENFORCE_GT
(
n
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The Sgd size should be larger than 0. But the n is %d."
,
n
));
std
::
vector
<
int64_t
>
all
,
out
;
std
::
vector
<
int64_t
>
all
,
out
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
all
.
push_back
(
i
);
all
.
push_back
(
i
);
...
...
paddle/fluid/operators/jit/gen/embseqpool.cc
浏览文件 @
36ed83d2
...
@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
...
@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
const
emb_seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
table_height
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
table_height
,
0
,
PADDLE_ENFORCE_GT
(
attr
.
table_width
,
0
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
attr
.
index_height
,
0
);
"The attribute table_height of EmbSeqPool should "
PADDLE_ENFORCE_GT
(
attr
.
index_width
,
0
);
"be larger than 0. But it is %d."
,
PADDLE_ENFORCE_GT
(
attr
.
out_width
,
0
);
attr
.
table_height
));
PADDLE_ENFORCE_GT
(
attr
.
table_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
table_width
));
PADDLE_ENFORCE_GT
(
attr
.
index_height
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute index_height of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
index_height
));
PADDLE_ENFORCE_GT
(
attr
.
index_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute index_width of EmbSeqPool should "
"be larger than 0. But it is %d."
,
attr
.
index_width
));
PADDLE_ENFORCE_GT
(
attr
.
out_width
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute out_width of EmbSeqPool should be "
"larger than 0. But it is %d."
,
attr
.
out_width
));
return
make_unique
<
EmbSeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
EmbSeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/matmul.cc
浏览文件 @
36ed83d2
...
@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() {
...
@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() {
preCode
();
preCode
();
int
block
,
rest
;
int
block
,
rest
;
const
auto
groups
=
packed_groups
(
n_
,
k_
,
&
block
,
&
rest
);
const
auto
groups
=
packed_groups
(
n_
,
k_
,
&
block
,
&
rest
);
PADDLE_ENFORCE_GT
(
groups
.
front
(),
0
);
PADDLE_ENFORCE_GT
(
groups
.
front
(),
0
,
platform
::
errors
::
InvalidArgument
(
"The number of rest registers should "
"be larger than 0. But it is %d."
,
groups
.
front
()));
const
int
block_len
=
sizeof
(
float
)
*
block
;
const
int
block_len
=
sizeof
(
float
)
*
block
;
const
int
x_reg_idx
=
(
block
==
ZMM_FLOAT_BLOCK
?
32
:
16
)
-
1
;
const
int
x_reg_idx
=
(
block
==
ZMM_FLOAT_BLOCK
?
32
:
16
)
-
1
;
...
@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
...
@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
matmul_attr_t
&
attr
)
const
override
{
const
matmul_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
m
,
0
);
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
attr
.
n
,
0
);
attr
.
m
,
0
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
attr
.
k
,
0
);
"The attribute m (first matrix's row) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
m
));
PADDLE_ENFORCE_GT
(
attr
.
n
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute n (first matrix's col) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
n
));
PADDLE_ENFORCE_GT
(
attr
.
k
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute k (second matrix's col) of MatMul should "
"be larger than 0. But it is %d."
,
attr
.
k
));
return
make_unique
<
MatMulJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
MatMulJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/matmul.h
浏览文件 @
36ed83d2
...
@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode {
...
@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode {
size_t
code_size
=
256
*
1024
,
size_t
code_size
=
256
*
1024
,
void
*
code_ptr
=
nullptr
)
void
*
code_ptr
=
nullptr
)
:
JitCode
(
code_size
,
code_ptr
),
m_
(
attr
.
m
),
n_
(
attr
.
n
),
k_
(
attr
.
k
)
{
:
JitCode
(
code_size
,
code_ptr
),
m_
(
attr
.
m
),
n_
(
attr
.
n
),
k_
(
attr
.
k
)
{
PADDLE_ENFORCE_EQ
(
m_
,
1
,
"Only support m==1 yet"
);
PADDLE_ENFORCE_EQ
(
m_
,
1
,
platform
::
errors
::
Unimplemented
(
"Jitcode of matmul only support m==1 (first "
"matrix's row) now. But m is %d."
,
m_
));
this
->
genCode
();
this
->
genCode
();
}
}
...
...
paddle/fluid/operators/jit/gen/seqpool.cc
浏览文件 @
36ed83d2
...
@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
...
@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
seq_pool_attr_t
&
attr
)
const
override
{
const
seq_pool_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
);
PADDLE_ENFORCE_GT
(
attr
.
w
,
0
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
);
"The attribute width of SeqPool should "
"be larger than 0. But it is %d."
,
attr
.
w
));
PADDLE_ENFORCE_GT
(
attr
.
h
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute height of SeqPool should "
"be larger than 0. But it is %d."
,
attr
.
h
));
return
make_unique
<
SeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
SeqPoolJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/seqpool.h
浏览文件 @
36ed83d2
...
@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode {
...
@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode {
vmovss
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
vmovss
(
xmm_t
(
reg_idx
+
max_num_regs
),
ptr
[
reg_ptr_src_i
]);
reg_idx
++
;
reg_idx
++
;
}
}
PADDLE_ENFORCE_EQ
(
reg_idx
,
rest_used_num_regs
,
PADDLE_ENFORCE_EQ
(
"All heights should use same regs"
);
reg_idx
,
rest_used_num_regs
,
platform
::
errors
::
InvalidArgument
(
"All heights of SeqPool should use the same number of registers."
"It equals to the numbr of rest registers. But use %d registers "
"and the numbr of rest registers is %d."
,
reg_idx
,
rest_used_num_regs
));
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
for
(
int
i
=
0
;
i
<
reg_idx
;
++
i
)
{
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
vaddps
(
xmm_t
(
i
),
xmm_t
(
i
),
xmm_t
(
i
+
max_num_regs
));
}
}
...
...
paddle/fluid/operators/jit/gen/sgd.cc
浏览文件 @
36ed83d2
...
@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
...
@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
size_t
CodeSize
(
const
sgd_attr_t
&
attr
)
const
override
{
return
96
+
32
*
8
;
}
size_t
CodeSize
(
const
sgd_attr_t
&
attr
)
const
override
{
return
96
+
32
*
8
;
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
sgd_attr_t
&
attr
)
const
override
{
const
sgd_attr_t
&
attr
)
const
override
{
PADDLE_ENFORCE_EQ
(
attr
.
param_width
,
attr
.
grad_width
);
PADDLE_ENFORCE_EQ
(
attr
.
param_width
,
attr
.
grad_width
,
PADDLE_ENFORCE_LE
(
attr
.
selected_rows_size
,
attr
.
grad_height
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GE
(
attr
.
selected_rows_size
,
0
);
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
.
param_width
,
attr
.
grad_width
));
PADDLE_ENFORCE_LE
(
attr
.
selected_rows_size
,
attr
.
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
.
selected_rows_size
,
attr
.
grad_height
));
PADDLE_ENFORCE_GE
(
attr
.
selected_rows_size
,
0
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or larger than 0. But selected_rows_size is %d."
,
attr
.
selected_rows_size
));
return
make_unique
<
SgdJitCode
>
(
attr
,
CodeSize
(
attr
));
return
make_unique
<
SgdJitCode
>
(
attr
,
CodeSize
(
attr
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen/vbroadcast.cc
浏览文件 @
36ed83d2
...
@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
...
@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
return
96
+
(
w
/
YMM_FLOAT_BLOCK
)
*
16
*
8
;
return
96
+
(
w
/
YMM_FLOAT_BLOCK
)
*
16
*
8
;
}
}
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int64_t
&
w
)
const
override
{
std
::
unique_ptr
<
GenBase
>
CreateJitCode
(
const
int64_t
&
w
)
const
override
{
PADDLE_ENFORCE_GT
(
w
,
0
);
PADDLE_ENFORCE_GT
(
w
,
0
,
platform
::
errors
::
InvalidArgument
(
"The width of VBroadcast should be larger than 0. But w is %d."
,
w
));
return
make_unique
<
VBroadcastJitCode
>
(
w
,
CodeSize
(
w
));
return
make_unique
<
VBroadcastJitCode
>
(
w
,
CodeSize
(
w
));
}
}
};
};
...
...
paddle/fluid/operators/jit/gen_base.cc
浏览文件 @
36ed83d2
...
@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const {
...
@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const {
void
*
GenBase
::
operator
new
(
size_t
size
)
{
void
*
GenBase
::
operator
new
(
size_t
size
)
{
void
*
ptr
;
void
*
ptr
;
constexpr
size_t
alignment
=
32ul
;
constexpr
size_t
alignment
=
32ul
;
PADDLE_ENFORCE_EQ
(
posix_memalign
(
&
ptr
,
alignment
,
size
),
0
,
PADDLE_ENFORCE_EQ
(
"GenBase Alloc %ld error!"
,
size
);
posix_memalign
(
&
ptr
,
alignment
,
size
),
0
,
PADDLE_ENFORCE
(
ptr
,
"Fail to allocate GenBase CPU memory: size = %d ."
,
size
);
platform
::
errors
::
InvalidArgument
(
"Jitcode generator (GenBase) allocate %ld memory error!"
,
size
));
PADDLE_ENFORCE_NOT_NULL
(
ptr
,
platform
::
errors
::
InvalidArgument
(
"Fail to allocate jitcode generator "
"(GenBase) CPU memory: size = %d ."
,
size
));
return
ptr
;
return
ptr
;
}
}
...
...
paddle/fluid/operators/jit/helper.cc
浏览文件 @
36ed83d2
...
@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) {
...
@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) {
ONE_CASE
(
kEmbSeqPool
);
ONE_CASE
(
kEmbSeqPool
);
ONE_CASE
(
kSgd
);
ONE_CASE
(
kSgd
);
default:
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
kt
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"JIT kernel do not support type: %d."
,
kt
));
return
"NOT JITKernel"
;
return
"NOT JITKernel"
;
}
}
return
nullptr
;
return
nullptr
;
...
@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) {
...
@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) {
ONE_CASE
(
kAvg
);
ONE_CASE
(
kAvg
);
ONE_CASE
(
kSqrt
);
ONE_CASE
(
kSqrt
);
default:
default:
PADDLE_THROW
(
"Not support type: %d, or forget to add it."
,
tp
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"SeqPool JIT kernel do not support type: %d."
,
tp
));
return
"NOT PoolType"
;
return
"NOT PoolType"
;
}
}
return
nullptr
;
return
nullptr
;
...
@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) {
...
@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) {
}
else
if
(
lower
==
"tanh"
||
lower
==
"vtanh"
)
{
}
else
if
(
lower
==
"tanh"
||
lower
==
"vtanh"
)
{
return
kVTanh
;
return
kVTanh
;
}
}
PADDLE_THROW
(
"Not support type: %s, or forget to add this case"
,
act
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s."
,
act
));
return
kNone
;
return
kNone
;
}
}
...
@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
...
@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
int
block
,
rest
;
int
block
,
rest
;
const
auto
groups
=
packed_groups
(
n
,
k
,
&
block
,
&
rest
);
const
auto
groups
=
packed_groups
(
n
,
k
,
&
block
,
&
rest
);
std
::
for_each
(
groups
.
begin
(),
groups
.
end
(),
[
&
](
int
i
)
{
std
::
for_each
(
groups
.
begin
(),
groups
.
end
(),
[
&
](
int
i
)
{
PADDLE_ENFORCE_GT
(
i
,
0
,
"each element of groups should be larger than 0."
);
PADDLE_ENFORCE_GT
(
i
,
0
,
platform
::
errors
::
InvalidArgument
(
"Each element of groups should be larger than "
"0. However the element: %d doesn't satify."
,
i
));
});
});
int
sum
=
std
::
accumulate
(
groups
.
begin
(),
groups
.
end
(),
0
);
int
sum
=
std
::
accumulate
(
groups
.
begin
(),
groups
.
end
(),
0
);
std
::
memset
(
dst
,
0
,
k
*
sum
*
block
*
sizeof
(
float
));
std
::
memset
(
dst
,
0
,
k
*
sum
*
block
*
sizeof
(
float
));
PADDLE_ENFORCE_GE
(
sum
*
block
,
n
,
PADDLE_ENFORCE_GE
(
sum
*
block
,
n
,
"The packed n should be equal to or larger than n"
);
platform
::
errors
::
InvalidArgument
(
"The packed n (sum * block) should be equal to or "
"larger than n (matmul row size). "
"However, the packed n is %d and n is %d."
,
sum
*
block
,
n
));
const
int
block_len
=
sizeof
(
float
)
*
block
;
const
int
block_len
=
sizeof
(
float
)
*
block
;
int
n_offset
=
0
;
int
n_offset
=
0
;
...
@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
...
@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
template
<
typename
T
>
template
<
typename
T
>
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>::
value
>::
type
pack_weights
(
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
float
>::
value
>::
type
pack_weights
(
const
T
*
src
,
T
*
dst
,
int
n
,
int
k
)
{
const
T
*
src
,
T
*
dst
,
int
n
,
int
k
)
{
PADDLE_THROW
(
"Only support pack with float type."
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Only supports pack weights with float type."
));
}
}
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/operators/jit/helper.h
浏览文件 @
36ed83d2
...
@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() {
...
@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() {
auto
&
ref_pool
=
ReferKernelPool
::
Instance
().
AllKernels
();
auto
&
ref_pool
=
ReferKernelPool
::
Instance
().
AllKernels
();
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
KernelKey
kkey
(
KernelTuple
::
kernel_type
,
platform
::
CPUPlace
());
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
auto
ref_iter
=
ref_pool
.
find
(
kkey
);
PADDLE_ENFORCE
(
ref_iter
!=
ref_pool
.
end
(),
PADDLE_ENFORCE_NE
(
"Every Kernel should have reference function."
);
ref_iter
,
ref_pool
.
end
(),
platform
::
errors
::
PreconditionNotMet
(
"Every Refer Kernel of jitcode should have reference function."
));
auto
&
ref_impls
=
ref_iter
->
second
;
auto
&
ref_impls
=
ref_iter
->
second
;
for
(
auto
&
impl
:
ref_impls
)
{
for
(
auto
&
impl
:
ref_impls
)
{
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
auto
i
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
impl
.
get
());
...
@@ -101,7 +103,9 @@ template <typename KernelTuple>
...
@@ -101,7 +103,9 @@ template <typename KernelTuple>
inline
typename
KernelTuple
::
func_type
GetReferFunc
()
{
inline
typename
KernelTuple
::
func_type
GetReferFunc
()
{
auto
ker
=
GetReferKernel
<
KernelTuple
>
();
auto
ker
=
GetReferKernel
<
KernelTuple
>
();
auto
p
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
ker
);
auto
p
=
dynamic_cast
<
const
ReferKernel
<
KernelTuple
>*>
(
ker
);
PADDLE_ENFORCE
(
p
,
"The Refer kernel should exsit"
);
PADDLE_ENFORCE_NOT_NULL
(
p
,
platform
::
errors
::
InvalidArgument
(
"Get the reference code of kernel in CPU "
"failed. The Refer kernel should exsit."
));
return
p
->
GetFunc
();
return
p
->
GetFunc
();
}
}
...
@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels(
...
@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels(
// The last implementation should be reference function on CPUPlace.
// The last implementation should be reference function on CPUPlace.
auto
ref
=
GetReferKernel
<
KernelTuple
>
();
auto
ref
=
GetReferKernel
<
KernelTuple
>
();
PADDLE_ENFORCE
(
ref
!=
nullptr
,
"Refer Kernel can not be empty."
);
PADDLE_ENFORCE_NOT_NULL
(
ref
,
platform
::
errors
::
InvalidArgument
(
"Get all candicate kernel in CPU failed. "
"The Refer Kernel can not be empty."
));
res
.
emplace_back
(
ref
);
res
.
emplace_back
(
ref
);
return
res
;
return
res
;
}
}
...
@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
...
@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
std
::
string
name
=
k
->
ImplType
();
std
::
string
name
=
k
->
ImplType
();
if
(
name
==
"JitCode"
)
{
if
(
name
==
"JitCode"
)
{
auto
i
=
dynamic_cast
<
const
GenBase
*>
(
k
);
auto
i
=
dynamic_cast
<
const
GenBase
*>
(
k
);
PADDLE_ENFORCE
(
i
,
"jitcode kernel cast can not fail."
);
PADDLE_ENFORCE_NOT_NULL
(
i
,
platform
::
errors
::
InvalidArgument
(
"Generate jitcode kernel (GenBase) failed."
));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
template
getCode
<
Func
>()));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
template
getCode
<
Func
>()));
}
else
{
}
else
{
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
k
);
auto
i
=
dynamic_cast
<
const
KernelMore
<
KernelTuple
>*>
(
k
);
PADDLE_ENFORCE
(
i
,
"kernel cast can not fail."
);
PADDLE_ENFORCE_NOT_NULL
(
i
,
platform
::
errors
::
InvalidArgument
(
"Kernel cast (KernelMore) failed."
));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
GetFunc
()));
res
.
emplace_back
(
std
::
make_pair
(
name
,
i
->
GetFunc
()));
}
}
}
}
...
@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
...
@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename
KernelTuple
::
func_type
GetDefaultBestFunc
(
typename
KernelTuple
::
func_type
GetDefaultBestFunc
(
const
typename
KernelTuple
::
attr_type
&
attr
)
{
const
typename
KernelTuple
::
attr_type
&
attr
)
{
auto
funcs
=
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
attr
);
auto
funcs
=
GetAllCandidateFuncs
<
KernelTuple
,
PlaceType
>
(
attr
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
);
PADDLE_ENFORCE_GE
(
funcs
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The candicate jit kernel is at least one in CPU."
));
// Here could do some runtime benchmark of this attr and return the best one.
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
// which is searched in order and tuned by offline.
...
...
paddle/fluid/operators/jit/more/mix/mix.cc
浏览文件 @
36ed83d2
...
@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
...
@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
}
else
if
(
type
==
kVIdentity
)
{
}
else
if
(
type
==
kVIdentity
)
{
return
KernelFuncs
<
VIdentityTuple
<
T
>
,
CPUPlace
>::
Cache
().
At
(
d
);
return
KernelFuncs
<
VIdentityTuple
<
T
>
,
CPUPlace
>::
Cache
().
At
(
d
);
}
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s"
,
type
));
return
nullptr
;
return
nullptr
;
}
}
...
...
paddle/fluid/operators/jit/more/mkl/mkl.h
浏览文件 @
36ed83d2
...
@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
...
@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
template
<
typename
T
>
template
<
typename
T
>
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
const
emb_seq_pool_attr_t
*
attr
)
{
const
emb_seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
);
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d, "
"out_width is %d."
,
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
));
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
"idx value: %d, i: %d"
,
PADDLE_ENFORCE_LT
(
idx
[
i
],
i
);
idx
[
i
],
attr
->
table_height
,
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
platform
::
errors
::
InvalidArgument
(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d."
,
i
,
idx
[
i
],
attr
->
table_height
));
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d."
,
i
,
idx
[
i
]));
};
};
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
...
@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
...
@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
template
<
typename
T
>
template
<
typename
T
>
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
T
*
out
,
const
sgd_attr_t
*
attr
)
{
T
*
out
,
const
sgd_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
);
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
,
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
);
platform
::
errors
::
InvalidArgument
(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
->
param_width
,
attr
->
grad_width
));
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
->
selected_rows_size
,
attr
->
grad_height
));
T
scalar
=
-
lr
[
0
];
T
scalar
=
-
lr
[
0
];
int
width
=
attr
->
grad_width
;
int
width
=
attr
->
grad_width
;
if
(
out
==
param
)
{
if
(
out
==
param
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
VAXPY
(
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
VAXPY
(
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
}
}
}
else
{
}
else
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
VScal
(
&
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
VScal
(
&
scalar
,
grad
+
i
*
width
,
out
+
h_idx
*
width
,
width
);
VAdd
(
param
+
h_idx
*
width
,
out
+
h_idx
*
width
,
out
+
h_idx
*
width
,
VAdd
(
param
+
h_idx
*
width
,
out
+
h_idx
*
width
,
out
+
h_idx
*
width
,
width
);
width
);
...
...
paddle/fluid/operators/jit/refer/refer.h
浏览文件 @
36ed83d2
...
@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
...
@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
}
else
if
(
type
==
kVIdentity
)
{
}
else
if
(
type
==
kVIdentity
)
{
return
VIdentity
<
T
>
;
return
VIdentity
<
T
>
;
}
}
PADDLE_THROW
(
"Not support type: %s"
,
type
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Act JIT kernel do not support type: %s."
,
type
));
return
nullptr
;
return
nullptr
;
}
}
...
@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
...
@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
template
<
typename
T
>
template
<
typename
T
>
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
void
EmbSeqPool
(
const
T
*
table
,
const
int64_t
*
idx
,
T
*
out
,
const
emb_seq_pool_attr_t
*
attr
)
{
const
emb_seq_pool_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
);
PADDLE_ENFORCE_EQ
(
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
,
platform
::
errors
::
InvalidArgument
(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d and "
"out_width is %d."
,
attr
->
table_width
*
attr
->
index_width
,
attr
->
out_width
));
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
auto
check_idx_value_valid
=
[
&
](
int64_t
i
)
{
PADDLE_ENFORCE_LT
(
idx
[
i
],
attr
->
table_height
,
"idx value: %d, i: %d"
,
PADDLE_ENFORCE_LT
(
idx
[
i
],
i
);
idx
[
i
],
attr
->
table_height
,
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
"idx value: %d, i: %d"
,
idx
[
i
],
i
);
platform
::
errors
::
InvalidArgument
(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d."
,
i
,
idx
[
i
],
attr
->
table_height
));
PADDLE_ENFORCE_GE
(
idx
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d."
,
i
,
idx
[
i
]));
};
};
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
for
(
int64_t
w
=
0
;
w
!=
attr
->
index_width
;
++
w
)
{
...
@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
...
@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
template
<
typename
T
>
template
<
typename
T
>
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
void
Sgd
(
const
T
*
lr
,
const
T
*
param
,
const
T
*
grad
,
const
int64_t
*
rows
,
T
*
out
,
const
sgd_attr_t
*
attr
)
{
T
*
out
,
const
sgd_attr_t
*
attr
)
{
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
);
PADDLE_ENFORCE_EQ
(
attr
->
param_width
,
attr
->
grad_width
,
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
);
platform
::
errors
::
InvalidArgument
(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d."
,
attr
->
param_width
,
attr
->
grad_width
));
PADDLE_ENFORCE_LE
(
attr
->
selected_rows_size
,
attr
->
grad_height
,
platform
::
errors
::
InvalidArgument
(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d."
,
attr
->
selected_rows_size
,
attr
->
grad_height
));
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
attr
->
selected_rows_size
;
++
i
)
{
auto
h_idx
=
rows
[
i
];
auto
h_idx
=
rows
[
i
];
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
);
PADDLE_ENFORCE_LT
(
h_idx
,
attr
->
param_height
,
PADDLE_ENFORCE_GE
(
h_idx
,
0
);
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d."
,
i
,
h_idx
,
attr
->
param_height
));
PADDLE_ENFORCE_GE
(
h_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d."
,
i
,
h_idx
));
for
(
int64_t
j
=
0
;
j
<
attr
->
grad_width
;
++
j
)
{
for
(
int64_t
j
=
0
;
j
<
attr
->
grad_width
;
++
j
)
{
out
[
h_idx
*
attr
->
grad_width
+
j
]
=
out
[
h_idx
*
attr
->
grad_width
+
j
]
=
param
[
h_idx
*
attr
->
grad_width
+
j
]
-
param
[
h_idx
*
attr
->
grad_width
+
j
]
-
...
...
paddle/fluid/operators/jit/test.cc
浏览文件 @
36ed83d2
...
@@ -850,8 +850,15 @@ void TestKernelSgd() {
...
@@ -850,8 +850,15 @@ void TestKernelSgd() {
const
T
lr
=
0.1
;
const
T
lr
=
0.1
;
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
auto
UnDuplicatedRandomVec
=
[](
int
n
,
const
int64_t
lower
,
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
const
int64_t
upper
)
->
std
::
vector
<
int64_t
>
{
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
);
PADDLE_ENFORCE_LE
(
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
,
PADDLE_ENFORCE_GT
(
n
,
0
);
paddle
::
platform
::
errors
::
InvalidArgument
(
"The range of Sgd (upper - lower) should be lower "
"than n-1 (Sgd size -1). But the upper - lower is %d "
"and n-1 is %d."
,
static_cast
<
size_t
>
(
upper
-
lower
),
n
-
1
));
PADDLE_ENFORCE_GT
(
n
,
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The Sgd size should be larger than 0. But the n is %d."
,
n
));
std
::
vector
<
int64_t
>
all
,
out
;
std
::
vector
<
int64_t
>
all
,
out
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
all
.
push_back
(
i
);
all
.
push_back
(
i
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录