Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
83ac8515
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
83ac8515
编写于
12月 20, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
test=develop
上级
045dc127
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
48 addition
and
85 deletion
+48
-85
paddle/fluid/framework/ddim.cc
paddle/fluid/framework/ddim.cc
+7
-19
paddle/fluid/framework/ddim.h
paddle/fluid/framework/ddim.h
+39
-64
paddle/fluid/operators/detail/strided_memcpy.h
paddle/fluid/operators/detail/strided_memcpy.h
+2
-2
未找到文件。
paddle/fluid/framework/ddim.cc
浏览文件 @
83ac8515
...
@@ -18,13 +18,6 @@ limitations under the License. */
...
@@ -18,13 +18,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
template
<
int
N
>
Dim
<
N
>
make_dim
(
const
int64_t
*
d
)
{
Dim
<
N
>
ret
;
fix_dim_assign
(
d
,
ret
.
GetMutable
());
return
ret
;
}
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
)
{
DDim
make_ddim
(
std
::
initializer_list
<
int64_t
>
dims
)
{
return
DDim
(
dims
.
begin
(),
dims
.
size
());
return
DDim
(
dims
.
begin
(),
dims
.
size
());
}
}
...
@@ -69,8 +62,7 @@ struct DDimPlusVisitor {
...
@@ -69,8 +62,7 @@ struct DDimPlusVisitor {
DDim
DDim
::
operator
+
(
const
DDim
&
d
)
const
{
DDim
DDim
::
operator
+
(
const
DDim
&
d
)
const
{
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
DDim
ret
;
DDim
ret
(
rank_
);
ret
.
rank_
=
rank_
;
ret
.
apply_visitor
(
DDimPlusVisitor
(
Get
(),
d
.
Get
()));
ret
.
apply_visitor
(
DDimPlusVisitor
(
Get
(),
d
.
Get
()));
return
ret
;
return
ret
;
}
}
...
@@ -90,8 +82,7 @@ struct DDimMulVisitor {
...
@@ -90,8 +82,7 @@ struct DDimMulVisitor {
DDim
DDim
::
operator
*
(
const
DDim
&
d
)
const
{
DDim
DDim
::
operator
*
(
const
DDim
&
d
)
const
{
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
PADDLE_ENFORCE
(
rank_
==
d
.
rank_
);
DDim
ret
;
DDim
ret
(
rank_
);
ret
.
rank_
=
rank_
;
ret
.
apply_visitor
(
DDimMulVisitor
(
Get
(),
d
.
Get
()));
ret
.
apply_visitor
(
DDimMulVisitor
(
Get
(),
d
.
Get
()));
return
ret
;
return
ret
;
}
}
...
@@ -118,7 +109,7 @@ std::vector<int> vectorize2int(const DDim& ddim) {
...
@@ -118,7 +109,7 @@ std::vector<int> vectorize2int(const DDim& ddim) {
struct
ProductVisitor
{
struct
ProductVisitor
{
template
<
int
D
>
template
<
int
D
>
int64_t
operator
()(
const
Dim
<
D
>&
dim
)
{
in
line
in
t64_t
operator
()(
const
Dim
<
D
>&
dim
)
{
return
product
(
dim
);
return
product
(
dim
);
}
}
};
};
...
@@ -130,8 +121,7 @@ int64_t product(const DDim& ddim) {
...
@@ -130,8 +121,7 @@ int64_t product(const DDim& ddim) {
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
PADDLE_ENFORCE
(
begin
>=
0
,
PADDLE_ENFORCE
(
begin
>=
0
,
"Begin index can't be less than zero in ddim slice."
);
"Begin index can't be less than zero in ddim slice."
);
DDim
ret
;
DDim
ret
(
end
-
begin
);
ret
.
rank_
=
end
-
begin
;
dynamic_dim_assign
(
dim
.
Get
()
+
begin
,
ret
.
GetMutable
(),
ret
.
rank_
);
dynamic_dim_assign
(
dim
.
Get
()
+
begin
,
ret
.
GetMutable
(),
ret
.
rank_
);
return
ret
;
return
ret
;
}
}
...
@@ -166,8 +156,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
...
@@ -166,8 +156,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_ddim
({
product
(
src
)});
}
DDim
flatten_to_1d
(
const
DDim
&
src
)
{
return
make_ddim
({
product
(
src
)});
}
DDim
stride
(
const
DDim
&
ddim
)
{
DDim
stride
(
const
DDim
&
ddim
)
{
DDim
strides
;
DDim
strides
(
ddim
.
size
());
strides
.
rank_
=
ddim
.
size
();
strides
[
ddim
.
size
()
-
1
]
=
1
;
strides
[
ddim
.
size
()
-
1
]
=
1
;
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
+
1
];
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
+
1
];
...
@@ -175,9 +164,8 @@ DDim stride(const DDim& ddim) {
...
@@ -175,9 +164,8 @@ DDim stride(const DDim& ddim) {
return
strides
;
return
strides
;
}
}
DDim
stride_numel
(
const
framework
::
DDim
&
ddim
)
{
DDim
stride_numel
(
const
DDim
&
ddim
)
{
DDim
strides
;
DDim
strides
(
ddim
.
size
());
strides
.
rank_
=
ddim
.
size
();
strides
[
ddim
.
size
()
-
1
]
=
ddim
[
ddim
.
size
()
-
1
];
strides
[
ddim
.
size
()
-
1
]
=
ddim
[
ddim
.
size
()
-
1
];
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
ddim
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
];
strides
[
i
]
=
strides
[
i
+
1
]
*
ddim
[
i
];
...
...
paddle/fluid/framework/ddim.h
浏览文件 @
83ac8515
...
@@ -22,27 +22,31 @@ limitations under the License. */
...
@@ -22,27 +22,31 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
#define PADDLE_VISIT_DDIM_BASE(rank, callback) \
case (rank): { \
constexpr auto kRank = (rank); \
return (callback); \
}
#define PADDLE_VISIT_DDIM(rank, callback) \
switch (rank) { \
PADDLE_VISIT_DDIM_BASE(0, callback); \
PADDLE_VISIT_DDIM_BASE(1, callback); \
PADDLE_VISIT_DDIM_BASE(2, callback); \
PADDLE_VISIT_DDIM_BASE(3, callback); \
PADDLE_VISIT_DDIM_BASE(4, callback); \
PADDLE_VISIT_DDIM_BASE(5, callback); \
PADDLE_VISIT_DDIM_BASE(6, callback); \
PADDLE_VISIT_DDIM_BASE(7, callback); \
PADDLE_VISIT_DDIM_BASE(8, callback); \
PADDLE_VISIT_DDIM_BASE(9, callback); \
default: \
PADDLE_THROW("Invalid rank %d", rank); \
}
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
inline
void
dynamic_dim_assign
(
const
T1
*
in
,
T2
*
out
,
int
n
)
{
inline
void
dynamic_dim_assign
(
const
T1
*
in
,
T2
*
out
,
int
n
)
{
#define STATIC_DIM_ASSIGN_CASE(rank) \
PADDLE_VISIT_DDIM
(
n
,
(
static_dim_assign
<
kRank
,
T1
,
T2
>
(
in
,
out
)));
case rank: \
static_dim_assign<rank, T1, T2>(in, out); \
return
switch
(
n
)
{
STATIC_DIM_ASSIGN_CASE
(
0
);
STATIC_DIM_ASSIGN_CASE
(
1
);
STATIC_DIM_ASSIGN_CASE
(
2
);
STATIC_DIM_ASSIGN_CASE
(
3
);
STATIC_DIM_ASSIGN_CASE
(
4
);
STATIC_DIM_ASSIGN_CASE
(
5
);
STATIC_DIM_ASSIGN_CASE
(
6
);
STATIC_DIM_ASSIGN_CASE
(
7
);
STATIC_DIM_ASSIGN_CASE
(
8
);
STATIC_DIM_ASSIGN_CASE
(
9
);
default:
PADDLE_THROW
(
"Invalid rank %d"
,
n
);
}
#undef STATIC_DIM_ASSIGN_CASE
}
}
/**
/**
...
@@ -84,22 +88,26 @@ class DDim {
...
@@ -84,22 +88,26 @@ class DDim {
inline
int64_t
operator
[](
int
idx
)
const
{
return
dim_
[
idx
];
}
inline
int64_t
operator
[](
int
idx
)
const
{
return
dim_
[
idx
];
}
inline
int64_t
&
at
(
int
idx
)
{
inline
int64_t
&
at
(
int
idx
)
{
PADDLE_ENFORCE
(
idx
>=
0
&&
idx
<
rank_
);
PADDLE_ENFORCE
(
idx
>=
0
&&
idx
<
rank_
,
"Invalid idx %d"
,
idx
);
return
dim_
[
idx
];
return
dim_
[
idx
];
}
}
inline
int64_t
at
(
int
idx
)
const
{
inline
int64_t
at
(
int
idx
)
const
{
PADDLE_ENFORCE
(
idx
>=
0
&&
idx
<
rank_
);
PADDLE_ENFORCE
(
idx
>=
0
&&
idx
<
rank_
,
"Invalid idx %d"
,
idx
);
return
dim_
[
idx
];
return
dim_
[
idx
];
}
}
template
<
typename
Visitor
>
template
<
typename
Visitor
>
typename
std
::
result_of
<
Visitor
(
Dim
<
0
>&
)
>::
type
apply_visitor
(
typename
std
::
result_of
<
Visitor
(
Dim
<
0
>&
)
>::
type
apply_visitor
(
Visitor
&&
visitor
);
Visitor
&&
visitor
)
{
PADDLE_VISIT_DDIM
(
rank_
,
visitor
(
UnsafeCast
<
kRank
>
()));
}
template
<
typename
Visitor
>
template
<
typename
Visitor
>
typename
std
::
result_of
<
Visitor
(
const
Dim
<
0
>&
)
>::
type
apply_visitor
(
typename
std
::
result_of
<
Visitor
(
const
Dim
<
0
>&
)
>::
type
apply_visitor
(
Visitor
&&
visitor
)
const
;
Visitor
&&
visitor
)
const
{
PADDLE_VISIT_DDIM
(
rank_
,
visitor
(
UnsafeCast
<
kRank
>
()));
}
bool
operator
==
(
const
DDim
&
d
)
const
;
bool
operator
==
(
const
DDim
&
d
)
const
;
...
@@ -128,55 +136,22 @@ class DDim {
...
@@ -128,55 +136,22 @@ class DDim {
return
*
reinterpret_cast
<
const
Dim
<
M
>*>
(
p
);
return
*
reinterpret_cast
<
const
Dim
<
M
>*>
(
p
);
}
}
// Construct DDim with given rank
// Only used in friend functions
explicit
DDim
(
int
rank
)
:
rank_
(
rank
)
{
PADDLE_ENFORCE
(
rank_
>=
0
&&
rank_
<
kMaxRank
,
"Invalid rank %d"
,
rank
);
}
friend
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
friend
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
friend
DDim
stride
(
const
DDim
&
ddim
);
friend
DDim
stride
(
const
DDim
&
ddim
);
friend
DDim
stride_numel
(
const
DDim
&
ddim
);
friend
DDim
stride_numel
(
const
DDim
&
ddim
);
private:
Dim
<
kMaxRank
>
dim_
;
Dim
<
kMaxRank
>
dim_
;
int
rank_
;
int
rank_
;
};
};
#define PADDLE_VISIT_DDIM(rank) \
#undef PADDLE_VISIT_DDIM_BASE
case rank: \
return visitor(UnsafeCast<rank>())
template
<
typename
Visitor
>
typename
std
::
result_of
<
Visitor
(
Dim
<
0
>&
)
>::
type
DDim
::
apply_visitor
(
Visitor
&&
visitor
)
{
switch
(
rank_
)
{
PADDLE_VISIT_DDIM
(
0
);
PADDLE_VISIT_DDIM
(
1
);
PADDLE_VISIT_DDIM
(
2
);
PADDLE_VISIT_DDIM
(
3
);
PADDLE_VISIT_DDIM
(
4
);
PADDLE_VISIT_DDIM
(
5
);
PADDLE_VISIT_DDIM
(
6
);
PADDLE_VISIT_DDIM
(
7
);
PADDLE_VISIT_DDIM
(
8
);
PADDLE_VISIT_DDIM
(
9
);
default:
PADDLE_THROW
(
"Invalid rank %d"
,
rank_
);
}
}
template
<
typename
Visitor
>
typename
std
::
result_of
<
Visitor
(
const
Dim
<
0
>&
)
>::
type
DDim
::
apply_visitor
(
Visitor
&&
visitor
)
const
{
switch
(
rank_
)
{
PADDLE_VISIT_DDIM
(
0
);
PADDLE_VISIT_DDIM
(
1
);
PADDLE_VISIT_DDIM
(
2
);
PADDLE_VISIT_DDIM
(
3
);
PADDLE_VISIT_DDIM
(
4
);
PADDLE_VISIT_DDIM
(
5
);
PADDLE_VISIT_DDIM
(
6
);
PADDLE_VISIT_DDIM
(
7
);
PADDLE_VISIT_DDIM
(
8
);
PADDLE_VISIT_DDIM
(
9
);
default:
PADDLE_THROW
(
"Invalid rank %d"
,
rank_
);
}
}
#undef PADDLE_VISIT_DDIM
#undef PADDLE_VISIT_DDIM
/**
/**
...
...
paddle/fluid/operators/detail/strided_memcpy.h
浏览文件 @
83ac8515
...
@@ -98,8 +98,8 @@ struct StridedCopyDimVisitor {
...
@@ -98,8 +98,8 @@ struct StridedCopyDimVisitor {
template
<
int
D
>
template
<
int
D
>
void
operator
()(
const
framework
::
Dim
<
D
>&
dst_dim
)
const
{
void
operator
()(
const
framework
::
Dim
<
D
>&
dst_dim
)
const
{
StridedMemcpyFunctor
<
T
,
D
>
functor
;
StridedMemcpyFunctor
<
T
,
D
>
functor
;
functor
(
dev_ctx_
,
src_
,
src_stride_
.
data
(),
dst_dim
.
data
(),
functor
(
dev_ctx_
,
src_
,
src_stride_
.
Get
(),
dst_dim
.
Get
(),
dst_stride_
.
Get
(),
dst_
stride_
.
data
(),
dst_
);
dst_
);
}
}
const
platform
::
DeviceContext
&
dev_ctx_
;
const
platform
::
DeviceContext
&
dev_ctx_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录