Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c919b2f3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c919b2f3
编写于
1月 03, 2019
作者:
P
peizhilin
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'upstream/develop' into windows/fixgpuissue
上级
fd4f4d0e
a1e60ab1
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
436 addition
and
158 deletion
+436
-158
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-0
paddle/fluid/framework/details/all_reduce_op_handle.cc
paddle/fluid/framework/details/all_reduce_op_handle.cc
+93
-82
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+9
-3
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+8
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+6
-5
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
...le/fluid/framework/details/parallel_ssa_graph_executor.cc
+99
-0
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
+51
-0
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+1
-1
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+103
-30
paddle/fluid/framework/parallel_executor.h
paddle/fluid/framework/parallel_executor.h
+10
-0
paddle/fluid/framework/threadpool.cc
paddle/fluid/framework/threadpool.cc
+0
-1
paddle/fluid/operators/reader/ctr_reader.h
paddle/fluid/operators/reader/ctr_reader.h
+1
-1
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+1
-1
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+6
-5
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+6
-9
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
...ddle/fluid/tests/unittests/parallel_executor_test_base.py
+0
-1
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+2
-2
python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py
...addle/fluid/tests/unittests/test_parallel_executor_crf.py
+36
-16
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
...dle/fluid/tests/unittests/test_parallel_executor_mnist.py
+1
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
c919b2f3
...
@@ -184,7 +184,7 @@ endif()
...
@@ -184,7 +184,7 @@ endif()
target_link_libraries
(
executor garbage_collector
)
target_link_libraries
(
executor garbage_collector
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
parallel_ssa_graph_executor
graph build_strategy
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper
)
fast_threaded_ssa_graph_executor variable_helper
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
c919b2f3
...
@@ -77,6 +77,8 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUT
...
@@ -77,6 +77,8 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUT
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
simple_threadpool device_context
)
cc_library
(
parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor
)
cc_test
(
broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
cc_test
(
broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle
)
device_context broadcast_op_handle
)
cc_test
(
gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
cc_test
(
gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
...
...
paddle/fluid/framework/details/all_reduce_op_handle.cc
浏览文件 @
c919b2f3
...
@@ -19,6 +19,13 @@
...
@@ -19,6 +19,13 @@
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
// asynchronous nccl allreduce or synchronous issue:
// https://github.com/PaddlePaddle/Paddle/issues/15049
DEFINE_bool
(
sync_nccl_allreduce
,
false
,
"If set true, will call `cudaStreamSynchronize(nccl_stream)`"
"after allreduce, this mode can get better performance in some scenarios."
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
...
@@ -48,100 +55,104 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
...
@@ -48,100 +55,104 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
void
AllReduceOpHandle
::
RunImpl
()
{
void
AllReduceOpHandle
::
RunImpl
()
{
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
cbegin
()
->
second
);
platform
::
RecordEvent
record_event
(
Name
(),
dev_ctxes_
.
cbegin
()
->
second
);
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
WaitInputVarGenerated
();
// this is a distributed or inter-process call, find a better way.
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Inputs
());
auto
out_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Outputs
());
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
places_
.
size
(),
"The NoDummyInputSize should be equal to the number of places."
);
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
out_var_handles
.
size
(),
"The NoDummyInputSize and NoDummyOutputSize should be equal."
);
std
::
vector
<
const
LoDTensor
*>
lod_tensors
;
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
*
s
=
local_scopes_
[
i
];
auto
&
local_scope
=
*
s
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
&
lod_tensor
=
local_scope
.
FindVar
(
in_var_handles
[
i
]
->
name_
)
->
Get
<
LoDTensor
>
();
lod_tensors
.
emplace_back
(
&
lod_tensor
);
PADDLE_ENFORCE_EQ
(
in_var_handles
[
i
]
->
name_
,
out_var_handles
[
i
]
->
name_
,
"The name of input and output should be equal."
);
}
if
(
platform
::
is_gpu_place
(
lod_tensors
[
0
]
->
place
()))
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if
(
NoDummyInputSize
()
==
1
&&
PADDLE_ENFORCE
(
nccl_ctxs_
,
"nccl_ctxs should not be nullptr."
);
local_scopes_
[
0
]
->
FindLocalVar
(
NCCL_ID_VARNAME
)
==
nullptr
)
{
int
dtype
=
-
1
;
#else
size_t
numel
=
0
;
if
(
NoDummyInputSize
()
==
1
)
{
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
#endif
return
;
// No need to all reduce when GPU count = 1;
}
else
{
// Wait input done
WaitInputVarGenerated
();
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Inputs
());
auto
out_var_handles
=
DynamicCast
<
VarHandle
>
(
this
->
Outputs
());
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
places_
.
size
(),
"The NoDummyInputSize should be equal to the number of places."
);
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
out_var_handles
.
size
(),
"The NoDummyInputSize and NoDummyOutputSize should be equal."
);
std
::
vector
<
const
LoDTensor
*>
lod_tensors
;
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
*
s
=
local_scopes_
[
i
];
auto
&
p
=
places_
[
i
];
auto
&
local_scope
=
*
s
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
&
lod_tensor
=
*
lod_tensors
[
i
];
auto
&
lod_tensor
=
void
*
buffer
=
const_cast
<
void
*>
(
lod_tensor
.
data
<
void
>
());
local_scope
.
FindVar
(
in_var_handles
[
i
]
->
name_
)
->
Get
<
LoDTensor
>
();
lod_tensors
.
emplace_back
(
&
lod_tensor
);
PADDLE_ENFORCE_EQ
(
in_var_handles
[
i
]
->
name_
,
out_var_handles
[
i
]
->
name_
,
"The name of input and output should be equal."
);
}
if
(
platform
::
is_gpu_place
(
lod_tensors
[
0
]
->
place
()))
{
if
(
dtype
==
-
1
)
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
dtype
=
platform
::
ToNCCLDataType
(
lod_tensor
.
type
());
PADDLE_ENFORCE
(
nccl_ctxs_
,
"nccl_ctxs should not be nullptr."
);
}
int
dtype
=
-
1
;
size_t
numel
=
0
;
if
(
numel
==
0
)
{
std
::
vector
<
std
::
function
<
void
()
>>
all_reduce_calls
;
numel
=
static_cast
<
size_t
>
(
lod_tensor
.
numel
());
for
(
size_t
i
=
0
;
i
<
local_scopes_
.
size
();
++
i
)
{
}
auto
&
p
=
places_
[
i
];
auto
&
lod_tensor
=
*
lod_tensors
[
i
];
void
*
buffer
=
const_cast
<
void
*>
(
lod_tensor
.
data
<
void
>
());
if
(
dtype
==
-
1
)
{
dtype
=
platform
::
ToNCCLDataType
(
lod_tensor
.
type
());
}
if
(
numel
==
0
)
{
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
numel
=
static_cast
<
size_t
>
(
lod_tensor
.
numel
());
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
all_reduce_calls
.
emplace_back
([
=
]
{
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
buffer
,
buffer
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
ncclSum
,
comm
,
stream
));
});
}
this
->
RunAndRecordEvent
([
&
]
{
if
(
all_reduce_calls
.
size
()
==
1UL
)
{
// Do not use NCCLGroup when manage NCCL by per thread per device
all_reduce_calls
[
0
]();
}
else
{
platform
::
NCCLGroupGuard
guard
;
for
(
auto
&
call
:
all_reduce_calls
)
{
call
();
}
}
}
});
if
(
FLAGS_sync_nccl_allreduce
)
{
for
(
auto
&
p
:
places_
)
{
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
int
dev_id
=
boost
::
get
<
platform
::
CUDAPlace
>
(
p
).
device
;
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
&
nccl_ctx
=
nccl_ctxs_
->
at
(
dev_id
);
auto
stream
=
nccl_ctx
.
stream
();
auto
stream
=
nccl_ctx
.
stream
();
auto
comm
=
nccl_ctx
.
comm_
;
cudaStreamSynchronize
(
stream
);
all_reduce_calls
.
emplace_back
([
=
]
{
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
buffer
,
buffer
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
ncclSum
,
comm
,
stream
));
});
}
}
this
->
RunAndRecordEvent
([
&
]
{
}
platform
::
NCCLGroupGuard
guard
;
for
(
auto
&
call
:
all_reduce_calls
)
{
call
();
}
});
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
#endif
}
else
{
// Special handle CPU only Operator's gradient. Like CRF
}
else
{
// Special handle CPU only Operator's gradient. Like CRF
auto
&
trg
=
*
this
->
local_scopes_
[
0
]
auto
&
trg
=
*
this
->
local_scopes_
[
0
]
->
FindVar
(
kLocalExecScopeName
)
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
()
->
Get
<
Scope
*>
()
->
FindVar
(
out_var_handles
[
0
]
->
name_
)
->
FindVar
(
out_var_handles
[
0
]
->
name_
)
->
GetMutable
<
framework
::
LoDTensor
>
();
->
GetMutable
<
framework
::
LoDTensor
>
();
// Reduce All Tensor to trg in CPU
// Reduce All Tensor to trg in CPU
ReduceLoDTensor
func
(
lod_tensors
,
&
trg
);
ReduceLoDTensor
func
(
lod_tensors
,
&
trg
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(),
func
);
VisitDataType
(
lod_tensors
[
0
]
->
type
(),
func
);
for
(
size_t
i
=
1
;
i
<
local_scopes_
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
local_scopes_
.
size
();
++
i
)
{
auto
&
scope
=
auto
&
scope
=
*
local_scopes_
[
i
]
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
*
local_scopes_
[
i
]
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
auto
*
var
=
scope
.
FindVar
(
out_var_handles
[
i
]
->
name_
);
auto
*
var
=
scope
.
FindVar
(
out_var_handles
[
i
]
->
name_
);
auto
*
dev_ctx
=
dev_ctxes_
.
at
(
p
);
auto
*
dev_ctx
=
dev_ctxes_
.
at
(
p
);
RunAndRecordEvent
(
p
,
[
&
trg
,
var
,
dev_ctx
,
p
]
{
RunAndRecordEvent
(
p
,
[
&
trg
,
var
,
dev_ctx
,
p
]
{
auto
&
tensor_gpu
=
*
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
&
tensor_gpu
=
*
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
&
tensor_cpu
=
trg
;
auto
&
tensor_cpu
=
trg
;
TensorCopy
(
tensor_cpu
,
p
,
*
dev_ctx
,
&
tensor_gpu
);
TensorCopy
(
tensor_cpu
,
p
,
*
dev_ctx
,
&
tensor_gpu
);
});
});
}
}
}
}
}
}
}
...
...
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
c919b2f3
...
@@ -31,7 +31,11 @@ namespace framework {
...
@@ -31,7 +31,11 @@ namespace framework {
namespace
details
{
namespace
details
{
static
inline
bool
SeqOnlyAllReduceOps
(
const
BuildStrategy
&
strategy
)
{
static
inline
bool
SeqOnlyAllReduceOps
(
const
BuildStrategy
&
strategy
)
{
return
(
!
strategy
.
enable_sequential_execution_
&&
strategy
.
num_trainers_
>
1
);
// Should fix the allreduce op order if scheduling
// them in multiple threads or processes to avoid hang.
return
(
!
strategy
.
enable_sequential_execution_
&&
strategy
.
num_trainers_
>
1
)
||
strategy
.
enable_parallel_graph_
;
}
}
class
ParallelExecutorPassBuilder
:
public
ir
::
PassBuilder
{
class
ParallelExecutorPassBuilder
:
public
ir
::
PassBuilder
{
...
@@ -86,8 +90,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -86,8 +90,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
auto
multi_devices_pass
=
AppendPass
(
"multi_devices_pass"
);
auto
multi_devices_pass
=
AppendPass
(
"multi_devices_pass"
);
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
multi_devices_pass
->
SetNotOwned
<
const
BuildStrategy
>
(
"strategy"
,
&
strategy_
);
&
strategy_
);
multi_devices_pass
->
Set
<
int
>
(
"num_trainers"
,
new
int
(
strategy_
.
num_trainers_
));
// Add a graph print pass to record a graph with device info.
// Add a graph print pass to record a graph with device info.
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
...
@@ -132,6 +134,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
...
@@ -132,6 +134,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
std
::
unique_ptr
<
ir
::
Graph
>
BuildStrategy
::
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
BuildStrategy
::
Apply
(
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
ProgramDesc
&
main_program
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
{
const
bool
use_cuda
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
{
#else
#else
...
@@ -150,6 +153,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
...
@@ -150,6 +153,9 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass
->
Erase
(
"local_scopes"
);
pass
->
Erase
(
"local_scopes"
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
"local_scopes"
,
&
local_scopes
);
&
local_scopes
);
pass
->
Erase
(
"nranks"
);
pass
->
Set
<
size_t
>
(
"nranks"
,
new
size_t
(
nranks
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
pass
->
Erase
(
"nccl_ctxs"
);
pass
->
Erase
(
"nccl_ctxs"
);
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
c919b2f3
...
@@ -110,6 +110,7 @@ struct BuildStrategy {
...
@@ -110,6 +110,7 @@ struct BuildStrategy {
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
string
&
loss_var_name
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
size_t
&
nranks
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
bool
use_cuda
,
const
bool
use_cuda
,
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
;
platform
::
NCCLContextMap
*
nccl_ctxs
)
const
;
...
@@ -117,6 +118,13 @@ struct BuildStrategy {
...
@@ -117,6 +118,13 @@ struct BuildStrategy {
const
bool
use_cuda
)
const
;
const
bool
use_cuda
)
const
;
#endif
#endif
// If set true, ParallelExecutor would build the main_program into multiple
// graphs,
// each of the graphs would run with one device. This approach can achieve
// better performance
// on some scenarios.
mutable
bool
enable_parallel_graph_
=
false
;
private:
private:
mutable
bool
is_finalized_
=
false
;
mutable
bool
is_finalized_
=
false
;
mutable
std
::
shared_ptr
<
ir
::
PassBuilder
>
pass_builder_
;
mutable
std
::
shared_ptr
<
ir
::
PassBuilder
>
pass_builder_
;
...
...
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
c919b2f3
...
@@ -138,7 +138,7 @@ static const char kLossVarName[] = "loss_var_name";
...
@@ -138,7 +138,7 @@ static const char kLossVarName[] = "loss_var_name";
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kPlaces
[]
=
"places"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kLocalScopes
[]
=
"local_scopes"
;
static
const
char
kStrategy
[]
=
"strategy"
;
static
const
char
kStrategy
[]
=
"strategy"
;
static
const
char
kN
umTrainers
[]
=
"num_trainer
s"
;
static
const
char
kN
Ranks
[]
=
"nrank
s"
;
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
void
MultiDevSSAGraphBuilder
::
Init
()
const
{
all_vars_
.
clear
();
all_vars_
.
clear
();
...
@@ -174,7 +174,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -174,7 +174,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto
nodes
=
graph
->
ReleaseNodes
();
auto
nodes
=
graph
->
ReleaseNodes
();
ir
::
Graph
&
result
=
*
graph
;
ir
::
Graph
&
result
=
*
graph
;
int
num_trainers
=
Get
<
int
>
(
kNumTrainer
s
);
size_t
nranks
=
Get
<
size_t
>
(
kNRank
s
);
for
(
auto
&
node
:
nodes
)
{
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
...
@@ -251,7 +251,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
...
@@ -251,7 +251,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
CreateComputationalOps
(
&
result
,
node
,
places_
.
size
());
}
}
if
(
!
is_forwarding
&&
(
places_
.
size
()
>
1
||
num_trainers
>
1
)
)
{
if
(
!
is_forwarding
&&
nranks
>
1UL
)
{
bool
is_bk_op
=
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
...
@@ -649,12 +649,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
...
@@ -649,12 +649,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
void
MultiDevSSAGraphBuilder
::
CreateScaleLossGradOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
,
ir
::
Graph
*
result
,
const
std
::
string
&
loss_grad_name
,
ir
::
Node
*
out_var_node
,
proto
::
VarType
::
Type
dtype
)
const
{
ir
::
Node
*
out_var_node
,
proto
::
VarType
::
Type
dtype
)
const
{
size_t
nranks
=
Get
<
size_t
>
(
"nranks"
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
// Insert ScaleCost OpHandle
// Insert ScaleCost OpHandle
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
.
size
()
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
nranks
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// FIXME: Currently ScaleLossGradOp only use device_count as scale
...
@@ -887,4 +888,4 @@ REGISTER_PASS(multi_devices_pass,
...
@@ -887,4 +888,4 @@ REGISTER_PASS(multi_devices_pass,
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kStrategy
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kN
umTrainer
s
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kN
Rank
s
);
paddle/fluid/framework/details/parallel_ssa_graph_executor.cc
0 → 100644
浏览文件 @
c919b2f3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
ParallelSSAGraphExecutor
::
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
&&
graphs
)
:
strategy_
(
std
::
move
(
strategy
)),
local_scopes_
(
std
::
move
(
local_scopes
)),
pool_
(
places
.
size
()
>=
2
?
new
::
ThreadPool
(
places
.
size
())
:
nullptr
),
places_
(
std
::
move
(
places
)),
graphs_
(
std
::
move
(
graphs
))
{
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
// set the correct size of thread pool to each device.
strategy_
.
num_threads_
=
strategy_
.
num_threads_
<
places_
.
size
()
?
1UL
:
strategy_
.
num_threads_
/
places_
.
size
();
VLOG
(
1
)
<<
"set num_threads: "
<<
strategy_
.
num_threads_
<<
" to run the operators of the graph on each device."
;
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
executors_
.
emplace_back
(
new
details
::
ThreadedSSAGraphExecutor
(
strategy_
,
{
local_scopes_
[
i
]},
{
places_
[
i
]},
std
::
move
(
graphs_
[
i
])));
}
}
FeedFetchList
ParallelSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
std
::
vector
<
std
::
future
<
FeedFetchList
>>
run_futures
;
std
::
vector
<
FeedFetchList
>
fetch_data
;
FeedFetchList
ret
;
fetch_data
.
reserve
(
places_
.
size
());
ret
.
reserve
(
fetch_tensors
.
size
());
exception_holder_
.
Clear
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
call
=
[
this
,
i
,
&
fetch_tensors
]()
->
FeedFetchList
{
try
{
return
executors_
[
i
]
->
Run
(
fetch_tensors
);
}
catch
(...)
{
exception_holder_
.
Catch
(
std
::
current_exception
());
}
return
FeedFetchList
();
};
if
(
pool_
)
{
run_futures
.
emplace_back
(
pool_
->
enqueue
(
std
::
move
(
call
)));
}
else
{
fetch_data
.
emplace_back
(
std
::
move
(
call
()));
}
}
if
(
pool_
)
{
for
(
auto
&
f
:
run_futures
)
{
if
(
exception_holder_
.
IsCaught
())
{
f
.
wait
();
}
else
{
fetch_data
.
emplace_back
(
std
::
move
(
f
.
get
()));
}
}
}
if
(
exception_holder_
.
IsCaught
())
{
exception_holder_
.
ReThrow
();
}
for
(
size_t
fetch_idx
=
0
;
fetch_idx
<
fetch_tensors
.
size
();
++
fetch_idx
)
{
std
::
vector
<
const
LoDTensor
*>
lodtensor_ptrs
;
lodtensor_ptrs
.
reserve
(
local_scopes_
.
size
());
for
(
size_t
scope_idx
=
0
;
scope_idx
<
local_scopes_
.
size
();
++
scope_idx
)
{
lodtensor_ptrs
.
push_back
(
&
fetch_data
.
at
(
scope_idx
).
at
(
fetch_idx
));
}
ret
.
emplace_back
();
ret
.
back
().
MergeLoDTensor
(
lodtensor_ptrs
,
platform
::
CPUPlace
());
}
return
ret
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/parallel_ssa_graph_executor.h
0 → 100644
浏览文件 @
c919b2f3
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
ParallelSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
ParallelSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
&&
graphs
);
~
ParallelSSAGraphExecutor
()
final
=
default
;
const
ir
::
Graph
&
Graph
()
const
override
{
return
*
graphs_
[
0
];
}
FeedFetchList
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
override
;
private:
ExecutionStrategy
strategy_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
std
::
vector
<
platform
::
Place
>
places_
;
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs_
;
std
::
vector
<
std
::
unique_ptr
<
details
::
ThreadedSSAGraphExecutor
>>
executors_
;
ExceptionHolder
exception_holder_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
c919b2f3
...
@@ -56,7 +56,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
...
@@ -56,7 +56,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
}
}
}
std
::
vector
<
framework
::
LoDTensor
>
fetch_data
;
std
::
vector
<
framework
::
LoDTensor
>
fetch_data
;
std
::
exception_ptr
eptr
;
std
::
exception_ptr
eptr
=
nullptr
;
try
{
try
{
fetch_data
=
underlying_executor_
->
Run
(
fetch_tensors
);
fetch_data
=
underlying_executor_
->
Run
(
fetch_tensors
);
}
catch
(...)
{
}
catch
(...)
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
c919b2f3
...
@@ -21,12 +21,9 @@ limitations under the License. */
...
@@ -21,12 +21,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
...
@@ -38,6 +35,8 @@ limitations under the License. */
...
@@ -38,6 +35,8 @@ limitations under the License. */
DEFINE_string
(
pe_profile_fname
,
""
,
DEFINE_string
(
pe_profile_fname
,
""
,
"Profiler filename for PE, which generated by gperftools."
"Profiler filename for PE, which generated by gperftools."
"Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable."
);
"Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable."
);
DEFINE_bool
(
enable_parallel_graph
,
false
,
"Force disable parallel graph execution mode if set false."
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -106,6 +105,7 @@ class ParallelExecutorPrivate {
...
@@ -106,6 +105,7 @@ class ParallelExecutorPrivate {
bool
own_local_scope_
;
bool
own_local_scope_
;
bool
use_cuda_
;
bool
use_cuda_
;
bool
use_all_reduce_
;
bool
use_all_reduce_
;
size_t
nranks_
;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// then keeps unchanged
...
@@ -201,6 +201,7 @@ ParallelExecutor::ParallelExecutor(
...
@@ -201,6 +201,7 @@ ParallelExecutor::ParallelExecutor(
member_
->
build_strategy_
=
build_strategy
;
member_
->
build_strategy_
=
build_strategy
;
member_
->
use_all_reduce_
=
member_
->
use_all_reduce_
=
build_strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
;
build_strategy
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
;
member_
->
nranks_
=
num_trainers
*
places
.
size
();
if
(
!
member_
->
use_all_reduce_
)
{
if
(
!
member_
->
use_all_reduce_
)
{
PADDLE_ENFORCE
(
places
.
size
()
>
1
,
PADDLE_ENFORCE
(
places
.
size
()
>
1
,
...
@@ -224,62 +225,98 @@ ParallelExecutor::ParallelExecutor(
...
@@ -224,62 +225,98 @@ ParallelExecutor::ParallelExecutor(
}
}
}
}
// FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy.
build_strategy
.
enable_parallel_graph_
=
EnableParallelGraphExecution
(
main_program
,
exec_strategy
,
build_strategy
);
VLOG
(
1
)
<<
"Enable ParallelGraph Execution: "
<<
build_strategy
.
enable_parallel_graph_
;
if
(
member_
->
use_cuda_
)
{
if
(
member_
->
use_cuda_
)
{
// Bcast Parameters to all GPUs
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
ncclUniqueId
*
nccl_id
=
nullptr
;
ncclUniqueId
*
nccl_id
=
nullptr
;
// gen_nccl_id operator can broadcast the ncclUniqueId for nccl2 collective
// distributed training
auto
*
nccl_id_var
=
scope
->
FindVar
(
NCCL_ID_VARNAME
);
if
(
nccl_id_var
!=
nullptr
)
{
if
(
nccl_id_var
!=
nullptr
)
{
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
nccl_id
=
nccl_id_var
->
GetMutable
<
ncclUniqueId
>
();
}
}
if
(
build_strategy
.
enable_parallel_graph_
&&
member_
->
nranks_
>
1UL
)
{
if
(
nccl_id
==
nullptr
)
{
local_nccl_id_
.
reset
(
new
ncclUniqueId
());
platform
::
dynload
::
ncclGetUniqueId
(
local_nccl_id_
.
get
());
nccl_id
=
local_nccl_id_
.
get
();
}
}
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
nccl_ctxs_
.
reset
(
new
platform
::
NCCLContextMap
(
member_
->
places_
,
nccl_id
,
num_trainers
,
trainer_id
));
member_
->
places_
,
nccl_id
,
num_trainers
,
trainer_id
));
#else
#else
PADDLE_THROW
(
"Not compiled with CUDA"
);
PADDLE_THROW
(
"Not compiled with CUDA"
);
#endif
#endif
}
}
if
(
member_
->
local_scopes_
.
size
()
!=
1
&&
local_scopes
.
empty
())
{
if
(
member_
->
local_scopes_
.
size
()
!=
1
&&
local_scopes
.
empty
())
{
BCastParamsToDevices
(
bcast_vars
);
BCastParamsToDevices
(
bcast_vars
);
}
}
// Startup Program has been run. All local scopes has correct parameters.
// Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// ncclOp
std
::
vector
<
std
::
unique_ptr
<
ir
::
Graph
>>
graphs
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if
(
build_strategy
.
enable_parallel_graph_
)
{
for
(
size_t
i
=
0
;
i
<
member_
->
places_
.
size
();
++
i
)
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
{
member_
->
places_
[
i
]},
loss_var_name
,
{
member_
->
local_scopes_
[
i
]},
member_
->
nranks_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
graphs
.
push_back
(
std
::
move
(
graph
));
}
}
else
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
nranks_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
graphs
.
push_back
(
std
::
move
(
graph
));
}
#else
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
use_cuda_
,
member_
->
nccl_ctxs_
.
get
());
member_
->
nranks_
,
member_
->
use_cuda_
);
#else
graphs
.
push_back
(
std
::
move
(
graph
));
std
::
unique_ptr
<
ir
::
Graph
>
graph
=
build_strategy
.
Apply
(
main_program
,
member_
->
places_
,
loss_var_name
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
#endif
auto
max_memory_size
=
GetEagerDeletionThreshold
();
auto
max_memory_size
=
GetEagerDeletionThreshold
();
if
(
max_memory_size
>=
0
)
{
if
(
max_memory_size
>=
0
)
{
graph
=
member_
->
PrepareGCAndRefCnts
(
std
::
move
(
graph
),
for
(
size_t
i
=
0
;
i
<
graphs
.
size
();
++
i
)
{
static_cast
<
size_t
>
(
max_memory_size
));
graphs
[
i
]
=
member_
->
PrepareGCAndRefCnts
(
std
::
move
(
graphs
[
i
]),
static_cast
<
size_t
>
(
max_memory_size
));
}
}
}
// Step 3. Create vars in each scope. Passes may also create new vars.
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
// skip control vars and empty vars
std
::
vector
<
details
::
VariableInfo
>
var_infos
;
std
::
vector
<
details
::
VariableInfo
>
var_infos
;
for
(
auto
&
node
:
graph
->
Nodes
())
{
for
(
auto
&
graph
:
graphs
)
{
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Var
())
{
for
(
auto
&
node
:
graph
->
Nodes
())
{
var_infos
.
emplace_back
();
if
(
node
->
IsVar
()
&&
!
node
->
IsCtrlVar
()
&&
node
->
Var
())
{
var_infos
.
back
().
name_
=
node
->
Var
()
->
Name
();
var_infos
.
emplace_back
();
var_infos
.
back
().
type_
=
node
->
Var
()
->
GetType
();
var_infos
.
back
().
name_
=
node
->
Var
()
->
Name
();
var_infos
.
back
().
persistable_
=
node
->
Var
()
->
Persistable
();
var_infos
.
back
().
type_
=
node
->
Var
()
->
GetType
();
var_infos
.
back
().
persistable_
=
node
->
Var
()
->
Persistable
();
}
}
}
}
}
// If the loss_var_name is given, the number of graph should be only one.
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
if
(
loss_var_name
.
size
())
{
size_t
graph_num
=
ir
::
GraphNum
(
*
graph
);
size_t
graph_num
=
ir
::
GraphNum
(
*
graph
s
[
0
]
);
if
(
graph_num
>
1
)
{
if
(
graph_num
>
1
)
{
LOG
(
WARNING
)
LOG
(
WARNING
)
<<
"The number of graph should be only one, "
<<
"The number of graph should be only one, "
"but the current graph has "
"but the current graph has "
<<
ir
::
GraphNum
(
*
graph
)
<<
ir
::
GraphNum
(
*
graph
s
[
0
]
)
<<
" sub_graphs. If you want to see the nodes of the "
<<
" sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, "
"to specify the output dir. NOTES: if you not do training, "
...
@@ -287,14 +324,20 @@ ParallelExecutor::ParallelExecutor(
...
@@ -287,14 +324,20 @@ ParallelExecutor::ParallelExecutor(
}
}
}
}
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
if
(
build_strategy
.
enable_parallel_graph_
)
{
member_
->
executor_
.
reset
(
new
details
::
Threaded
SSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
Parallel
SSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graph
)));
std
::
move
(
graph
s
)));
}
else
{
}
else
{
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
if
(
exec_strategy
.
type_
==
ExecutionStrategy
::
kDefault
)
{
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graphs
[
0
])));
}
else
{
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
places_
,
std
::
move
(
graphs
[
0
])));
}
}
}
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
...
@@ -423,6 +466,36 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
...
@@ -423,6 +466,36 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
}
}
}
bool
ParallelExecutor
::
EnableParallelGraphExecution
(
const
ProgramDesc
&
main_program
,
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
)
const
{
if
(
!
FLAGS_enable_parallel_graph
)
return
false
;
bool
enable_parallel_graph
=
true
;
// TODO(Yancey1989): support sparse update in ParallelGraph mode.
for
(
auto
&
var_desc
:
main_program
.
Block
(
0
).
AllVars
())
{
if
(
var_desc
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
)
{
enable_parallel_graph
=
false
;
}
}
// TODO(Yancey1989): support pserver mode
for
(
auto
&
op_desc
:
main_program
.
Block
(
0
).
AllOps
())
{
if
(
op_desc
->
Type
()
==
"send"
||
op_desc
->
Type
()
==
"recv"
)
{
enable_parallel_graph
=
false
;
break
;
}
}
if
(
!
member_
->
use_all_reduce_
||
!
member_
->
use_cuda_
)
enable_parallel_graph
=
false
;
if
(
build_strategy
.
enable_sequential_execution_
||
exec_strategy
.
type_
==
ExecutionStrategy
::
ExecutorType
::
kExperimental
)
enable_parallel_graph
=
false
;
return
enable_parallel_graph
;
}
ParallelExecutor
::~
ParallelExecutor
()
{
ParallelExecutor
::~
ParallelExecutor
()
{
for
(
auto
&
p
:
member_
->
places_
)
{
for
(
auto
&
p
:
member_
->
places_
)
{
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
)
->
Wait
();
...
...
paddle/fluid/framework/parallel_executor.h
浏览文件 @
c919b2f3
...
@@ -28,6 +28,10 @@ limitations under the License. */
...
@@ -28,6 +28,10 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -68,8 +72,14 @@ class ParallelExecutor {
...
@@ -68,8 +72,14 @@ class ParallelExecutor {
private:
private:
void
BCastParamsToDevices
(
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
;
void
BCastParamsToDevices
(
const
std
::
unordered_set
<
std
::
string
>
&
vars
)
const
;
bool
EnableParallelGraphExecution
(
const
ProgramDesc
&
main_program
,
const
ExecutionStrategy
&
exec_strategy
,
const
BuildStrategy
&
build_strategy
)
const
;
ParallelExecutorPrivate
*
member_
;
ParallelExecutorPrivate
*
member_
;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std
::
unique_ptr
<
ncclUniqueId
>
local_nccl_id_
;
#endif
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/threadpool.cc
浏览文件 @
c919b2f3
...
@@ -89,7 +89,6 @@ void ThreadPool::TaskLoop() {
...
@@ -89,7 +89,6 @@ void ThreadPool::TaskLoop() {
task
=
std
::
move
(
tasks_
.
front
());
task
=
std
::
move
(
tasks_
.
front
());
tasks_
.
pop
();
tasks_
.
pop
();
}
}
// run the task
// run the task
task
();
task
();
}
}
...
...
paddle/fluid/operators/reader/ctr_reader.h
浏览文件 @
c919b2f3
...
@@ -49,7 +49,7 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
...
@@ -49,7 +49,7 @@ void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
class
CTRReader
:
public
framework
::
FileReader
{
class
CTRReader
:
public
framework
::
FileReader
{
public:
public:
explicit
CTRReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
,
explicit
CTRReader
(
const
std
::
shared_ptr
<
LoDTensorBlockingQueue
>&
queue
,
int
batch_size
,
in
t
thread_num
,
int
batch_size
,
size_
t
thread_num
,
const
std
::
vector
<
std
::
string
>&
slots
,
const
std
::
vector
<
std
::
string
>&
slots
,
const
std
::
vector
<
std
::
string
>&
file_list
)
const
std
::
vector
<
std
::
string
>&
file_list
)
:
batch_size_
(
batch_size
),
slots_
(
slots
),
file_list_
(
file_list
)
{
:
batch_size_
(
batch_size
),
slots_
(
slots
),
file_list_
(
file_list
)
{
...
...
paddle/fluid/platform/nccl_helper.h
浏览文件 @
c919b2f3
...
@@ -106,7 +106,7 @@ struct NCCLContextMap {
...
@@ -106,7 +106,7 @@ struct NCCLContextMap {
}
}
std
::
unique_ptr
<
ncclComm_t
[]
>
comms
(
new
ncclComm_t
[
order_
.
size
()]);
std
::
unique_ptr
<
ncclComm_t
[]
>
comms
(
new
ncclComm_t
[
order_
.
size
()]);
// if num_trainers == 1, should create a new nccl id for local comms.
// if num_trainers == 1, should create a new nccl id for local comms.
if
(
num_trainers
==
1
)
{
if
(
num_trainers
==
1
&&
nccl_id
==
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
NCCLGroupGuard
::
NCCLMutex
());
std
::
lock_guard
<
std
::
mutex
>
guard
(
NCCLGroupGuard
::
NCCLMutex
());
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclCommInitAll
(
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclCommInitAll
(
comms
.
get
(),
static_cast
<
int
>
(
order_
.
size
()),
order_
.
data
()));
comms
.
get
(),
static_cast
<
int
>
(
order_
.
size
()),
order_
.
data
()));
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
c919b2f3
...
@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/port.h"
#include <algorithm>
#include <algorithm>
#include <iomanip>
#include <iomanip>
#include <limits>
#include <limits>
...
@@ -25,9 +22,12 @@ limitations under the License. */
...
@@ -25,9 +22,12 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda.h>
#endif // PADDLE_WITH_CUDA
#endif // PADDLE_WITH_CUDA
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/printf.h"
DEFINE_bool
(
enable_rpc_profiler
,
false
,
"Enable rpc profiler or not."
);
DEFINE_bool
(
enable_rpc_profiler
,
false
,
"Enable rpc profiler or not."
);
...
@@ -173,8 +173,9 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
...
@@ -173,8 +173,9 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx) {
RecordEvent
::
RecordEvent
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
)
RecordEvent
::
RecordEvent
(
const
std
::
string
&
name
,
const
DeviceContext
*
dev_ctx
)
:
is_enabled_
(
false
),
start_ns_
(
PosixInNsec
())
{
:
is_enabled_
(
false
),
start_ns_
(
PosixInNsec
())
{
std
::
lock_guard
<
std
::
mutex
>
l
(
profiler_mu
);
if
(
g_state
==
ProfilerState
::
kDisabled
)
return
;
if
(
g_state
==
ProfilerState
::
kDisabled
)
return
;
std
::
lock_guard
<
std
::
mutex
>
l
(
profiler_mu
);
is_enabled_
=
true
;
is_enabled_
=
true
;
dev_ctx_
=
dev_ctx
;
dev_ctx_
=
dev_ctx
;
name_
=
name
;
name_
=
name
;
...
@@ -184,8 +185,8 @@ RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx)
...
@@ -184,8 +185,8 @@ RecordEvent::RecordEvent(const std::string& name, const DeviceContext* dev_ctx)
}
}
RecordEvent
::~
RecordEvent
()
{
RecordEvent
::~
RecordEvent
()
{
std
::
lock_guard
<
std
::
mutex
>
l
(
profiler_mu
);
if
(
g_state
==
ProfilerState
::
kDisabled
||
!
is_enabled_
)
return
;
if
(
g_state
==
ProfilerState
::
kDisabled
||
!
is_enabled_
)
return
;
std
::
lock_guard
<
std
::
mutex
>
l
(
profiler_mu
);
DeviceTracer
*
tracer
=
GetDeviceTracer
();
DeviceTracer
*
tracer
=
GetDeviceTracer
();
if
(
tracer
)
{
if
(
tracer
)
{
tracer
->
AddCPURecords
(
CurAnnotation
(),
start_ns_
,
PosixInNsec
(),
tracer
->
AddCPURecords
(
CurAnnotation
(),
start_ns_
,
PosixInNsec
(),
...
...
python/paddle/fluid/__init__.py
浏览文件 @
c919b2f3
...
@@ -135,7 +135,8 @@ def __bootstrap__():
...
@@ -135,7 +135,8 @@ def __bootstrap__():
'free_idle_memory'
,
'paddle_num_threads'
,
"dist_threadpool_size"
,
'free_idle_memory'
,
'paddle_num_threads'
,
"dist_threadpool_size"
,
'eager_delete_tensor_gb'
,
'fast_eager_deletion_mode'
,
'eager_delete_tensor_gb'
,
'fast_eager_deletion_mode'
,
'allocator_strategy'
,
'reader_queue_speed_test_mode'
,
'allocator_strategy'
,
'reader_queue_speed_test_mode'
,
'print_sub_graph_dir'
,
'pe_profile_fname'
,
'warpctc_dir'
'print_sub_graph_dir'
,
'pe_profile_fname'
,
'warpctc_dir'
,
'enable_parallel_graph'
]
]
if
'Darwin'
not
in
sysstr
:
if
'Darwin'
not
in
sysstr
:
read_env_flags
.
append
(
'use_pinned_memory'
)
read_env_flags
.
append
(
'use_pinned_memory'
)
...
@@ -158,14 +159,10 @@ def __bootstrap__():
...
@@ -158,14 +159,10 @@ def __bootstrap__():
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
read_env_flags
+=
[
read_env_flags
+=
[
'fraction_of_gpu_memory_to_use'
,
'fraction_of_gpu_memory_to_use'
,
'cudnn_deterministic'
,
'cudnn_deterministic'
,
'enable_cublas_tensor_op_math'
,
'conv_workspace_size_limit'
,
'enable_cublas_tensor_op_math'
,
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
,
'conv_workspace_size_limit'
,
'cudnn_exhaustive_search_times'
,
'sync_nccl_allreduce'
'cudnn_exhaustive_search'
,
'memory_optimize_debug'
,
'selected_gpus'
,
'cudnn_exhaustive_search_times'
,
]
]
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
core
.
init_gflags
([
sys
.
argv
[
0
]]
+
...
...
python/paddle/fluid/tests/unittests/parallel_executor_test_base.py
浏览文件 @
c919b2f3
...
@@ -78,7 +78,6 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -78,7 +78,6 @@ class TestParallelExecutorBase(unittest.TestCase):
exec_strategy
.
allow_op_delay
=
allow_op_delay
exec_strategy
.
allow_op_delay
=
allow_op_delay
if
use_fast_executor
:
if
use_fast_executor
:
exec_strategy
.
use_experimental_executor
=
True
exec_strategy
.
use_experimental_executor
=
True
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
\
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
if
use_reduce
else
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
...
...
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
c919b2f3
...
@@ -442,10 +442,10 @@ class TestDistBase(unittest.TestCase):
...
@@ -442,10 +442,10 @@ class TestDistBase(unittest.TestCase):
tr_cmd
=
"%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
tr_cmd
=
"%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
tr0_cmd
=
tr_cmd
%
\
tr0_cmd
=
tr_cmd
%
\
(
self
.
_python_interp
,
model
,
self
.
_ps_endpoints
,
(
self
.
_python_interp
,
model
,
self
.
_ps_endpoints
,
0
,
w0_ep
,
self
.
_lr
/
2
)
0
,
w0_ep
,
self
.
_lr
)
tr1_cmd
=
tr_cmd
%
\
tr1_cmd
=
tr_cmd
%
\
(
self
.
_python_interp
,
model
,
self
.
_ps_endpoints
,
(
self
.
_python_interp
,
model
,
self
.
_ps_endpoints
,
1
,
w1_ep
,
self
.
_lr
/
2
)
1
,
w1_ep
,
self
.
_lr
)
if
self
.
_mem_opt
:
if
self
.
_mem_opt
:
tr0_cmd
+=
" --mem_opt"
tr0_cmd
+=
" --mem_opt"
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_crf.py
浏览文件 @
c919b2f3
...
@@ -175,41 +175,61 @@ class TestCRFModel(unittest.TestCase):
...
@@ -175,41 +175,61 @@ class TestCRFModel(unittest.TestCase):
print
(
pe
.
run
(
feed
=
feeder
.
feed
(
cur_batch
),
print
(
pe
.
run
(
feed
=
feeder
.
feed
(
cur_batch
),
fetch_list
=
[
avg_cost
.
name
])[
0
])
fetch_list
=
[
avg_cost
.
name
])[
0
])
def
test_update_sparse_parameter_all_reduce
(
self
):
def
_new_build_strategy
(
self
,
use_reduce
=
False
):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
if
use_reduce
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
else
:
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
return
build_strategy
def
test_update_sparse_parameter_all_reduce
(
self
):
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
True
,
build_strategy
=
build_strategy
,
use_cuda
=
True
)
is_sparse
=
True
,
build_strategy
=
self
.
_new_build_strategy
(),
use_cuda
=
True
)
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
True
,
build_strategy
=
build_strategy
,
use_cuda
=
False
)
is_sparse
=
True
,
build_strategy
=
self
.
_new_build_strategy
(),
use_cuda
=
False
)
def
test_update_dense_parameter_all_reduce
(
self
):
def
test_update_dense_parameter_all_reduce
(
self
):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
AllReduce
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
False
,
build_strategy
=
build_strategy
,
use_cuda
=
True
)
is_sparse
=
False
,
build_strategy
=
self
.
_new_build_strategy
(),
use_cuda
=
True
)
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
False
,
build_strategy
=
build_strategy
,
use_cuda
=
False
)
is_sparse
=
False
,
build_strategy
=
self
.
_new_build_strategy
(),
use_cuda
=
False
)
def
test_update_sparse_parameter_reduce
(
self
):
def
test_update_sparse_parameter_reduce
(
self
):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
True
,
build_strategy
=
build_strategy
,
use_cuda
=
True
)
is_sparse
=
True
,
build_strategy
=
self
.
_new_build_strategy
(
use_reduce
=
True
),
use_cuda
=
True
)
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
True
,
build_strategy
=
build_strategy
,
use_cuda
=
False
)
is_sparse
=
True
,
build_strategy
=
self
.
_new_build_strategy
(
use_reduce
=
True
),
use_cuda
=
False
)
def
test_update_dense_parameter_reduce
(
self
):
def
test_update_dense_parameter_reduce
(
self
):
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
if
core
.
is_compiled_with_cuda
():
if
core
.
is_compiled_with_cuda
():
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
False
,
build_strategy
=
build_strategy
,
use_cuda
=
True
)
is_sparse
=
False
,
build_strategy
=
self
.
_new_build_strategy
(
use_reduce
=
True
),
use_cuda
=
True
)
self
.
check_network_convergence
(
self
.
check_network_convergence
(
is_sparse
=
False
,
build_strategy
=
build_strategy
,
use_cuda
=
False
)
is_sparse
=
False
,
build_strategy
=
self
.
_new_build_strategy
(
use_reduce
=
True
),
use_cuda
=
False
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_mnist.py
浏览文件 @
c919b2f3
...
@@ -86,6 +86,7 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -86,6 +86,7 @@ class TestMNIST(TestParallelExecutorBase):
"label"
:
label
},
"label"
:
label
},
use_cuda
=
use_cuda
,
use_cuda
=
use_cuda
,
use_reduce
=
False
)
use_reduce
=
False
)
reduce_first_loss
,
reduce_last_loss
=
self
.
check_network_convergence
(
reduce_first_loss
,
reduce_last_loss
=
self
.
check_network_convergence
(
model
,
model
,
feed_dict
=
{
"image"
:
img
,
feed_dict
=
{
"image"
:
img
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录