Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7d56c6d0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7d56c6d0
编写于
2月 22, 2018
作者:
X
xuwei06
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adding Dim<0>
Dim<0> is for scalar (rank-0 tensor). Adding Dim<0> can simplify a lot of code.
上级
a67cebaf
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
73 addition
and
49 deletion
+73
-49
paddle/fluid/framework/ddim.cc
paddle/fluid/framework/ddim.cc
+8
-6
paddle/fluid/framework/ddim.h
paddle/fluid/framework/ddim.h
+2
-2
paddle/fluid/framework/dim.h
paddle/fluid/framework/dim.h
+40
-41
paddle/fluid/operators/detail/strided_memcpy.h
paddle/fluid/operators/detail/strided_memcpy.h
+23
-0
未找到文件。
paddle/fluid/framework/ddim.cc
浏览文件 @
7d56c6d0
...
@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) {
...
@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) {
}
}
template
<
>
template
<
>
Dim
<
1
>
make_dim
<
1
>
(
const
int64_t
*
d
)
{
Dim
<
0
>
make_dim
<
0
>
(
const
int64_t
*
d
)
{
return
Dim
<
1
>
(
*
d
);
return
Dim
<
0
>
(
*
d
);
}
}
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
void
make_ddim
(
DDim
&
ddim
,
const
int64_t
*
dims
,
int
n
)
{
switch
(
n
)
{
switch
(
n
)
{
case
0
:
ddim
=
make_dim
<
0
>
(
dims
);
break
;
case
1
:
case
1
:
ddim
=
make_dim
<
1
>
(
dims
);
ddim
=
make_dim
<
1
>
(
dims
);
break
;
break
;
...
@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> {
...
@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> {
this
->
operator
()(
t
.
tail
);
this
->
operator
()(
t
.
tail
);
}
}
void
operator
()(
const
Dim
<
1
>&
t
)
{
vector
.
push_back
(
t
.
head
);
}
void
operator
()(
const
Dim
<
0
>&
t
)
{
}
};
};
/// @endcond
/// @endcond
...
@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
...
@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
}
}
}
}
void
operator
()(
const
Dim
<
1
>&
dim
)
{
void
operator
()(
const
Dim
<
0
>&
dim
)
{
PADDLE_ENFORCE
(
end
==
1
,
"End index in ddim slice is out of bound."
);
PADDLE_ENFORCE
(
end
==
0
,
"End index in ddim slice is out of bound."
);
vector
.
push_back
(
dim
.
head
);
}
}
};
};
...
...
paddle/fluid/framework/ddim.h
浏览文件 @
7d56c6d0
...
@@ -30,8 +30,8 @@ namespace framework {
...
@@ -30,8 +30,8 @@ namespace framework {
* The number of dimensions must be between [1, 9].
* The number of dimensions must be between [1, 9].
*/
*/
struct
DDim
{
struct
DDim
{
typedef
boost
::
variant
<
Dim
<
1
>
,
Dim
<
2
>
,
Dim
<
3
>
,
Dim
<
4
>
,
Dim
<
5
>
,
Dim
<
6
>
,
Dim
<
7
>
,
typedef
boost
::
variant
<
Dim
<
0
>
,
Dim
<
1
>
,
Dim
<
2
>
,
Dim
<
3
>
,
Dim
<
4
>
,
Dim
<
5
>
,
Dim
<
6
>
,
Dim
<
8
>
,
Dim
<
9
>>
Dim
<
7
>
,
Dim
<
8
>
,
Dim
<
9
>>
DDimVar
;
DDimVar
;
DDimVar
var
;
DDimVar
var
;
...
...
paddle/fluid/framework/dim.h
浏览文件 @
7d56c6d0
...
@@ -72,38 +72,36 @@ struct Dim {
...
@@ -72,38 +72,36 @@ struct Dim {
// Base case specialization
// Base case specialization
template
<
>
template
<
>
struct
Dim
<
1
>
{
struct
Dim
<
0
>
{
static
constexpr
int
dimensions
=
1
;
static
constexpr
int
dimensions
=
0
;
HOSTDEVICE
HOSTDEVICE
Dim
(
int64_t
_head
)
:
head
(
_head
)
{}
Dim
(
int64_t
_head
)
{}
HOSTDEVICE
HOSTDEVICE
Dim
()
:
head
(
0
)
{}
Dim
()
{}
HOSTDEVICE
HOSTDEVICE
Dim
(
int
idx
,
const
Dim
<
1
>&
size
)
:
head
(
idx
)
{
Dim
(
int
idx
,
const
Dim
<
0
>&
size
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
>
=
size
.
head
)
{
if
(
idx
>
0
)
{
throw
std
::
invalid_argument
(
"Index out of range."
);
throw
std
::
invalid_argument
(
"Index out of range."
);
}
}
#else
#else
PADDLE_ASSERT
(
idx
<
size
.
head
);
PADDLE_ASSERT
(
idx
==
0
);
#endif
#endif
}
}
HOSTDEVICE
HOSTDEVICE
bool
operator
==
(
const
Dim
<
1
>&
o
)
const
{
return
(
head
==
o
.
head
)
;
}
bool
operator
==
(
const
Dim
<
0
>&
o
)
const
{
return
true
;
}
HOSTDEVICE
HOSTDEVICE
bool
operator
!=
(
const
Dim
<
1
>&
o
)
const
{
return
!
(
*
this
==
o
)
;
}
bool
operator
!=
(
const
Dim
<
0
>&
o
)
const
{
return
false
;
}
HOSTDEVICE
HOSTDEVICE
int64_t
&
operator
[](
int
idx
);
int64_t
&
operator
[](
int
idx
);
HOSTDEVICE
HOSTDEVICE
int64_t
operator
[](
int
idx
)
const
;
int64_t
operator
[](
int
idx
)
const
;
int64_t
head
;
};
};
namespace
{
namespace
{
...
@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
...
@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
}
}
template
<
>
template
<
>
HOSTDEVICE
int64_t
&
indexer
<
1
>
(
Dim
<
1
>&
dim
,
int
idx
)
{
HOSTDEVICE
int64_t
&
indexer
<
0
>
(
Dim
<
0
>&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
!=
0
)
{
throw
std
::
invalid_argument
(
"Invalid index"
);
throw
std
::
invalid_argument
(
"Invalid index"
);
}
#else
#else
PADDLE_ASSERT
(
idx
==
0
);
PADDLE_ASSERT
(
false
);
#endif
#endif
return
dim
.
head
;
static
int64_t
head
=
0
;
return
head
;
}
}
template
<
int
D
>
template
<
int
D
>
...
@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
...
@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
}
}
template
<
>
template
<
>
HOSTDEVICE
int64_t
indexer
<
1
>
(
const
Dim
<
1
>&
dim
,
int
idx
)
{
HOSTDEVICE
int64_t
indexer
<
0
>
(
const
Dim
<
0
>&
dim
,
int
idx
)
{
#ifndef __CUDA_ARCH__
#ifndef __CUDA_ARCH__
if
(
idx
!=
0
)
{
throw
std
::
invalid_argument
(
"Invalid index"
);
throw
std
::
invalid_argument
(
"Invalid index"
);
}
#else
#else
PADDLE_ASSERT
(
idx
==
0
);
PADDLE_ASSERT
(
false
);
#endif
#endif
return
dim
.
head
;
static
int64_t
head
=
0
;
return
head
;
}
}
}
// namespace
}
// namespace
...
@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
...
@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
}
}
// Dynamic access to constant Dim
// Dynamic access to constant Dim
inline
HOSTDEVICE
int64_t
Dim
<
1
>::
operator
[](
int
i
)
const
{
inline
HOSTDEVICE
int64_t
Dim
<
0
>::
operator
[](
int
i
)
const
{
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
// Dynamic access to mutable Dim
// Dynamic access to mutable Dim
inline
HOSTDEVICE
int64_t
&
Dim
<
1
>::
operator
[](
int
i
)
{
inline
HOSTDEVICE
int64_t
&
Dim
<
0
>::
operator
[](
int
i
)
{
return
indexer
(
*
this
,
i
);
return
indexer
(
*
this
,
i
);
}
}
...
@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
...
@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
// Base case dot product of two Dims
// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
template
<
>
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
1
>&
a
,
const
Dim
<
1
>&
b
)
{
HOSTDEVICE
inline
int64_t
linearize
(
const
Dim
<
0
>&
a
,
const
Dim
<
0
>&
b
)
{
return
a
.
head
*
b
.
head
;
return
0
;
}
}
// Product of a Dim
// Product of a Dim
...
@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
...
@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
// Base case product of a Dim
// Base case product of a Dim
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
template
<
>
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
1
>&
a
,
int
prod
)
{
HOSTDEVICE
inline
int64_t
product
(
const
Dim
<
0
>&
a
,
int
prod
)
{
return
prod
*
a
.
head
;
return
prod
;
}
}
// Is 0 <= idx_i < size_i for all i?
// Is 0 <= idx_i < size_i for all i?
...
@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
...
@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
// Base case of is 0 <= idx_i < size_i ?
// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
template
<
>
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
1
>&
idx
,
const
Dim
<
1
>&
size
)
{
HOSTDEVICE
inline
bool
contained
(
const
Dim
<
0
>&
idx
,
const
Dim
<
0
>&
size
)
{
return
((
0
<=
idx
.
head
)
&&
(
idx
.
head
<
size
.
head
))
;
return
true
;
}
}
/**
/**
...
@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
...
@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
// Base case of ex_prefix_mul
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
// Notice it is inline because it is no longer a template
template
<
>
template
<
>
HOSTDEVICE
inline
Dim
<
1
>
ex_prefix_mul
(
const
Dim
<
1
>&
src
,
int
mul
)
{
HOSTDEVICE
inline
Dim
<
0
>
ex_prefix_mul
(
const
Dim
<
0
>&
src
,
int
mul
)
{
return
Dim
<
1
>
(
mul
);
return
Dim
<
0
>
(
);
}
}
///\endcond
///\endcond
...
@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
...
@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
// Base case
// Base case
template
<
>
template
<
>
HOSTDEVICE
inline
Dim
<
1
>
dim_plus
(
const
Dim
<
1
>&
a
,
const
Dim
<
1
>&
b
)
{
HOSTDEVICE
inline
Dim
<
0
>
dim_plus
(
const
Dim
<
0
>&
a
,
const
Dim
<
0
>&
b
)
{
return
Dim
<
1
>
(
a
.
head
+
b
.
head
);
return
Dim
<
0
>
(
);
}
}
template
<
int
i
>
template
<
int
i
>
...
@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
...
@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
// Base case
// Base case
template
<
>
template
<
>
HOSTDEVICE
inline
Dim
<
1
>
dim_mult
(
const
Dim
<
1
>&
a
,
const
Dim
<
1
>&
b
)
{
HOSTDEVICE
inline
Dim
<
0
>
dim_mult
(
const
Dim
<
0
>&
a
,
const
Dim
<
0
>&
b
)
{
return
Dim
<
1
>
(
a
.
head
*
b
.
head
);
return
Dim
<
0
>
(
);
}
}
template
<
int
i
>
template
<
int
i
>
...
@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
...
@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
///\cond HIDDEN
///\cond HIDDEN
template
<
>
template
<
>
HOSTDEVICE
inline
Dim
<
1
>
normalize_strides
(
const
Dim
<
1
>&
size
,
HOSTDEVICE
inline
Dim
<
0
>
normalize_strides
(
const
Dim
<
0
>&
size
,
const
Dim
<
1
>&
stride
)
{
const
Dim
<
0
>&
stride
)
{
int
norm_stride
=
size
.
head
==
1
?
0
:
stride
.
head
;
return
Dim
<
0
>
();
return
Dim
<
1
>
(
norm_stride
);
}
}
///\endcond
///\endcond
...
@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
...
@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
return
os
;
return
os
;
}
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Dim
<
0
>&
d
)
{
return
os
;
}
template
<
int
i
>
template
<
int
i
>
HOST
std
::
string
Dim
<
i
>::
to_string
()
const
{
HOST
std
::
string
Dim
<
i
>::
to_string
()
const
{
std
::
stringstream
stream
;
std
::
stringstream
stream
;
...
...
paddle/fluid/operators/detail/strided_memcpy.h
浏览文件 @
7d56c6d0
...
@@ -24,6 +24,29 @@ namespace detail {
...
@@ -24,6 +24,29 @@ namespace detail {
template
<
typename
T
,
int
Rank
>
template
<
typename
T
,
int
Rank
>
struct
StridedMemcpyFunctor
;
struct
StridedMemcpyFunctor
;
template
<
typename
T
>
struct
StridedMemcpyFunctor
<
T
,
0
>
{
void
operator
()(
const
platform
::
DeviceContext
&
dev_ctx
,
const
T
*
src
,
framework
::
Dim
<
0
>
src_stride
,
framework
::
Dim
<
0
>
dst_dim
,
framework
::
Dim
<
0
>
dst_stride
,
T
*
dst
)
const
{
auto
place
=
dev_ctx
.
GetPlace
();
if
(
platform
::
is_cpu_place
(
place
))
{
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
place
);
memory
::
Copy
(
cpu_place
,
dst
,
cpu_place
,
src
,
sizeof
(
T
));
}
else
{
#ifdef PADDLE_WITH_CUDA
auto
&
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
);
auto
&
cuda_ctx
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
memory
::
Copy
(
gpu_place
,
dst
,
gpu_place
,
src
,
sizeof
(
T
),
cuda_ctx
.
stream
());
#else
PADDLE_THROW
(
"Paddle is not compiled with GPU"
);
#endif
}
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
StridedMemcpyFunctor
<
T
,
1
>
{
struct
StridedMemcpyFunctor
<
T
,
1
>
{
void
operator
()(
const
platform
::
DeviceContext
&
dev_ctx
,
const
T
*
src
,
void
operator
()(
const
platform
::
DeviceContext
&
dev_ctx
,
const
T
*
src
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录