Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
600f6d82
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
600f6d82
编写于
12月 21, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
test=develop
上级
89b9d86d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
79 addition
and
81 deletion
+79
-81
paddle/fluid/framework/ddim.cc
paddle/fluid/framework/ddim.cc
+14
-16
paddle/fluid/framework/ddim.h
paddle/fluid/framework/ddim.h
+7
-7
paddle/fluid/framework/dim.h
paddle/fluid/framework/dim.h
+58
-58
未找到文件。
paddle/fluid/framework/ddim.cc
浏览文件 @
600f6d82
...
@@ -42,7 +42,8 @@ struct DDimEqualityVisitor {
...
@@ -42,7 +42,8 @@ struct DDimEqualityVisitor {
};
};
bool
DDim
::
operator
==
(
const
DDim
&
d
)
const
{
bool
DDim
::
operator
==
(
const
DDim
&
d
)
const
{
return
rank_
==
d
.
rank_
&&
this
->
apply_visitor
(
DDimEqualityVisitor
(
d
.
Get
()));
return
size
()
==
d
.
size
()
&&
this
->
apply_visitor
(
DDimEqualityVisitor
(
d
.
Get
()));
}
}
bool
DDim
::
operator
!=
(
const
DDim
&
d
)
const
{
return
!
(
*
this
==
d
);
}
bool
DDim
::
operator
!=
(
const
DDim
&
d
)
const
{
return
!
(
*
this
==
d
);
}
...
@@ -61,7 +62,7 @@ struct DDimPlusVisitor {
...
@@ -61,7 +62,7 @@ struct DDimPlusVisitor {
};
};
DDim
DDim
::
operator
+
(
const
DDim
&
d
)
const
{
DDim
DDim
::
operator
+
(
const
DDim
&
d
)
const
{
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
PADDLE_ENFORCE
(
size
()
==
d
.
size
()
);
DDim
ret
;
DDim
ret
;
ret
.
rank_
=
rank_
;
ret
.
rank_
=
rank_
;
ret
.
apply_visitor
(
DDimPlusVisitor
(
Get
(),
d
.
Get
()));
ret
.
apply_visitor
(
DDimPlusVisitor
(
Get
(),
d
.
Get
()));
...
@@ -82,7 +83,7 @@ struct DDimMulVisitor {
...
@@ -82,7 +83,7 @@ struct DDimMulVisitor {
};
};
DDim
DDim
::
operator
*
(
const
DDim
&
d
)
const
{
DDim
DDim
::
operator
*
(
const
DDim
&
d
)
const
{
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
PADDLE_ENFORCE
(
size
()
==
d
.
size
()
);
DDim
ret
;
DDim
ret
;
ret
.
rank_
=
rank_
;
ret
.
rank_
=
rank_
;
ret
.
apply_visitor
(
DDimMulVisitor
(
Get
(),
d
.
Get
()));
ret
.
apply_visitor
(
DDimMulVisitor
(
Get
(),
d
.
Get
()));
...
@@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) {
...
@@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) {
}
}
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
PADDLE_ENFORCE
(
begin
>=
0
,
PADDLE_ENFORCE
(
begin
>=
0
&&
end
<=
dim
.
size
(),
"Begin index can't be less than zero in ddim slice."
);
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice."
,
int
len
=
end
-
begin
;
begin
,
end
,
dim
.
size
());
DDim
ret
;
// Constructor of DDim would check whether end - begin is valid
ret
.
rank_
=
len
;
return
DDim
(
dim
.
Get
()
+
begin
,
end
-
begin
);
dynamic_dim_assign
(
dim
.
Get
()
+
begin
,
ret
.
GetMutable
(),
ret
.
rank_
);
return
ret
;
}
}
int
arity
(
const
DDim
&
d
)
{
return
d
.
size
();
}
int
arity
(
const
DDim
&
d
)
{
return
d
.
size
();
}
...
@@ -138,8 +137,8 @@ struct DDimPrinter {
...
@@ -138,8 +137,8 @@ struct DDimPrinter {
std
::
ostream
&
os
;
std
::
ostream
&
os
;
explicit
DDimPrinter
(
std
::
ostream
&
os_
)
:
os
(
os_
)
{}
explicit
DDimPrinter
(
std
::
ostream
&
os_
)
:
os
(
os_
)
{}
template
<
typename
T
>
template
<
int
D
>
void
operator
()(
const
T
&
t
)
{
void
operator
()(
const
Dim
<
D
>
&
t
)
{
os
<<
t
;
os
<<
t
;
}
}
};
};
...
@@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
...
@@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
}
}
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
int
rank
=
src
.
size
();
return
DDim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
return
make_ddim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
product
(
slice_ddim
(
src
,
num_col_dims
,
src
.
size
()))});
product
(
slice_ddim
(
src
,
num_col_dims
,
rank
))});
}
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_dd
im
({
product
(
src
)});
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
DD
im
({
product
(
src
)});
}
DDim
stride
(
const
DDim
&
ddim
)
{
DDim
stride
(
const
DDim
&
ddim
)
{
DDim
strides
;
DDim
strides
;
...
...
paddle/fluid/framework/ddim.h
浏览文件 @
600f6d82
...
@@ -124,16 +124,16 @@ class DDim {
...
@@ -124,16 +124,16 @@ class DDim {
inline
int
size
()
const
{
return
rank_
;
}
inline
int
size
()
const
{
return
rank_
;
}
private:
private:
template
<
int
M
>
template
<
int
D
>
inline
Dim
<
M
>&
UnsafeCast
()
{
inline
Dim
<
D
>&
UnsafeCast
()
{
return
const_cast
<
Dim
<
M
>&>
(
const_cast
<
const
DDim
*>
(
this
)
->
UnsafeCast
<
M
>
());
return
const_cast
<
Dim
<
D
>&>
(
const_cast
<
const
DDim
*>
(
this
)
->
UnsafeCast
<
D
>
());
}
}
template
<
int
M
>
template
<
int
D
>
inline
const
Dim
<
M
>&
UnsafeCast
()
const
{
inline
const
Dim
<
D
>&
UnsafeCast
()
const
{
static_assert
(
M
>=
0
&&
M
<=
kMaxRank
,
"Invalid rank"
);
static_assert
(
D
>=
0
&&
D
<=
kMaxRank
,
"Invalid rank"
);
auto
*
p
=
static_cast
<
const
void
*>
(
&
dim_
);
auto
*
p
=
static_cast
<
const
void
*>
(
&
dim_
);
return
*
reinterpret_cast
<
const
Dim
<
M
>*>
(
p
);
return
*
reinterpret_cast
<
const
Dim
<
D
>*>
(
p
);
}
}
friend
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
friend
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
...
...
paddle/fluid/framework/dim.h
浏览文件 @
600f6d82
...
@@ -28,17 +28,17 @@ namespace paddle {
...
@@ -28,17 +28,17 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
// Statically sized, statically indexed dimension
// Statically sized, statically indexed dimension
template
<
int
N
>
template
<
int
D
>
class
Dim
:
public
Array
<
int64_t
,
N
>
{
class
Dim
:
public
Array
<
int64_t
,
D
>
{
public:
public:
static_assert
(
N
>=
0
,
"N
must be not less than 0"
);
static_assert
(
D
>=
0
,
"D
must be not less than 0"
);
static
constexpr
int
kRank
=
N
;
static
constexpr
int
kRank
=
D
;
using
BaseClass
=
Array
<
int64_t
,
N
>
;
using
BaseClass
=
Array
<
int64_t
,
D
>
;
inline
Dim
(
int64_t
head
,
const
Dim
<
N
-
1
>&
tail
)
{
inline
Dim
(
int64_t
head
,
const
Dim
<
D
-
1
>&
tail
)
{
(
*
this
)[
0
]
=
head
;
(
*
this
)[
0
]
=
head
;
new
(
this
->
GetMutable
()
+
1
)
Dim
<
N
-
1
>
(
tail
);
new
(
this
->
GetMutable
()
+
1
)
Dim
<
D
-
1
>
(
tail
);
}
}
template
<
typename
...
Args
>
template
<
typename
...
Args
>
...
@@ -47,7 +47,7 @@ class Dim : public Array<int64_t, N> {
...
@@ -47,7 +47,7 @@ class Dim : public Array<int64_t, N> {
/** Construct a Dim from a linear index and size. Uses Fortran order
/** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */
* indexing. */
HOSTDEVICE
Dim
(
int64_t
idx
,
const
Dim
<
N
>&
size
);
HOSTDEVICE
Dim
(
int64_t
idx
,
const
Dim
<
D
>&
size
);
/** Construct a Dim with each dimension set to the given index */
/** Construct a Dim with each dimension set to the given index */
HOSTDEVICE
explicit
Dim
(
int64_t
idx
)
{
this
->
Fill
(
idx
);
}
HOSTDEVICE
explicit
Dim
(
int64_t
idx
)
{
this
->
Fill
(
idx
);
}
...
@@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
...
@@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
Dim
<
N
>::
Dim
(
int64_t
idx
,
const
Dim
<
N
>&
size
)
{
HOSTDEVICE
Dim
<
D
>::
Dim
(
int64_t
idx
,
const
Dim
<
D
>&
size
)
{
detail
::
FortranOrderIndexingConstructorFunctor
<
0
,
N
,
N
==
0
>::
Run
(
detail
::
FortranOrderIndexingConstructorFunctor
<
0
,
D
,
D
==
0
>::
Run
(
size
.
Get
(),
&
idx
,
this
->
GetMutable
());
size
.
Get
(),
&
idx
,
this
->
GetMutable
());
}
}
template
<
int
idx
,
int
N
>
template
<
int
idx
,
int
D
>
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
N
>&
dim
)
{
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
D
>&
dim
)
{
return
dim
[
idx
];
return
dim
[
idx
];
}
}
template
<
int
idx
,
int
N
>
template
<
int
idx
,
int
D
>
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
N
>&
dim
)
{
// NOLINT
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
D
>&
dim
)
{
// NOLINT
return
dim
[
idx
];
return
dim
[
idx
];
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
N
>&
dim
,
int
idx
)
{
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
D
>&
dim
,
int
idx
)
{
return
dim
[
idx
];
return
dim
[
idx
];
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
N
>&
dim
,
int
idx
)
{
// NOLINT
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
D
>&
dim
,
int
idx
)
{
// NOLINT
return
dim
[
idx
];
return
dim
[
idx
];
}
}
// Dot product of two dims
// Dot product of two dims
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
N
>&
a
,
const
Dim
<
N
>&
b
)
{
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
D
>&
a
,
const
Dim
<
D
>&
b
)
{
return
UnrollProduct
<
N
>::
Run
(
a
.
Get
(),
b
.
Get
());
return
UnrollProduct
<
D
>::
Run
(
a
.
Get
(),
b
.
Get
());
}
}
// Product of a Dim
// Product of a Dim
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
N
>&
a
)
{
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
D
>&
a
)
{
return
UnrollProduct
<
N
>::
Run
(
a
.
Get
());
return
UnrollProduct
<
D
>::
Run
(
a
.
Get
());
}
}
// Is 0 <= idx_i < size_i for all i?
// Is 0 <= idx_i < size_i for all i?
...
@@ -135,9 +135,9 @@ struct ContainedFunctor<kStart, kEnd, true> {
...
@@ -135,9 +135,9 @@ struct ContainedFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
N
>&
idx
,
const
Dim
<
N
>&
size
)
{
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
D
>&
idx
,
const
Dim
<
D
>&
size
)
{
return
detail
::
ContainedFunctor
<
0
,
N
,
N
==
0
>::
Run
(
idx
.
Get
(),
size
.
Get
());
return
detail
::
ContainedFunctor
<
0
,
D
,
D
==
0
>::
Run
(
idx
.
Get
(),
size
.
Get
());
}
}
/**
/**
...
@@ -160,40 +160,40 @@ struct ExPrefixMulFunctor<kStart, kEnd, true> {
...
@@ -160,40 +160,40 @@ struct ExPrefixMulFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
ex_prefix_mul
(
const
Dim
<
N
>&
src
)
{
HOSTDEVICE
inline
Dim
<
D
>
ex_prefix_mul
(
const
Dim
<
D
>&
src
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
detail
::
ExPrefixMulFunctor
<
0
,
N
,
N
==
0
>::
Run
(
src
.
Get
(),
ret
.
GetMutable
());
detail
::
ExPrefixMulFunctor
<
0
,
D
,
D
==
0
>::
Run
(
src
.
Get
(),
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
/**
/**
* Add two dimensions together
* Add two dimensions together
*/
*/
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
dim_plus
(
const
Dim
<
N
>&
a
,
const
Dim
<
N
>&
b
)
{
HOSTDEVICE
inline
Dim
<
D
>
dim_plus
(
const
Dim
<
D
>&
a
,
const
Dim
<
D
>&
b
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
UnrollAdd
<
N
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
UnrollAdd
<
D
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
operator
+
(
const
Dim
<
N
>&
lhs
,
const
Dim
<
N
>&
rhs
)
{
HOSTDEVICE
inline
Dim
<
D
>
operator
+
(
const
Dim
<
D
>&
lhs
,
const
Dim
<
D
>&
rhs
)
{
return
dim_plus
(
lhs
,
rhs
);
return
dim_plus
(
lhs
,
rhs
);
}
}
/**
/**
* Multiply two dimensions together
* Multiply two dimensions together
*/
*/
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
dim_mult
(
const
Dim
<
N
>&
a
,
const
Dim
<
N
>&
b
)
{
HOSTDEVICE
inline
Dim
<
D
>
dim_mult
(
const
Dim
<
D
>&
a
,
const
Dim
<
D
>&
b
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
UnrollMul
<
N
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
UnrollMul
<
D
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
template
<
int
i
>
template
<
int
D
>
HOSTDEVICE
Dim
<
i
>
operator
*
(
const
Dim
<
i
>&
lhs
,
const
Dim
<
i
>&
rhs
)
{
HOSTDEVICE
Dim
<
D
>
operator
*
(
const
Dim
<
D
>&
lhs
,
const
Dim
<
D
>&
rhs
)
{
return
dim_mult
(
lhs
,
rhs
);
return
dim_mult
(
lhs
,
rhs
);
}
}
...
@@ -224,10 +224,10 @@ struct NormalizeStridesFunctor<kStart, kEnd, true> {
...
@@ -224,10 +224,10 @@ struct NormalizeStridesFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
Dim
<
N
>
normalize_strides
(
const
Dim
<
N
>&
size
,
const
Dim
<
N
>&
stride
)
{
HOSTDEVICE
Dim
<
D
>
normalize_strides
(
const
Dim
<
D
>&
size
,
const
Dim
<
D
>&
stride
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
detail
::
NormalizeStridesFunctor
<
0
,
N
,
N
==
0
>::
Run
(
size
.
Get
(),
stride
.
Get
(),
detail
::
NormalizeStridesFunctor
<
0
,
D
,
D
==
0
>::
Run
(
size
.
Get
(),
stride
.
Get
(),
ret
.
GetMutable
());
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
...
@@ -245,10 +245,10 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
...
@@ -245,10 +245,10 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
}
}
// Allows us to output a Dim
// Allows us to output a Dim
template
<
int
N
>
template
<
int
D
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
N
>&
d
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
D
>&
d
)
{
os
<<
d
[
0
];
os
<<
d
[
0
];
for
(
int
i
=
1
;
i
<
N
;
++
i
)
{
for
(
int
i
=
1
;
i
<
D
;
++
i
)
{
os
<<
", "
<<
d
[
i
];
os
<<
", "
<<
d
[
i
];
}
}
return
os
;
return
os
;
...
@@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
...
@@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return
os
;
return
os
;
}
}
template
<
int
N
>
template
<
int
D
>
HOST
std
::
string
Dim
<
N
>::
to_string
()
const
{
HOST
std
::
string
Dim
<
D
>::
to_string
()
const
{
std
::
stringstream
stream
;
std
::
stringstream
stream
;
stream
<<
*
this
;
stream
<<
*
this
;
return
stream
.
str
();
return
stream
.
str
();
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
Dim
<
N
>
linear_to_dimension
(
int
linear_index
,
const
Dim
<
N
>&
extents
)
{
HOSTDEVICE
Dim
<
D
>
linear_to_dimension
(
int
linear_index
,
const
Dim
<
D
>&
extents
)
{
Dim
<
N
>
result
;
Dim
<
D
>
result
;
for
(
int
i
=
0
;
i
<
N
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
D
-
1
;
++
i
)
{
result
[
i
]
=
linear_index
%
extents
[
i
];
result
[
i
]
=
linear_index
%
extents
[
i
];
linear_index
/=
extents
[
i
];
linear_index
/=
extents
[
i
];
}
}
result
[
N
-
1
]
=
linear_index
;
result
[
D
-
1
]
=
linear_index
;
return
result
;
return
result
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录