Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
56b04e5b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
56b04e5b
编写于
4月 10, 2018
作者:
_青葱
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into initializer
上级
93940642
b1224da8
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
122 addition
and
105 deletion
+122
-105
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+13
-8
paddle/fluid/framework/details/ssa_graph.h
paddle/fluid/framework/details/ssa_graph.h
+5
-1
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+16
-14
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-2
paddle/fluid/operators/elementwise_op_function.h
paddle/fluid/operators/elementwise_op_function.h
+75
-25
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+0
-48
python/paddle/fluid/distribute_transpiler.py
python/paddle/fluid/distribute_transpiler.py
+10
-6
python/setup.py.in
python/setup.py.in
+1
-1
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
56b04e5b
...
...
@@ -59,7 +59,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto
graph
=
new
SSAGraph
();
SSAGraph
&
result
=
*
graph
;
std
::
unordered_set
<
std
::
string
>
og_has_been_broadcast
;
result
.
vars_
.
resize
(
places_
.
size
());
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
vars_
=
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
(
places_
.
size
());
bool
is_forwarding
=
true
;
for
(
auto
*
op
:
program
.
Block
(
0
).
AllOps
())
{
...
...
@@ -147,15 +151,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if
(
vars
.
empty
())
{
// This device has no data. continue.
continue
;
}
auto
*
prev_grad
=
&
vars
[
vars
.
size
()
-
1
];
op_handle
->
AddInput
(
prev_grad
);
auto
&
prev_grad
=
vars
[
vars
.
size
()
-
1
];
op_handle
->
AddInput
(
prev_grad
.
get
()
);
auto
&
var
=
vars
[
vars
.
size
()];
var
.
place_
=
p
;
var
.
name_
=
og
;
var
.
version_
=
vars
.
size
()
-
1
;
vars
.
emplace_back
(
new
VarHandle
);
auto
&
var
=
vars
.
back
();
var
->
place_
=
p
;
var
->
name_
=
og
;
var
->
version_
=
vars
.
size
()
-
1
;
op_handle
->
AddOutput
(
&
var
);
op_handle
->
AddOutput
(
var
.
get
()
);
}
#else
PADDLE_ENFORCE
(
"Not implemented"
);
...
...
paddle/fluid/framework/details/ssa_graph.h
浏览文件 @
56b04e5b
...
...
@@ -16,6 +16,8 @@
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
...
...
@@ -24,7 +26,9 @@ namespace framework {
namespace
details
{
struct
SSAGraph
{
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
map
<
int
,
VarHandle
>>>
vars_
;
std
::
vector
<
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
VarHandle
>>>>
vars_
;
// aux variables to represent dependency. Useful to resolve data hazard.
std
::
unordered_set
<
std
::
unique_ptr
<
VarHandleBase
>>
dep_vars_
;
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
ops_
;
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
56b04e5b
...
...
@@ -27,8 +27,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
auto
*
write_op
=
it_new
->
second
.
generated_op_
;
auto
&
read_ops
=
it_old
->
second
.
pending_ops_
;
auto
*
write_op
=
(
*
it_new
)
->
generated_op_
;
auto
&
read_ops
=
(
*
it_old
)
->
pending_ops_
;
for
(
auto
*
read_op
:
read_ops
)
{
// Manually add a dependency var from read_op to write_op;
...
...
@@ -54,14 +54,15 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto
&
var_holder
=
var_holders
[
each_var_name
];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
var_holder
.
emplace_back
(
new
VarHandle
);
auto
&
init_var
=
var_holder
[
0
];
init_var
.
place_
=
place
;
init_var
.
name_
=
each_var_name
;
init_var
.
generated_op_
=
nullptr
;
init_var
.
version_
=
0
;
var
=
&
init_var
;
init_var
->
place_
=
place
;
init_var
->
name_
=
each_var_name
;
init_var
->
generated_op_
=
nullptr
;
init_var
->
version_
=
0
;
var
=
init_var
.
get
()
;
}
else
{
var
=
&
var_holder
.
rbegin
()
->
second
;
var
=
var_holder
.
rbegin
()
->
get
()
;
}
return
var
;
}
...
...
@@ -72,11 +73,12 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
vars_
[
place_offset
][
each_var_name
];
size_t
version
=
vars
.
size
();
auto
&
var
=
vars
[
version
];
var
.
version_
=
version
;
var
.
name_
=
each_var_name
;
var
.
place_
=
place
;
op_handle
->
AddOutput
(
&
var
);
vars
.
emplace_back
(
new
VarHandle
());
auto
&
var
=
vars
.
back
();
var
->
version_
=
version
;
var
->
name_
=
each_var_name
;
var
->
place_
=
place
;
op_handle
->
AddOutput
(
var
.
get
());
}
template
<
typename
Callback
>
...
...
@@ -84,7 +86,7 @@ void IterAllVar(const SSAGraph &graph, Callback callback) {
for
(
auto
&
each
:
graph
.
vars_
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
pair2
.
second
);
callback
(
*
pair2
);
}
}
}
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
56b04e5b
...
...
@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
InsertPendingVar
(
version_pair
.
second
);
InsertPendingVar
(
*
version_pair
);
}
}
}
...
...
@@ -95,7 +95,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
for
(
auto
&
var_map
:
graph_
->
vars_
)
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
fetched_vars
[
fetch_var_name
].
push_back
(
&
it
->
second
.
rbegin
()
->
second
);
fetched_vars
[
fetch_var_name
].
push_back
(
it
->
second
.
rbegin
()
->
get
()
);
}
}
}
...
...
paddle/fluid/operators/elementwise_op_function.h
浏览文件 @
56b04e5b
...
...
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_helper.h"
constexpr
int
ELEMWISE_MAX_BLOCK_DIM
=
1024
;
#endif
...
...
@@ -43,35 +44,35 @@ namespace operators {
*/
inline
void
get_mid_dims
(
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
int
&
pre
,
int
&
n
,
int
&
post
)
{
pre
=
1
;
n
=
1
;
post
=
1
;
int
*
pre
,
int
*
n
,
int
*
post
)
{
*
pre
=
1
;
*
n
=
1
;
*
post
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
pre
*=
x_dims
[
i
];
(
*
pre
)
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
i
+
axis
],
y_dims
[
i
],
"Broadcast dimension mismatch."
);
n
*=
y_dims
[
i
];
(
*
n
)
*=
y_dims
[
i
];
}
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
post
*=
x_dims
[
i
];
(
*
post
)
*=
x_dims
[
i
];
}
}
inline
void
trim_trailing_singular_dims
(
framework
::
DDim
&
dims
)
{
inline
void
trim_trailing_singular_dims
(
framework
::
DDim
*
dims
)
{
// Remove trailing dimensions of size 1 for y
auto
actual_dims_size
=
dims
.
size
();
auto
actual_dims_size
=
dims
->
size
();
for
(;
actual_dims_size
!=
0
;
--
actual_dims_size
)
{
if
(
dims
[
actual_dims_size
-
1
]
!=
1
)
break
;
if
(
(
*
dims
)
[
actual_dims_size
-
1
]
!=
1
)
break
;
}
if
(
actual_dims_size
!=
dims
.
size
())
{
auto
actual_dims
=
framework
::
vectorize
(
dims
);
if
(
actual_dims_size
!=
dims
->
size
())
{
auto
actual_dims
=
framework
::
vectorize
(
*
dims
);
actual_dims
.
resize
(
actual_dims_size
);
dims
=
framework
::
make_ddim
(
actual_dims
);
*
dims
=
framework
::
make_ddim
(
actual_dims
);
}
}
...
...
@@ -159,7 +160,7 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
RowwiseTransformIterator
<
T
,
platform
::
CUDADeviceContext
>
,
const
T
*>
super_t
;
HOSTDEVICE
RowwiseTransformIterator
(
const
T
*
x
,
int
n
)
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
)
{};
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
)
{}
friend
class
thrust
::
iterator_core_access
;
private:
...
...
@@ -179,7 +180,7 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
MidWiseTransformIterator
<
T
,
platform
::
CUDADeviceContext
>
,
const
T
*>
super_t
;
HOSTDEVICE
MidWiseTransformIterator
(
const
T
*
x
,
int
n
,
int
post
)
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
),
post_
(
post
)
{};
:
super_t
(
x
),
begin_
(
x
),
n_
(
n
),
post_
(
post
)
{}
friend
class
thrust
::
iterator_core_access
;
private:
...
...
@@ -333,6 +334,55 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
}
#ifdef __NVCC__
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// TODO(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
__shared__
T
shm
[
32
];
const
int
warpSize
=
32
;
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
__syncthreads
();
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
...
...
@@ -355,7 +405,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
if
(
dy
)
{
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
platform
::
reduceSum
(
val
,
tid
,
h
);
val
=
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
...
...
@@ -432,7 +482,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
if
(
dy
)
{
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
platform
::
reduceSum
(
val
,
tid
,
h
);
val
=
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
...
...
@@ -472,11 +522,11 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
auto
y_dim
=
y
.
dims
();
axis
=
(
axis
==
-
1
?
x_dim
.
size
()
-
y_dim
.
size
()
:
axis
);
trim_trailing_singular_dims
(
y_dim
);
trim_trailing_singular_dims
(
&
y_dim
);
axis
=
(
y_dim
.
size
()
==
0
)
?
x_dim
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
pre
,
n
,
post
);
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
int
h
=
pre
;
int
w
=
n
;
...
...
@@ -514,7 +564,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
}
}
}
}
;
}
template
<
typename
DeviceContext
,
typename
T
,
typename
functor
,
typename
broadcastfunctor
,
typename
broadcast2functor
>
...
...
@@ -543,11 +593,11 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
}
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims
.
size
()
:
axis
);
trim_trailing_singular_dims
(
y_dims
);
trim_trailing_singular_dims
(
&
y_dims
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
pre
,
n
,
post
);
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
broadcastfunctor
f
;
...
...
@@ -582,11 +632,11 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims
.
size
()
:
axis
);
PADDLE_ENFORCE
(
axis
>=
0
&&
axis
<
x_dims
.
size
(),
"Axis should be in range [0, x_dims)"
);
trim_trailing_singular_dims
(
y_dims
);
trim_trailing_singular_dims
(
&
y_dims
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
pre
,
n
,
post
);
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
functor
.
RunRowWise
(
n
,
pre
);
return
;
...
...
paddle/fluid/platform/cuda_helper.h
浏览文件 @
56b04e5b
...
...
@@ -62,53 +62,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
}
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// TODO(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
__shared__
T
shm
[
32
];
const
int
warpSize
=
32
;
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
tid
<
len
);
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
if
(
tid
<
warpSize
)
shm
[
tid
]
=
0
;
__syncthreads
();
if
(
tid
%
warpSize
==
0
)
{
shm
[
tid
/
warpSize
]
=
val
;
}
CREATE_SHFL_MASK
(
mask
,
tid
<
warpSize
);
if
(
tid
<
warpSize
)
{
val
=
shm
[
tid
];
for
(
int
offset
=
warpSize
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
mask
,
val
,
offset
);
}
return
val
;
}
}
// namespace platform
}
// namespace paddle
python/paddle/fluid/distribute_transpiler.py
浏览文件 @
56b04e5b
...
...
@@ -278,11 +278,21 @@ class DistributeTranspiler:
# we don't need to create them when grad arrives.
# change client side var name to origin name by
# removing ".trainer_%d" suffix
suff_idx
=
v
.
name
.
find
(
".trainer_"
)
if
suff_idx
>=
0
:
orig_var_name
=
v
.
name
[:
suff_idx
]
else
:
orig_var_name
=
v
.
name
# NOTE: single_trainer_var must be created for multi-trainer
# case to merge grads from multiple trainers
single_trainer_var
=
\
pserver_program
.
global_block
().
create_var
(
name
=
orig_var_name
,
persistable
=
True
,
type
=
v
.
type
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
if
self
.
trainers
>
1
:
for
trainer_id
in
xrange
(
self
.
trainers
):
var
=
pserver_program
.
global_block
().
create_var
(
...
...
@@ -293,12 +303,6 @@ class DistributeTranspiler:
shape
=
v
.
shape
)
recv_inputs
.
append
(
var
)
else
:
single_trainer_var
=
pserver_program
.
global_block
().
create_var
(
name
=
orig_var_name
,
persistable
=
True
,
type
=
v
.
type
,
dtype
=
v
.
dtype
,
shape
=
v
.
shape
)
recv_inputs
.
append
(
single_trainer_var
)
# step3
...
...
python/setup.py.in
浏览文件 @
56b04e5b
...
...
@@ -102,7 +102,7 @@ if '${WITH_FLUID_ONLY}'== 'OFF':
package_data['py_paddle']=['*.py','_swig_paddle.so']
package_dir={
'': '${
CMAKE_CURRENT_SOURCE_DIR}
',
'': '${
PADDLE_BINARY_DIR}/python
',
# The paddle.fluid.proto will be generated while compiling.
# So that package points to other directory.
'paddle.fluid.proto.profiler': '${PADDLE_BINARY_DIR}/paddle/fluid/platform',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录