Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
600f6d82
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
600f6d82
编写于
12月 21, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
test=develop
上级
89b9d86d
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
79 addition
and
81 deletion
+79
-81
paddle/fluid/framework/ddim.cc
paddle/fluid/framework/ddim.cc
+14
-16
paddle/fluid/framework/ddim.h
paddle/fluid/framework/ddim.h
+7
-7
paddle/fluid/framework/dim.h
paddle/fluid/framework/dim.h
+58
-58
未找到文件。
paddle/fluid/framework/ddim.cc
浏览文件 @
600f6d82
...
@@ -42,7 +42,8 @@ struct DDimEqualityVisitor {
...
@@ -42,7 +42,8 @@ struct DDimEqualityVisitor {
};
};
bool
DDim
::
operator
==
(
const
DDim
&
d
)
const
{
bool
DDim
::
operator
==
(
const
DDim
&
d
)
const
{
return
rank_
==
d
.
rank_
&&
this
->
apply_visitor
(
DDimEqualityVisitor
(
d
.
Get
()));
return
size
()
==
d
.
size
()
&&
this
->
apply_visitor
(
DDimEqualityVisitor
(
d
.
Get
()));
}
}
bool
DDim
::
operator
!=
(
const
DDim
&
d
)
const
{
return
!
(
*
this
==
d
);
}
bool
DDim
::
operator
!=
(
const
DDim
&
d
)
const
{
return
!
(
*
this
==
d
);
}
...
@@ -61,7 +62,7 @@ struct DDimPlusVisitor {
...
@@ -61,7 +62,7 @@ struct DDimPlusVisitor {
};
};
DDim
DDim
::
operator
+
(
const
DDim
&
d
)
const
{
DDim
DDim
::
operator
+
(
const
DDim
&
d
)
const
{
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
PADDLE_ENFORCE
(
size
()
==
d
.
size
()
);
DDim
ret
;
DDim
ret
;
ret
.
rank_
=
rank_
;
ret
.
rank_
=
rank_
;
ret
.
apply_visitor
(
DDimPlusVisitor
(
Get
(),
d
.
Get
()));
ret
.
apply_visitor
(
DDimPlusVisitor
(
Get
(),
d
.
Get
()));
...
@@ -82,7 +83,7 @@ struct DDimMulVisitor {
...
@@ -82,7 +83,7 @@ struct DDimMulVisitor {
};
};
DDim
DDim
::
operator
*
(
const
DDim
&
d
)
const
{
DDim
DDim
::
operator
*
(
const
DDim
&
d
)
const
{
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
PADDLE_ENFORCE
(
size
()
==
d
.
size
()
);
DDim
ret
;
DDim
ret
;
ret
.
rank_
=
rank_
;
ret
.
rank_
=
rank_
;
ret
.
apply_visitor
(
DDimMulVisitor
(
Get
(),
d
.
Get
()));
ret
.
apply_visitor
(
DDimMulVisitor
(
Get
(),
d
.
Get
()));
...
@@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) {
...
@@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) {
}
}
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
PADDLE_ENFORCE
(
begin
>=
0
,
PADDLE_ENFORCE
(
begin
>=
0
&&
end
<=
dim
.
size
(),
"Begin index can't be less than zero in ddim slice."
);
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice."
,
int
len
=
end
-
begin
;
begin
,
end
,
dim
.
size
());
DDim
ret
;
// Constructor of DDim would check whether end - begin is valid
ret
.
rank_
=
len
;
return
DDim
(
dim
.
Get
()
+
begin
,
end
-
begin
);
dynamic_dim_assign
(
dim
.
Get
()
+
begin
,
ret
.
GetMutable
(),
ret
.
rank_
);
return
ret
;
}
}
int
arity
(
const
DDim
&
d
)
{
return
d
.
size
();
}
int
arity
(
const
DDim
&
d
)
{
return
d
.
size
();
}
...
@@ -138,8 +137,8 @@ struct DDimPrinter {
...
@@ -138,8 +137,8 @@ struct DDimPrinter {
std
::
ostream
&
os
;
std
::
ostream
&
os
;
explicit
DDimPrinter
(
std
::
ostream
&
os_
)
:
os
(
os_
)
{}
explicit
DDimPrinter
(
std
::
ostream
&
os_
)
:
os
(
os_
)
{}
template
<
typename
T
>
template
<
int
D
>
void
operator
()(
const
T
&
t
)
{
void
operator
()(
const
Dim
<
D
>
&
t
)
{
os
<<
t
;
os
<<
t
;
}
}
};
};
...
@@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
...
@@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
}
}
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
DDim
flatten_to_2d
(
const
DDim
&
src
,
int
num_col_dims
)
{
int
rank
=
src
.
size
();
return
DDim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
return
make_ddim
({
product
(
slice_ddim
(
src
,
0
,
num_col_dims
)),
product
(
slice_ddim
(
src
,
num_col_dims
,
src
.
size
()))});
product
(
slice_ddim
(
src
,
num_col_dims
,
rank
))});
}
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_dd
im
({
product
(
src
)});
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
DD
im
({
product
(
src
)});
}
DDim
stride
(
const
DDim
&
ddim
)
{
DDim
stride
(
const
DDim
&
ddim
)
{
DDim
strides
;
DDim
strides
;
...
...
paddle/fluid/framework/ddim.h
浏览文件 @
600f6d82
...
@@ -124,16 +124,16 @@ class DDim {
...
@@ -124,16 +124,16 @@ class DDim {
inline
int
size
()
const
{
return
rank_
;
}
inline
int
size
()
const
{
return
rank_
;
}
private:
private:
template
<
int
M
>
template
<
int
D
>
inline
Dim
<
M
>&
UnsafeCast
()
{
inline
Dim
<
D
>&
UnsafeCast
()
{
return
const_cast
<
Dim
<
M
>&>
(
const_cast
<
const
DDim
*>
(
this
)
->
UnsafeCast
<
M
>
());
return
const_cast
<
Dim
<
D
>&>
(
const_cast
<
const
DDim
*>
(
this
)
->
UnsafeCast
<
D
>
());
}
}
template
<
int
M
>
template
<
int
D
>
inline
const
Dim
<
M
>&
UnsafeCast
()
const
{
inline
const
Dim
<
D
>&
UnsafeCast
()
const
{
static_assert
(
M
>=
0
&&
M
<=
kMaxRank
,
"Invalid rank"
);
static_assert
(
D
>=
0
&&
D
<=
kMaxRank
,
"Invalid rank"
);
auto
*
p
=
static_cast
<
const
void
*>
(
&
dim_
);
auto
*
p
=
static_cast
<
const
void
*>
(
&
dim_
);
return
*
reinterpret_cast
<
const
Dim
<
M
>*>
(
p
);
return
*
reinterpret_cast
<
const
Dim
<
D
>*>
(
p
);
}
}
friend
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
friend
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
...
...
paddle/fluid/framework/dim.h
浏览文件 @
600f6d82
...
@@ -28,17 +28,17 @@ namespace paddle {
...
@@ -28,17 +28,17 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
// Statically sized, statically indexed dimension
// Statically sized, statically indexed dimension
template
<
int
N
>
template
<
int
D
>
class
Dim
:
public
Array
<
int64_t
,
N
>
{
class
Dim
:
public
Array
<
int64_t
,
D
>
{
public:
public:
static_assert
(
N
>=
0
,
"N
must be not less than 0"
);
static_assert
(
D
>=
0
,
"D
must be not less than 0"
);
static
constexpr
int
kRank
=
N
;
static
constexpr
int
kRank
=
D
;
using
BaseClass
=
Array
<
int64_t
,
N
>
;
using
BaseClass
=
Array
<
int64_t
,
D
>
;
inline
Dim
(
int64_t
head
,
const
Dim
<
N
-
1
>&
tail
)
{
inline
Dim
(
int64_t
head
,
const
Dim
<
D
-
1
>&
tail
)
{
(
*
this
)[
0
]
=
head
;
(
*
this
)[
0
]
=
head
;
new
(
this
->
GetMutable
()
+
1
)
Dim
<
N
-
1
>
(
tail
);
new
(
this
->
GetMutable
()
+
1
)
Dim
<
D
-
1
>
(
tail
);
}
}
template
<
typename
...
Args
>
template
<
typename
...
Args
>
...
@@ -47,7 +47,7 @@ class Dim : public Array<int64_t, N> {
...
@@ -47,7 +47,7 @@ class Dim : public Array<int64_t, N> {
/** Construct a Dim from a linear index and size. Uses Fortran order
/** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */
* indexing. */
HOSTDEVICE
Dim
(
int64_t
idx
,
const
Dim
<
N
>&
size
);
HOSTDEVICE
Dim
(
int64_t
idx
,
const
Dim
<
D
>&
size
);
/** Construct a Dim with each dimension set to the given index */
/** Construct a Dim with each dimension set to the given index */
HOSTDEVICE
explicit
Dim
(
int64_t
idx
)
{
this
->
Fill
(
idx
);
}
HOSTDEVICE
explicit
Dim
(
int64_t
idx
)
{
this
->
Fill
(
idx
);
}
...
@@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
...
@@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
Dim
<
N
>::
Dim
(
int64_t
idx
,
const
Dim
<
N
>&
size
)
{
HOSTDEVICE
Dim
<
D
>::
Dim
(
int64_t
idx
,
const
Dim
<
D
>&
size
)
{
detail
::
FortranOrderIndexingConstructorFunctor
<
0
,
N
,
N
==
0
>::
Run
(
detail
::
FortranOrderIndexingConstructorFunctor
<
0
,
D
,
D
==
0
>::
Run
(
size
.
Get
(),
&
idx
,
this
->
GetMutable
());
size
.
Get
(),
&
idx
,
this
->
GetMutable
());
}
}
template
<
int
idx
,
int
N
>
template
<
int
idx
,
int
D
>
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
N
>&
dim
)
{
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
D
>&
dim
)
{
return
dim
[
idx
];
return
dim
[
idx
];
}
}
template
<
int
idx
,
int
N
>
template
<
int
idx
,
int
D
>
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
N
>&
dim
)
{
// NOLINT
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
D
>&
dim
)
{
// NOLINT
return
dim
[
idx
];
return
dim
[
idx
];
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
N
>&
dim
,
int
idx
)
{
HOSTDEVICE
inline
int64_t
get
(
const
Dim
<
D
>&
dim
,
int
idx
)
{
return
dim
[
idx
];
return
dim
[
idx
];
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
N
>&
dim
,
int
idx
)
{
// NOLINT
HOSTDEVICE
inline
int64_t
&
get
(
Dim
<
D
>&
dim
,
int
idx
)
{
// NOLINT
return
dim
[
idx
];
return
dim
[
idx
];
}
}
// Dot product of two dims
// Dot product of two dims
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
N
>&
a
,
const
Dim
<
N
>&
b
)
{
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
D
>&
a
,
const
Dim
<
D
>&
b
)
{
return
UnrollProduct
<
N
>::
Run
(
a
.
Get
(),
b
.
Get
());
return
UnrollProduct
<
D
>::
Run
(
a
.
Get
(),
b
.
Get
());
}
}
// Product of a Dim
// Product of a Dim
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
N
>&
a
)
{
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
D
>&
a
)
{
return
UnrollProduct
<
N
>::
Run
(
a
.
Get
());
return
UnrollProduct
<
D
>::
Run
(
a
.
Get
());
}
}
// Is 0 <= idx_i < size_i for all i?
// Is 0 <= idx_i < size_i for all i?
...
@@ -135,9 +135,9 @@ struct ContainedFunctor<kStart, kEnd, true> {
...
@@ -135,9 +135,9 @@ struct ContainedFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
N
>&
idx
,
const
Dim
<
N
>&
size
)
{
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
D
>&
idx
,
const
Dim
<
D
>&
size
)
{
return
detail
::
ContainedFunctor
<
0
,
N
,
N
==
0
>::
Run
(
idx
.
Get
(),
size
.
Get
());
return
detail
::
ContainedFunctor
<
0
,
D
,
D
==
0
>::
Run
(
idx
.
Get
(),
size
.
Get
());
}
}
/**
/**
...
@@ -160,40 +160,40 @@ struct ExPrefixMulFunctor<kStart, kEnd, true> {
...
@@ -160,40 +160,40 @@ struct ExPrefixMulFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
ex_prefix_mul
(
const
Dim
<
N
>&
src
)
{
HOSTDEVICE
inline
Dim
<
D
>
ex_prefix_mul
(
const
Dim
<
D
>&
src
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
detail
::
ExPrefixMulFunctor
<
0
,
N
,
N
==
0
>::
Run
(
src
.
Get
(),
ret
.
GetMutable
());
detail
::
ExPrefixMulFunctor
<
0
,
D
,
D
==
0
>::
Run
(
src
.
Get
(),
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
/**
/**
* Add two dimensions together
* Add two dimensions together
*/
*/
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
dim_plus
(
const
Dim
<
N
>&
a
,
const
Dim
<
N
>&
b
)
{
HOSTDEVICE
inline
Dim
<
D
>
dim_plus
(
const
Dim
<
D
>&
a
,
const
Dim
<
D
>&
b
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
UnrollAdd
<
N
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
UnrollAdd
<
D
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
operator
+
(
const
Dim
<
N
>&
lhs
,
const
Dim
<
N
>&
rhs
)
{
HOSTDEVICE
inline
Dim
<
D
>
operator
+
(
const
Dim
<
D
>&
lhs
,
const
Dim
<
D
>&
rhs
)
{
return
dim_plus
(
lhs
,
rhs
);
return
dim_plus
(
lhs
,
rhs
);
}
}
/**
/**
* Multiply two dimensions together
* Multiply two dimensions together
*/
*/
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
inline
Dim
<
N
>
dim_mult
(
const
Dim
<
N
>&
a
,
const
Dim
<
N
>&
b
)
{
HOSTDEVICE
inline
Dim
<
D
>
dim_mult
(
const
Dim
<
D
>&
a
,
const
Dim
<
D
>&
b
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
UnrollMul
<
N
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
UnrollMul
<
D
>::
Run
(
a
.
Get
(),
b
.
Get
(),
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
template
<
int
i
>
template
<
int
D
>
HOSTDEVICE
Dim
<
i
>
operator
*
(
const
Dim
<
i
>&
lhs
,
const
Dim
<
i
>&
rhs
)
{
HOSTDEVICE
Dim
<
D
>
operator
*
(
const
Dim
<
D
>&
lhs
,
const
Dim
<
D
>&
rhs
)
{
return
dim_mult
(
lhs
,
rhs
);
return
dim_mult
(
lhs
,
rhs
);
}
}
...
@@ -224,10 +224,10 @@ struct NormalizeStridesFunctor<kStart, kEnd, true> {
...
@@ -224,10 +224,10 @@ struct NormalizeStridesFunctor<kStart, kEnd, true> {
};
};
}
// namespace detail
}
// namespace detail
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
Dim
<
N
>
normalize_strides
(
const
Dim
<
N
>&
size
,
const
Dim
<
N
>&
stride
)
{
HOSTDEVICE
Dim
<
D
>
normalize_strides
(
const
Dim
<
D
>&
size
,
const
Dim
<
D
>&
stride
)
{
Dim
<
N
>
ret
;
Dim
<
D
>
ret
;
detail
::
NormalizeStridesFunctor
<
0
,
N
,
N
==
0
>::
Run
(
size
.
Get
(),
stride
.
Get
(),
detail
::
NormalizeStridesFunctor
<
0
,
D
,
D
==
0
>::
Run
(
size
.
Get
(),
stride
.
Get
(),
ret
.
GetMutable
());
ret
.
GetMutable
());
return
ret
;
return
ret
;
}
}
...
@@ -245,10 +245,10 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
...
@@ -245,10 +245,10 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
}
}
// Allows us to output a Dim
// Allows us to output a Dim
template
<
int
N
>
template
<
int
D
>
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
N
>&
d
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
D
>&
d
)
{
os
<<
d
[
0
];
os
<<
d
[
0
];
for
(
int
i
=
1
;
i
<
N
;
++
i
)
{
for
(
int
i
=
1
;
i
<
D
;
++
i
)
{
os
<<
", "
<<
d
[
i
];
os
<<
", "
<<
d
[
i
];
}
}
return
os
;
return
os
;
...
@@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
...
@@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return
os
;
return
os
;
}
}
template
<
int
N
>
template
<
int
D
>
HOST
std
::
string
Dim
<
N
>::
to_string
()
const
{
HOST
std
::
string
Dim
<
D
>::
to_string
()
const
{
std
::
stringstream
stream
;
std
::
stringstream
stream
;
stream
<<
*
this
;
stream
<<
*
this
;
return
stream
.
str
();
return
stream
.
str
();
}
}
template
<
int
N
>
template
<
int
D
>
HOSTDEVICE
Dim
<
N
>
linear_to_dimension
(
int
linear_index
,
const
Dim
<
N
>&
extents
)
{
HOSTDEVICE
Dim
<
D
>
linear_to_dimension
(
int
linear_index
,
const
Dim
<
D
>&
extents
)
{
Dim
<
N
>
result
;
Dim
<
D
>
result
;
for
(
int
i
=
0
;
i
<
N
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
D
-
1
;
++
i
)
{
result
[
i
]
=
linear_index
%
extents
[
i
];
result
[
i
]
=
linear_index
%
extents
[
i
];
linear_index
/=
extents
[
i
];
linear_index
/=
extents
[
i
];
}
}
result
[
N
-
1
]
=
linear_index
;
result
[
D
-
1
]
=
linear_index
;
return
result
;
return
result
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录