Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
ff9b1a0f
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff9b1a0f
编写于
6月 07, 2018
作者:
Y
Yu Yang
提交者:
GitHub
6月 07, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11234 from reyoung/feature/refine_code
SSA Graph Builder Factory
上级
08823146
d9af1532
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
295 addition
and
89 deletion
+295
-89
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+4
-0
paddle/fluid/framework/details/broadcast_op_handle.h
paddle/fluid/framework/details/broadcast_op_handle.h
+2
-2
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+4
-0
paddle/fluid/framework/details/graph_builder_factory.cc
paddle/fluid/framework/details/graph_builder_factory.cc
+47
-0
paddle/fluid/framework/details/graph_builder_factory.h
paddle/fluid/framework/details/graph_builder_factory.h
+67
-0
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+0
-9
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
+2
-2
paddle/fluid/framework/details/reduce_op_handle.h
paddle/fluid/framework/details/reduce_op_handle.h
+2
-2
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+0
-58
paddle/fluid/framework/details/ssa_graph_builder.h
paddle/fluid/framework/details/ssa_graph_builder.h
+0
-2
paddle/fluid/framework/details/ssa_graph_printer.cc
paddle/fluid/framework/details/ssa_graph_printer.cc
+83
-0
paddle/fluid/framework/details/ssa_graph_printer.h
paddle/fluid/framework/details/ssa_graph_printer.h
+67
-0
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+10
-13
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+6
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
ff9b1a0f
...
...
@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
multi_devices_graph_builder
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
graph_builder_factory
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
ff9b1a0f
...
...
@@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place
cc_library
(
ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base
)
cc_library
(
ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph
)
cc_library
(
ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
...
...
@@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle
${
multi_devices_graph_builder_deps
}
reduce_op_handle broadcast_op_handle
)
cc_library
(
graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context
)
...
...
paddle/fluid/framework/details/broadcast_op_handle.h
浏览文件 @
ff9b1a0f
...
...
@@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
void
RunImpl
()
override
;
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
#ifdef PADDLE_WITH_CUDA
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
ff9b1a0f
...
...
@@ -14,6 +14,8 @@
#pragma once
#include <string>
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -29,6 +31,8 @@ struct BuildStrategy {
ReduceStrategy
reduce_
{
ReduceStrategy
::
kAllReduce
};
GradientScaleStrategy
gradient_scale_
{
GradientScaleStrategy
::
kCoeffNumDevice
};
std
::
string
debug_graphviz_path_
{
""
};
};
}
// namespace details
...
...
paddle/fluid/framework/details/graph_builder_factory.cc
0 → 100644
浏览文件 @
ff9b1a0f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
std
::
unique_ptr
<
SSAGraphBuilder
>
SSAGraphBuilderFactory
::
Create
()
{
std
::
unique_ptr
<
SSAGraphBuilder
>
res
(
#ifdef PADDLE_WITH_CUDA
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
nccl_ctxs_
,
strategy_
)
#else
new
MultiDevSSAGraphBuilder
(
places_
,
loss_var_name_
,
param_names_
,
local_scopes_
,
strategy_
)
#endif
);
// NOLINT
if
(
!
strategy_
.
debug_graphviz_path_
.
empty
())
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
strategy_
.
debug_graphviz_path_
));
PADDLE_ENFORCE
(
fout
->
good
());
std
::
unique_ptr
<
GraphvizSSAGraphPrinter
>
graphviz_printer
(
new
GraphvizSSAGraphPrinter
());
res
.
reset
(
new
SSAGraghBuilderWithPrinter
(
std
::
move
(
fout
),
std
::
move
(
graphviz_printer
),
std
::
move
(
res
)));
}
return
res
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/graph_builder_factory.h
0 → 100644
浏览文件 @
ff9b1a0f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
class
SSAGraphBuilderFactory
{
public:
SSAGraphBuilderFactory
(
const
std
::
vector
<
platform
::
Place
>&
places
,
const
std
::
string
&
loss_var_name
,
const
std
::
unordered_set
<
std
::
string
>&
param_names
,
const
std
::
vector
<
Scope
*>&
local_scopes
,
const
BuildStrategy
&
strategy
)
:
places_
(
places
),
loss_var_name_
(
loss_var_name
),
param_names_
(
param_names
),
local_scopes_
(
local_scopes
),
strategy_
(
strategy
)
{}
#ifdef PADDLE_WITH_CUDA
void
SetNCCLContextMap
(
platform
::
NCCLContextMap
*
nccl_ctxs
)
{
nccl_ctxs_
=
nccl_ctxs
;
}
#endif
std
::
unique_ptr
<
SSAGraphBuilder
>
Create
();
private:
std
::
vector
<
platform
::
Place
>
places_
;
std
::
string
loss_var_name_
;
std
::
unordered_set
<
std
::
string
>
param_names_
;
std
::
vector
<
Scope
*>
local_scopes_
;
BuildStrategy
strategy_
;
#ifdef PADDLE_WITH_CUDA
platform
::
NCCLContextMap
*
nccl_ctxs_
;
#endif
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
ff9b1a0f
...
...
@@ -30,10 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif
DEFINE_string
(
ssa_graph_path
,
"/tmp/ssa_graph.dot"
,
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot"
);
namespace
paddle
{
namespace
framework
{
namespace
details
{
...
...
@@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/
AddOutputToLeafOps
(
&
result
);
if
(
VLOG_IS_ON
(
10
))
{
std
::
ofstream
fout
(
FLAGS_ssa_graph_path
);
PrintGraphviz
(
*
graph
,
fout
);
}
return
std
::
unique_ptr
<
SSAGraph
>
(
graph
);
}
...
...
paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
浏览文件 @
ff9b1a0f
...
...
@@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
void
RunImpl
()
override
;
private:
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
const
platform
::
NCCLContextMap
&
nccl_ctxs_
;
};
...
...
paddle/fluid/framework/details/reduce_op_handle.h
浏览文件 @
ff9b1a0f
...
...
@@ -32,8 +32,8 @@ namespace framework {
namespace
details
{
struct
ReduceOpHandle
:
public
OpHandleBase
{
const
std
::
vector
<
Scope
*>
&
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
&
places_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
#ifdef PADDLE_WITH_CUDA
const
platform
::
NCCLContextMap
*
nccl_ctxs_
;
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
ff9b1a0f
...
...
@@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
op_handle
->
AddOutput
(
var
);
}
template
<
typename
Callback
>
void
IterAllVar
(
const
SSAGraph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
vars_
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
}
}
}
for
(
auto
&
var
:
graph
.
dep_vars_
)
{
callback
(
*
var
);
}
}
void
SSAGraphBuilder
::
PrintGraphviz
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
{
size_t
var_id
=
0
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
IterAllVar
(
graph
,
[
&
](
const
VarHandleBase
&
var
)
{
auto
*
var_ptr
=
&
var
;
auto
*
var_handle_ptr
=
dynamic_cast
<
const
VarHandle
*>
(
var_ptr
);
auto
*
dummy_ptr
=
dynamic_cast
<
const
DummyVarHandle
*>
(
var_ptr
);
size_t
cur_var_id
=
var_id
++
;
vars
[
var_ptr
]
=
cur_var_id
;
if
(
var_handle_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
var_handle_ptr
->
name_
<<
"
\\
n"
<<
var_handle_ptr
->
place_
<<
"
\\
n"
<<
var_handle_ptr
->
version_
<<
"
\"
]"
<<
std
::
endl
;
}
else
if
(
dummy_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
dummy
\"
]"
<<
std
::
endl
;
}
});
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
ops_
)
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
op
->
Inputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
for
(
auto
out
:
op
->
Outputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
}
}
sout
<<
"}
\n
"
;
}
void
SSAGraphBuilder
::
AddOutputToLeafOps
(
SSAGraph
*
graph
)
{
for
(
auto
&
op
:
graph
->
ops_
)
{
if
(
!
op
->
Outputs
().
empty
())
{
...
...
paddle/fluid/framework/details/ssa_graph_builder.h
浏览文件 @
ff9b1a0f
...
...
@@ -55,8 +55,6 @@ class SSAGraphBuilder {
const
platform
::
Place
&
place
,
size_t
place_offset
);
static
void
AddOutputToLeafOps
(
SSAGraph
*
graph
);
static
void
PrintGraphviz
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
);
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/details/ssa_graph_printer.cc
0 → 100644
浏览文件 @
ff9b1a0f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
SSAGraph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
vars_
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
}
}
}
for
(
auto
&
var
:
graph
.
dep_vars_
)
{
callback
(
*
var
);
}
}
void
GraphvizSSAGraphPrinter
::
Print
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
const
{
size_t
var_id
=
0
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
IterAllVar
(
graph
,
[
&
](
const
VarHandleBase
&
var
)
{
auto
*
var_ptr
=
&
var
;
auto
*
var_handle_ptr
=
dynamic_cast
<
const
VarHandle
*>
(
var_ptr
);
auto
*
dummy_ptr
=
dynamic_cast
<
const
DummyVarHandle
*>
(
var_ptr
);
size_t
cur_var_id
=
var_id
++
;
vars
[
var_ptr
]
=
cur_var_id
;
if
(
var_handle_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
"
<<
var_handle_ptr
->
name_
<<
"
\\
n"
<<
var_handle_ptr
->
place_
<<
"
\\
n"
<<
var_handle_ptr
->
version_
<<
"
\"
]"
<<
std
::
endl
;
}
else
if
(
dummy_ptr
)
{
sout
<<
"var_"
<<
cur_var_id
<<
" [label=
\"
dummy
\"
]"
<<
std
::
endl
;
}
});
size_t
op_id
=
0
;
for
(
auto
&
op
:
graph
.
ops_
)
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
for
(
auto
in
:
op
->
Inputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
in
]);
sout
<<
var_name
<<
" -> "
<<
op_name
<<
std
::
endl
;
}
for
(
auto
out
:
op
->
Outputs
())
{
std
::
string
var_name
=
"var_"
+
std
::
to_string
(
vars
[
out
]);
sout
<<
op_name
<<
" -> "
<<
var_name
<<
std
::
endl
;
}
}
sout
<<
"}
\n
"
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/ssa_graph_printer.h
0 → 100644
浏览文件 @
ff9b1a0f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iosfwd>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
SSAGraph
;
class
SSAGraphPrinter
{
public:
virtual
~
SSAGraphPrinter
()
{}
virtual
void
Print
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
const
=
0
;
};
class
GraphvizSSAGraphPrinter
:
public
SSAGraphPrinter
{
public:
void
Print
(
const
SSAGraph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
};
class
SSAGraghBuilderWithPrinter
:
public
SSAGraphBuilder
{
public:
SSAGraghBuilderWithPrinter
(
std
::
ostream
&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ref_
(
sout
)
{}
SSAGraghBuilderWithPrinter
(
std
::
unique_ptr
<
std
::
ostream
>&&
sout
,
std
::
unique_ptr
<
SSAGraphPrinter
>&&
printer
,
std
::
unique_ptr
<
SSAGraphBuilder
>&&
builder
)
:
printer_
(
std
::
move
(
printer
)),
builder_
(
std
::
move
(
builder
)),
stream_ptr_
(
std
::
move
(
sout
)),
stream_ref_
(
*
stream_ptr_
)
{}
std
::
unique_ptr
<
SSAGraph
>
Build
(
const
ProgramDesc
&
program
)
const
override
{
auto
graph
=
builder_
->
Build
(
program
);
printer_
->
Print
(
*
graph
,
stream_ref_
);
return
graph
;
}
private:
std
::
unique_ptr
<
SSAGraphPrinter
>
printer_
;
std
::
unique_ptr
<
SSAGraphBuilder
>
builder_
;
std
::
unique_ptr
<
std
::
ostream
>
stream_ptr_
;
std
::
ostream
&
stream_ref_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
ff9b1a0f
...
...
@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/framework/details/
multi_devices_graph_builder
.h"
#include "paddle/fluid/framework/details/
graph_builder_factory
.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h"
...
...
@@ -102,22 +102,19 @@ ParallelExecutor::ParallelExecutor(
var_infos
.
back
().
persistable_
=
var
->
Persistable
();
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
details
::
MultiDevSSAGraphBuilder
builder
(
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
details
::
SSAGraphBuilderFactory
builder_factory
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
member_
->
nccl_ctxs_
.
get
(),
build_strategy
);
#else
details
::
MultiDevSSAGraphBuilder
builder
(
member_
->
places_
,
loss_var_name
,
params
,
member_
->
local_scopes_
,
build_strategy
);
#ifdef PADDLE_WITH_CUDA
builder_factory
.
SetNCCLContextMap
(
member_
->
nccl_ctxs_
.
get
());
#endif
auto
graph
=
builder
.
Build
(
main_program
);
member_
->
executor_
.
reset
(
new
details
::
ThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
places
,
std
::
move
(
graph
)));
exec_strategy
,
member_
->
local_scopes_
,
places
,
builder_factory
.
Create
()
->
Build
(
main_program
)));
member_
->
executor_
.
reset
(
new
details
::
ScopeBufferedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
std
::
move
(
var_infos
),
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
ff9b1a0f
...
...
@@ -553,6 +553,12 @@ All parameter, weight, gradient are variables in Paddle.
[](
BuildStrategy
&
self
,
BuildStrategy
::
GradientScaleStrategy
strategy
)
{
self
.
gradient_scale_
=
strategy
;
})
.
def_property
(
"debug_graphviz_path"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
debug_graphviz_path_
;
},
[](
BuildStrategy
&
self
,
const
std
::
string
&
path
)
{
self
.
debug_graphviz_path_
=
path
;
});
pe
.
def
(
py
::
init
<
const
std
::
vector
<
platform
::
Place
>
&
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录