Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
04bd413a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
04bd413a
编写于
5月 08, 2019
作者:
C
chengduo
提交者:
GitHub
5月 08, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Code Clean: Move all pass to paddle::framework::ir (#17228)
* move pass to ir * polish code test=develop * fix dependency test=develop
上级
648320bb
变更
51
隐藏空白更改
内联
并排
Showing
51 changed file
with
1094 addition
and
1211 deletion
+1094
-1211
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-35
paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc
...framework/details/alloc_continuous_space_for_grad_pass.cc
+0
-411
paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h
.../framework/details/alloc_continuous_space_for_grad_pass.h
+0
-79
paddle/fluid/framework/details/build_strategy.cc
paddle/fluid/framework/details/build_strategy.cc
+13
-14
paddle/fluid/framework/details/eager_deletion_op_handle.cc
paddle/fluid/framework/details/eager_deletion_op_handle.cc
+1
-1
paddle/fluid/framework/details/eager_deletion_op_handle.h
paddle/fluid/framework/details/eager_deletion_op_handle.h
+4
-4
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h
.../framework/details/modify_op_lock_and_record_event_pass.h
+0
-31
paddle/fluid/framework/details/sequential_execution_pass.cc
paddle/fluid/framework/details/sequential_execution_pass.cc
+0
-108
paddle/fluid/framework/details/sequential_execution_pass.h
paddle/fluid/framework/details/sequential_execution_pass.h
+0
-31
paddle/fluid/framework/inplace_op_inference.h
paddle/fluid/framework/inplace_op_inference.h
+1
-1
paddle/fluid/framework/inplace_op_inference_test.cc
paddle/fluid/framework/inplace_op_inference_test.cc
+6
-6
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+5
-1
paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc
...luid/framework/ir/alloc_continuous_space_for_grad_pass.cc
+414
-0
paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h
...fluid/framework/ir/alloc_continuous_space_for_grad_pass.h
+9
-11
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/CMakeLists.txt
...fluid/framework/ir/fuse_optimizer_ops_pass/CMakeLists.txt
+4
-0
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
...framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc
+5
-8
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
...ework/ir/fuse_optimizer_ops_pass/fuse_momentum_op_pass.cc
+5
-7
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc
...work/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.cc
+19
-18
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h
...ework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h
+2
-2
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
.../framework/ir/fuse_optimizer_ops_pass/fuse_sgd_op_pass.cc
+4
-9
paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt
...le/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt
+18
-0
paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc
.../framework/ir/memory_optimize_pass/eager_deletion_pass.cc
+32
-27
paddle/fluid/framework/ir/memory_optimize_pass/inplace_op_pass.cc
...luid/framework/ir/memory_optimize_pass/inplace_op_pass.cc
+8
-9
paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.cc
...amework/ir/memory_optimize_pass/memory_optimize_helper.cc
+6
-5
paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h
...ramework/ir/memory_optimize_pass/memory_optimize_helper.h
+2
-2
paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper_test.cc
...rk/ir/memory_optimize_pass/memory_optimize_helper_test.cc
+5
-6
paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.cc
...framework/ir/memory_optimize_pass/memory_optimize_pass.cc
+6
-7
paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h
.../framework/ir/memory_optimize_pass/memory_optimize_pass.h
+3
-3
paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.cc
.../fluid/framework/ir/memory_optimize_pass/op_graph_view.cc
+13
-11
paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h
...e/fluid/framework/ir/memory_optimize_pass/op_graph_view.h
+17
-14
paddle/fluid/framework/ir/memory_optimize_pass/record_skip_memory_opt_vars_pass.cc
.../memory_optimize_pass/record_skip_memory_opt_vars_pass.cc
+4
-4
paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc
...framework/ir/memory_optimize_pass/reference_count_pass.cc
+39
-30
paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.cc
...rk/ir/memory_optimize_pass/reference_count_pass_helper.cc
+9
-8
paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h
...ork/ir/memory_optimize_pass/reference_count_pass_helper.h
+7
-7
paddle/fluid/framework/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
...k/ir/memory_optimize_pass/while_op_eager_deletion_pass.cc
+5
-5
paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
...luid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
+16
-0
paddle/fluid/framework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
...ework/ir/multi_devices_graph_pass/all_reduce_deps_pass.cc
+32
-29
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
...rk/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
+35
-30
paddle/fluid/framework/ir/multi_devices_graph_pass/modify_op_lock_and_record_event_pass.cc
...evices_graph_pass/modify_op_lock_and_record_event_pass.cc
+22
-21
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_check_pass.cc
...ulti_devices_graph_pass/multi_devices_graph_check_pass.cc
+13
-12
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
...k/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
+132
-110
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
...rk/ir/multi_devices_graph_pass/multi_devices_graph_pass.h
+6
-6
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.cc
...ulti_devices_graph_pass/multi_devices_graph_print_pass.cc
+22
-12
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h
...multi_devices_graph_pass/multi_devices_graph_print_pass.h
+2
-12
paddle/fluid/framework/ir/multi_devices_graph_pass/sequential_execution_pass.cc
.../ir/multi_devices_graph_pass/sequential_execution_pass.cc
+113
-0
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
+15
-13
paddle/fluid/framework/ir/sync_batch_norm_pass.h
paddle/fluid/framework/ir/sync_batch_norm_pass.h
+0
-31
paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc
paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc
+2
-2
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+11
-13
paddle/fluid/pybind/const_value.cc
paddle/fluid/pybind/const_value.cc
+2
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-3
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
04bd413a
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto node
)
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto node
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
op_graph_view SRCS op_graph_view.cc DEPS op_handle_base
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper
)
cc_library
(
multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper
)
cc_library
(
multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper
)
cc_library
(
multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper
)
cc_library
(
alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_adam_op_pass SRCS fuse_adam_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
...
@@ -27,7 +17,7 @@ if(WITH_DISTRIBUTE)
...
@@ -27,7 +17,7 @@ if(WITH_DISTRIBUTE)
endif
()
endif
()
endif
()
endif
()
set
(
all_reduce_deps all_reduce_op_handle
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
nv_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
nv_library
(
all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor
)
dynload_cuda variable_visitor
)
...
@@ -37,7 +27,6 @@ if(WITH_GPU)
...
@@ -37,7 +27,6 @@ if(WITH_GPU)
if
(
WITH_DGC
)
if
(
WITH_DGC
)
nv_library
(
sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
nv_library
(
sparse_all_reduce_op_handle SRCS sparse_all_reduce_op_handle.cc DEPS op_handle_base scope
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle
)
lod_tensor ddim memory dynload_cuda variable_visitor dgc all_reduce_op_handle
)
set
(
all_reduce_deps sparse_all_reduce_op_handle
)
endif
()
endif
()
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
...
@@ -68,34 +57,12 @@ endif()
...
@@ -68,34 +57,12 @@ endif()
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
if
(
WITH_GPU
)
cc_library
(
memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info
)
else
()
cc_library
(
memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info
)
endif
()
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass
)
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle
${
all_reduce_deps
}
reduce_op_handle broadcast_op_handle fused_broadcast_op_handle
)
cc_library
(
fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass
)
set
(
SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
list
(
APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass
)
endif
()
endif
()
cc_test
(
memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS
${
SSA_GRAPH_EXECUTOR_DEPS
}
)
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc
已删除
100644 → 0
浏览文件 @
648320bb
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64
(
fuse_parameter_memory_size
,
0
,
// 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size."
);
DEFINE_int32
(
fuse_parameter_groups_size
,
3
,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value."
);
namespace
paddle
{
namespace
framework
{
namespace
details
{
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void
SetFuseParameterGroupsSize
(
int
group_size
)
{
FLAGS_fuse_parameter_groups_size
=
group_size
;
}
int
GetFuseParameterGroupsSize
()
{
return
FLAGS_fuse_parameter_groups_size
;
}
void
SetFuseParameterMemorySize
(
uint64_t
memory_size
)
{
FLAGS_fuse_parameter_memory_size
=
memory_size
;
}
uint64_t
GetFuseParameterMemorySize
()
{
return
FLAGS_fuse_parameter_memory_size
;
}
static
const
char
kUnKnow
[]
=
"@UNKNOW@"
;
static
framework
::
proto
::
VarType
::
Type
kDefaultDtype
=
framework
::
proto
::
VarType
::
Type
::
VarType_Type_BOOL
;
void
AllocContinuousSpaceForGradPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
&
result
=
*
graph
;
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
ResetAttribute
<
ParamsAndGrads
>
(
kParamsAndGrads
,
&
result
);
ResetAttribute
<
GroupGradsAndParams
>
(
kGroupGradsAndParams
,
&
result
);
// NOTE: The operator nodes should be in topology order.
std
::
vector
<
ir
::
Node
*>
topo_nodes
=
ir
::
TopologySortOperations
(
result
);
auto
&
params_grads
=
result
.
Get
<
ParamsAndGrads
>
(
kParamsAndGrads
);
for
(
auto
&
node
:
topo_nodes
)
{
RecordParamsAndGrads
(
node
,
&
params_grads
);
}
if
(
params_grads
.
size
()
==
0
)
{
VLOG
(
10
)
<<
"Doesn't find gradients"
;
return
;
}
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
vars
;
for
(
ir
::
Node
*
node
:
result
.
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars
.
emplace
(
node
->
Var
()
->
Name
(),
node
);
}
}
auto
&
group_grads_params
=
result
.
Get
<
GroupGradsAndParams
>
(
kGroupGradsAndParams
);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams
(
vars
,
params_grads
,
&
group_grads_params
);
params_grads
.
clear
();
for
(
auto
&
group_p_g
:
group_grads_params
)
{
params_grads
.
insert
(
params_grads
.
begin
(),
group_p_g
.
begin
(),
group_p_g
.
end
());
}
for
(
auto
&
p_g
:
params_grads
)
{
std
::
swap
(
p_g
.
first
,
p_g
.
second
);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto
dtype
=
kDefaultDtype
;
for
(
auto
&
p_g
:
params_grads
)
{
// Get gradient var
auto
iter
=
vars
.
find
(
p_g
.
second
);
PADDLE_ENFORCE
(
iter
!=
vars
.
end
(),
"%s is not found."
,
p_g
.
second
);
iter
->
second
->
Var
()
->
SetPersistable
(
true
);
PADDLE_ENFORCE
(
IsSupportedVarType
(
iter
->
second
->
Var
()
->
GetType
()));
// Get Dtype
auto
ele_dtype
=
iter
->
second
->
Var
()
->
GetDataType
();
if
(
dtype
==
kDefaultDtype
)
{
dtype
=
ele_dtype
;
PADDLE_ENFORCE_NE
(
ele_dtype
,
kDefaultDtype
,
"The data type should not be bool."
);
}
PADDLE_ENFORCE_EQ
(
ele_dtype
,
dtype
,
"The data type of input is not consistent."
);
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if
(
!
result
.
Has
(
kFusedVars
))
{
result
.
Set
(
kFusedVars
,
new
FusedVars
);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result
.
Set
(
kFusedGrads
,
new
FusedGrads
);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto
fused_var_name
=
std
::
string
(
kFusedVarNamePrefix
)
+
"@GRAD@"
+
params_grads
.
begin
()
->
second
;
result
.
Get
<
FusedGrads
>
(
kFusedGrads
)
=
fused_var_name
;
auto
&
fused_var_set
=
result
.
Get
<
FusedVars
>
(
kFusedVars
);
PADDLE_ENFORCE_EQ
(
fused_var_set
.
count
(
fused_var_name
),
0
,
"%s is duplicate in FusedVars."
,
fused_var_name
);
fused_var_set
.
insert
(
fused_var_name
);
InitFusedVarsAndAllocSpaceForVars
(
places
,
local_scopes
,
vars
,
fused_var_name
,
params_grads
);
}
template
<
typename
AttrType
>
void
AllocContinuousSpaceForGradPass
::
ResetAttribute
(
const
std
::
string
&
attr_name
,
ir
::
Graph
*
graph
)
const
{
if
(
graph
->
Has
(
attr_name
))
{
VLOG
(
10
)
<<
attr_name
<<
" is reset."
;
graph
->
Erase
(
attr_name
);
}
graph
->
Set
(
attr_name
,
new
AttrType
);
}
void
AllocContinuousSpaceForGradPass
::
SetGroupGradsAndParams
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
const
ParamsAndGrads
&
params_grads
,
GroupGradsAndParams
*
group_grads_params
)
const
{
SetGroupAccordingToLayers
(
var_nodes
,
params_grads
,
group_grads_params
);
SetGroupAccordingToMemorySize
(
var_nodes
,
group_grads_params
);
SetGroupAccordingToGroupSize
(
var_nodes
,
group_grads_params
);
}
void
AllocContinuousSpaceForGradPass
::
SetGroupAccordingToLayers
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
const
ParamsAndGrads
&
params_grads
,
GroupGradsAndParams
*
group_grads_params
)
const
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
layer_params
;
for
(
size_t
i
=
0
;
i
<
params_grads
.
size
();
++
i
)
{
auto
pos
=
params_grads
[
i
].
first
.
find_first_of
(
"."
);
if
(
pos
==
std
::
string
::
npos
)
{
layer_params
[
std
::
string
(
kUnKnow
)].
emplace_back
(
i
);
}
else
{
layer_params
[
params_grads
[
i
].
first
.
substr
(
0
,
pos
)].
emplace_back
(
i
);
}
}
group_grads_params
->
reserve
(
layer_params
.
size
());
for
(
size_t
i
=
0
;
i
<
params_grads
.
size
();
++
i
)
{
auto
pos
=
params_grads
[
i
].
first
.
find_first_of
(
"."
);
std
::
string
key
=
kUnKnow
;
if
(
pos
!=
std
::
string
::
npos
)
{
key
=
params_grads
[
i
].
first
.
substr
(
0
,
pos
);
}
auto
iter
=
layer_params
.
find
(
key
);
if
(
iter
==
layer_params
.
end
())
continue
;
group_grads_params
->
emplace_back
();
auto
&
local_group_grads_params
=
group_grads_params
->
back
();
for
(
auto
&
idx
:
iter
->
second
)
{
local_group_grads_params
.
emplace_back
(
std
::
make_pair
(
params_grads
[
idx
].
second
,
params_grads
[
idx
].
first
));
}
layer_params
.
erase
(
iter
);
}
VLOG
(
10
)
<<
"SetGroupAccordingToLayers: "
;
for
(
size_t
i
=
0
;
i
<
group_grads_params
->
size
();
++
i
)
{
VLOG
(
10
)
<<
"group "
<<
i
;
std
::
stringstream
out
;
for
(
auto
&
p_g
:
group_grads_params
->
at
(
i
))
{
out
<<
"("
<<
p_g
.
second
<<
", "
<<
p_g
.
first
<<
"), "
;
}
VLOG
(
10
)
<<
out
.
str
();
}
}
void
AllocContinuousSpaceForGradPass
::
SetGroupAccordingToMemorySize
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
GroupGradsAndParams
*
group_grads_params
)
const
{
const
uint64_t
group_memory_size
=
GetFuseParameterMemorySize
();
if
(
group_memory_size
==
0
)
{
return
;
}
GroupGradsAndParams
local_group_grads_params
;
size_t
j
=
0
;
while
(
j
<
group_grads_params
->
size
())
{
local_group_grads_params
.
emplace_back
();
auto
&
group_p_g
=
local_group_grads_params
.
back
();
size_t
local_group_memory_size
=
0
;
while
(
j
<
group_grads_params
->
size
())
{
std
::
for_each
(
group_grads_params
->
at
(
j
).
begin
(),
group_grads_params
->
at
(
j
).
end
(),
[
&
local_group_memory_size
,
&
var_nodes
](
const
std
::
pair
<
std
::
string
,
std
::
string
>
&
g_p
)
{
auto
iter
=
var_nodes
.
find
(
g_p
.
second
);
PADDLE_ENFORCE
(
iter
!=
var_nodes
.
end
(),
"%s is not found."
,
g_p
.
second
);
auto
shape
=
iter
->
second
->
Var
()
->
GetShape
();
size_t
size
=
framework
::
SizeOfType
(
iter
->
second
->
Var
()
->
GetDataType
());
std
::
for_each
(
shape
.
begin
(),
shape
.
end
(),
[
&
size
](
const
int64_t
&
n
)
{
size
*=
n
;
});
local_group_memory_size
+=
size
;
});
group_p_g
.
insert
(
group_p_g
.
end
(),
group_grads_params
->
at
(
j
).
begin
(),
group_grads_params
->
at
(
j
).
end
());
++
j
;
if
(
local_group_memory_size
>=
group_memory_size
)
{
break
;
}
}
}
std
::
swap
(
*
group_grads_params
,
local_group_grads_params
);
VLOG
(
10
)
<<
string
::
Sprintf
(
"SetGroupAccordingToMemorySize(memory_size: %d):"
,
group_memory_size
);
for
(
size_t
i
=
0
;
i
<
group_grads_params
->
size
();
++
i
)
{
VLOG
(
10
)
<<
"group "
<<
i
;
std
::
stringstream
out
;
for
(
auto
&
g_p
:
group_grads_params
->
at
(
i
))
{
auto
iter
=
var_nodes
.
find
(
g_p
.
second
);
PADDLE_ENFORCE
(
iter
!=
var_nodes
.
end
(),
"%s is not found."
,
g_p
.
second
);
auto
shape
=
iter
->
second
->
Var
()
->
GetShape
();
size_t
size
=
framework
::
SizeOfType
(
iter
->
second
->
Var
()
->
GetDataType
());
std
::
for_each
(
shape
.
begin
(),
shape
.
end
(),
[
&
size
](
const
int64_t
&
n
)
{
size
*=
n
;
});
out
<<
string
::
Sprintf
(
"(%s(%d), %s)"
,
g_p
.
second
,
size
,
g_p
.
first
);
}
VLOG
(
10
)
<<
out
.
str
();
}
}
void
AllocContinuousSpaceForGradPass
::
SetGroupAccordingToGroupSize
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
GroupGradsAndParams
*
group_grads_params
)
const
{
if
(
GetFuseParameterGroupsSize
()
==
1
)
{
return
;
}
const
int
group_size
=
GetFuseParameterGroupsSize
()
==
-
1
?
static_cast
<
int
>
(
group_grads_params
->
size
())
:
GetFuseParameterGroupsSize
();
PADDLE_ENFORCE_GT
(
group_size
,
1
);
size_t
groups
=
(
group_grads_params
->
size
()
+
group_size
-
1
)
/
group_size
;
GroupGradsAndParams
local_group_grads_params
;
local_group_grads_params
.
reserve
(
groups
);
size_t
j
=
0
;
for
(
size_t
i
=
0
;
i
<
groups
;
++
i
)
{
local_group_grads_params
.
emplace_back
();
auto
&
group_p_g
=
local_group_grads_params
.
back
();
group_p_g
.
reserve
(
group_size
);
while
(
j
<
group_grads_params
->
size
())
{
group_p_g
.
insert
(
group_p_g
.
end
(),
group_grads_params
->
at
(
j
).
begin
(),
group_grads_params
->
at
(
j
).
end
());
++
j
;
if
(
j
%
group_size
==
0
)
break
;
}
}
std
::
swap
(
*
group_grads_params
,
local_group_grads_params
);
VLOG
(
10
)
<<
string
::
Sprintf
(
"SetGroupAccordingToGroupSize(group_size: %d):"
,
group_size
);
for
(
size_t
i
=
0
;
i
<
group_grads_params
->
size
();
++
i
)
{
VLOG
(
10
)
<<
"group "
<<
i
;
std
::
stringstream
out
;
for
(
auto
&
p_g
:
group_grads_params
->
at
(
i
))
{
out
<<
"("
<<
p_g
.
second
<<
", "
<<
p_g
.
first
<<
"), "
;
}
VLOG
(
10
)
<<
out
.
str
();
}
}
bool
AllocContinuousSpaceForGradPass
::
IsSupportedVarType
(
const
proto
::
VarType
::
Type
&
type
)
const
{
// Current only support LOD_TENSOR.
return
type
==
proto
::
VarType
::
LOD_TENSOR
;
}
void
AllocContinuousSpaceForGradPass
::
RecordParamsAndGrads
(
ir
::
Node
*
node
,
ParamsAndGrads
*
params_grads
)
const
{
try
{
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
if
(
!
is_bk_op
)
return
;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
static_cast
<
size_t
>
(
0
));
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
VLOG
(
10
)
<<
"Trainable parameter: "
<<
backward_vars
[
i
]
<<
", gradient: "
<<
backward_vars
[
i
+
1
];
params_grads
->
emplace_back
(
std
::
make_pair
(
backward_vars
[
i
]
/*param*/
,
backward_vars
[
i
+
1
]
/*grad*/
));
}
}
catch
(
boost
::
bad_get
e
)
{
}
}
void
AllocContinuousSpaceForGradPass
::
InitFusedVarsAndAllocSpaceForVars
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
vars
,
const
std
::
string
&
fused_var_name
,
const
ParamsAndGrads
&
params_grads
)
const
{
// Init Gradients and FusedVars
VLOG
(
10
)
<<
"Init FusedVars and Gradients."
;
for
(
auto
it
=
local_scopes
.
rbegin
();
it
!=
local_scopes
.
rend
();
++
it
)
{
auto
&
scope
=
*
it
;
PADDLE_ENFORCE
(
scope
->
FindVar
(
fused_var_name
)
==
nullptr
,
"%s has existed in scope."
,
fused_var_name
);
scope
->
Var
(
fused_var_name
)
->
GetMutable
<
LoDTensor
>
();
for
(
auto
&
p_g
:
params_grads
)
{
auto
iter
=
vars
.
find
(
p_g
.
second
);
PADDLE_ENFORCE
(
iter
!=
vars
.
end
());
PADDLE_ENFORCE_NOT_NULL
(
iter
->
second
->
Var
());
PADDLE_ENFORCE_EQ
(
iter
->
second
->
Var
()
->
GetType
(),
proto
::
VarType
::
LOD_TENSOR
);
scope
->
Var
(
p_g
.
second
)
->
GetMutable
<
LoDTensor
>
();
}
}
// Alloc continuous space for vars.
std
::
vector
<
std
::
string
>
grads_name
;
std
::
vector
<
std
::
string
>
params_name
;
grads_name
.
reserve
(
params_grads
.
size
());
params_name
.
reserve
(
params_grads
.
size
());
for
(
auto
&
p_g
:
params_grads
)
{
params_name
.
emplace_back
(
p_g
.
first
);
grads_name
.
emplace_back
(
p_g
.
second
);
}
framework
::
ProgramDesc
program_desc
;
AppendAllocSpaceForVarsOp
(
params_name
,
grads_name
,
fused_var_name
,
program_desc
.
MutableBlock
(
0
));
for
(
size_t
i
=
0
;
i
<
local_scopes
.
size
();
++
i
)
{
for
(
auto
&
op_desc
:
program_desc
.
Block
(
0
).
AllOps
())
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
op
->
Run
(
*
local_scopes
[
i
],
places
[
i
]);
}
}
}
void
AllocContinuousSpaceForGradPass
::
AppendAllocSpaceForVarsOp
(
const
std
::
vector
<
std
::
string
>
&
params_name
,
const
std
::
vector
<
std
::
string
>
&
grads_name
,
const
std
::
string
&
fused_var_name
,
BlockDesc
*
global_block
)
const
{
auto
op_desc
=
global_block
->
AppendOp
();
op_desc
->
SetType
(
"alloc_continuous_space"
);
op_desc
->
SetInput
(
"Input"
,
params_name
);
op_desc
->
SetOutput
(
"Output"
,
grads_name
);
op_desc
->
SetOutput
(
"FusedOutput"
,
{
fused_var_name
});
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
alloc_continuous_space_for_grad_pass
,
paddle
::
framework
::
details
::
AllocContinuousSpaceForGradPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h
已删除
100644 → 0
浏览文件 @
648320bb
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
void
SetFuseParameterGroupsSize
(
int
group_size
);
int
GetFuseParameterGroupsSize
();
void
SetFuseParameterMemorySize
(
uint64_t
memory_size
);
uint64_t
GetFuseParameterMemorySize
();
class
AllocContinuousSpaceForGradPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
template
<
typename
AttrType
>
void
ResetAttribute
(
const
std
::
string
&
attr_name
,
ir
::
Graph
*
graph
)
const
;
void
SetGroupGradsAndParams
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
const
ParamsAndGrads
&
params_grads
,
GroupGradsAndParams
*
group_grads_params
)
const
;
void
SetGroupAccordingToLayers
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
const
ParamsAndGrads
&
params_grads
,
GroupGradsAndParams
*
group_grads_params
)
const
;
void
SetGroupAccordingToMemorySize
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
GroupGradsAndParams
*
group_grads_params
)
const
;
void
SetGroupAccordingToGroupSize
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
GroupGradsAndParams
*
group_grads_params
)
const
;
private:
bool
IsSupportedVarType
(
const
proto
::
VarType
::
Type
&
type
)
const
;
void
RecordParamsAndGrads
(
ir
::
Node
*
node
,
ParamsAndGrads
*
params_grads
)
const
;
void
InitFusedVarsAndAllocSpaceForVars
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
vars
,
const
std
::
string
&
fused_var_name
,
const
ParamsAndGrads
&
params_grads
)
const
;
void
AppendAllocSpaceForVarsOp
(
const
std
::
vector
<
std
::
string
>
&
params_name
,
const
std
::
vector
<
std
::
string
>
&
grads_name
,
const
std
::
string
&
fused_var_name
,
BlockDesc
*
global_block
)
const
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/build_strategy.cc
浏览文件 @
04bd413a
...
@@ -17,15 +17,14 @@ limitations under the License. */
...
@@ -17,15 +17,14 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <memory>
#include <memory>
#include <utility>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_print_pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -173,10 +172,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
...
@@ -173,10 +172,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
const
std
::
string
graph_path
=
const
std
::
string
graph_path
=
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
string
::
Sprintf
(
"%s%s"
,
strategy_
.
debug_graphviz_path_
.
c_str
(),
"_multi_devices_graph"
);
"_multi_devices_graph"
);
multi_devices_print_pass
->
Set
<
std
::
string
>
(
kGraphvizPath
,
multi_devices_print_pass
->
Set
<
std
::
string
>
(
ir
::
kGraphvizPath
,
new
std
::
string
(
graph_path
));
new
std
::
string
(
graph_path
));
multi_devices_print_pass
->
Set
<
details
::
GraphvizSSAGraphPrinter
>
(
multi_devices_print_pass
->
Set
<
ir
::
GraphvizSSAGraphPrinter
>
(
"graph_printer"
,
new
details
::
GraphvizSSAGraphPrinter
);
"graph_printer"
,
new
ir
::
GraphvizSSAGraphPrinter
);
}
}
// experimental shows that the program will be faster if append
// experimental shows that the program will be faster if append
...
@@ -240,7 +239,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
...
@@ -240,7 +239,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
}
}
bool
BuildStrategy
::
IsMultiDevPass
(
const
std
::
string
&
pass_name
)
const
{
bool
BuildStrategy
::
IsMultiDevPass
(
const
std
::
string
&
pass_name
)
const
{
return
framework
::
details
::
MultiDevSSAGraphBuilder
().
count
(
pass_name
)
>
0
;
return
framework
::
ir
::
MultiDevSSAGraphBuilder
().
count
(
pass_name
)
>
0
;
}
}
ir
::
Graph
*
BuildStrategy
::
Apply
(
ir
::
Graph
*
graph
,
ir
::
Graph
*
BuildStrategy
::
Apply
(
ir
::
Graph
*
graph
,
...
@@ -263,13 +262,13 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -263,13 +262,13 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
if
(
IsMultiDevPass
(
pass
->
Type
()))
{
if
(
IsMultiDevPass
(
pass
->
Type
()))
{
pass
->
Erase
(
kPlaces
);
pass
->
Erase
(
kPlaces
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
SetNotOwned
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
,
&
places
);
pass
->
Erase
(
kLossVarName
);
pass
->
Erase
(
ir
::
kLossVarName
);
pass
->
SetNotOwned
<
const
std
::
string
>
(
kLossVarName
,
&
loss_var_name
);
pass
->
SetNotOwned
<
const
std
::
string
>
(
ir
::
kLossVarName
,
&
loss_var_name
);
pass
->
Erase
(
kLocalScopes
);
pass
->
Erase
(
kLocalScopes
);
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
pass
->
SetNotOwned
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
,
&
local_scopes
);
&
local_scopes
);
pass
->
Erase
(
kNRanks
);
pass
->
Erase
(
ir
::
kNRanks
);
pass
->
Set
<
size_t
>
(
kNRanks
,
new
size_t
(
nranks
));
pass
->
Set
<
size_t
>
(
ir
::
kNRanks
,
new
size_t
(
nranks
));
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
platform
::
NCCLContextMap
*
nctx
=
use_cuda
?
nccl_ctxs
:
nullptr
;
...
@@ -312,8 +311,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
...
@@ -312,8 +311,8 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
continue
;
continue
;
}
}
}
else
if
(
pass
->
Type
()
==
"inplace_pass"
)
{
}
else
if
(
pass
->
Type
()
==
"inplace_pass"
)
{
pass
->
Erase
(
kUseCuda
);
pass
->
Erase
(
ir
::
kUseCuda
);
pass
->
Set
<
bool
>
(
kUseCuda
,
new
bool
(
use_cuda
));
pass
->
Set
<
bool
>
(
ir
::
kUseCuda
,
new
bool
(
use_cuda
));
}
}
VLOG
(
3
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
VLOG
(
3
)
<<
"Start Apply Pass "
<<
pass
->
Type
();
graph
=
pass
->
Apply
(
graph
);
graph
=
pass
->
Apply
(
graph
);
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.cc
浏览文件 @
04bd413a
...
@@ -31,7 +31,7 @@ namespace details {
...
@@ -31,7 +31,7 @@ namespace details {
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
EagerDeletionOpHandle
::
EagerDeletionOpHandle
(
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
ir
::
Node
*
node
,
const
Scope
*
scope
,
const
platform
::
Place
&
place
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
GarbageCollector
*
gc
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
GarbageCollector
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
ir
::
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
:
OpHandleBase
(
node
),
scope_
(
scope
),
scope_
(
scope
),
var_names_
(
var_names
.
begin
(),
var_names
.
end
()),
var_names_
(
var_names
.
begin
(),
var_names
.
end
()),
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.h
浏览文件 @
04bd413a
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/
detail
s/reference_count_pass_helper.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/reference_count_pass_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -34,7 +34,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
...
@@ -34,7 +34,7 @@ class EagerDeletionOpHandle : public OpHandleBase {
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
const
std
::
unordered_set
<
std
::
string
>
&
var_names
,
GarbageCollector
*
gc
,
GarbageCollector
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
);
ir
::
AtomicReferenceCountMap
*
ref_cnts
);
~
EagerDeletionOpHandle
();
~
EagerDeletionOpHandle
();
...
@@ -55,8 +55,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
...
@@ -55,8 +55,8 @@ class EagerDeletionOpHandle : public OpHandleBase {
const
Scope
*
scope_
;
const
Scope
*
scope_
;
std
::
vector
<
std
::
string
>
var_names_
;
std
::
vector
<
std
::
string
>
var_names_
;
GarbageCollector
*
gc_
;
// not own
GarbageCollector
*
gc_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
ir
::
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
platform
::
CUDADeviceContext
*
dev_ctx_
{
nullptr
};
platform
::
CUDADeviceContext
*
dev_ctx_
{
nullptr
};
cudaEvent_t
event_
{
nullptr
};
cudaEvent_t
event_
{
nullptr
};
...
...
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h
已删除
100644 → 0
浏览文件 @
648320bb
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
ModifyOpLockAndRecordEventPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/sequential_execution_pass.cc
已删除
100644 → 0
浏览文件 @
648320bb
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
bool
IsSameOpDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
void
SequentialExecutionPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static
std
::
unordered_set
<
std
::
string
>
skip_dist_ops
{
"send"
,
"recv"
,
"send_barrier"
,
"fetch_barrier"
};
auto
&
ops
=
graph
->
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kStaleProgramOpDescs
);
std
::
vector
<
ir
::
Node
*>
op_node_list
;
op_node_list
.
reserve
(
ops
.
size
());
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
std
::
unordered_set
<
ir
::
Node
*>
ready_ops
;
for
(
ir
::
Node
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
std
::
unordered_set
<
ir
::
Node
*>
preceding_ops
;
for
(
auto
*
in
:
node
->
inputs
)
{
PADDLE_ENFORCE
(
in
->
IsVar
(),
"Preceding Node of Op Nodes must be Var Node"
);
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
preceding_ops
.
insert
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
insert
(
node
);
}
op_deps
[
node
]
=
preceding_ops
.
size
();
if
(
preceding_ops
.
empty
())
{
ready_ops
.
insert
(
node
);
}
}
for
(
auto
*
op_desc
:
ops
)
{
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ready_ops
)
{
if
(
IsSameOpDesc
(
op_desc
,
node
->
Op
()))
{
PADDLE_ENFORCE
(
found_node
==
nullptr
,
"Found multiple op_desc in graph: %s"
,
op_desc
->
Type
());
found_node
=
node
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
found_node
,
"Cannot find op_desc in graph: %s"
,
op_desc
->
Type
());
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
.
at
(
pending_op
)
==
0
)
{
ready_ops
.
insert
(
pending_op
);
}
}
ready_ops
.
erase
(
found_node
);
if
(
skip_dist_ops
.
count
(
op_desc
->
Type
())
==
0
)
{
op_node_list
.
push_back
(
found_node
);
}
}
for
(
size_t
i
=
1
;
i
<
op_node_list
.
size
();
++
i
)
{
auto
*
dep_var
=
graph
->
CreateControlDepVar
();
op_node_list
[
i
]
->
inputs
.
push_back
(
dep_var
);
op_node_list
[
i
-
1
]
->
outputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
op_node_list
[
i
]);
dep_var
->
inputs
.
push_back
(
op_node_list
[
i
-
1
]);
VLOG
(
10
)
<<
"Add dependencies between "
<<
op_node_list
[
i
-
1
]
->
Name
()
<<
" and "
<<
op_node_list
[
i
]
->
Name
();
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
sequential_execution_pass
,
paddle
::
framework
::
details
::
SequentialExecutionPass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kStaleProgramOpDescs
);
paddle/fluid/framework/details/sequential_execution_pass.h
已删除
100644 → 0
浏览文件 @
648320bb
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
SequentialExecutionPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/inplace_op_inference.h
浏览文件 @
04bd413a
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/
detail
s/memory_optimize_helper.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/type_defs.h"
...
...
paddle/fluid/framework/inplace_op_inference_test.cc
浏览文件 @
04bd413a
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/
detail
s/memory_optimize_helper.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -33,7 +33,7 @@ namespace framework {
...
@@ -33,7 +33,7 @@ namespace framework {
std
::
unique_ptr
<
ir
::
Pass
>
CreateInplacePass
()
{
std
::
unique_ptr
<
ir
::
Pass
>
CreateInplacePass
()
{
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"inplace_pass"
);
auto
pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"inplace_pass"
);
pass
->
Set
<
bool
>
(
details
::
kUseCuda
,
new
bool
(
true
));
pass
->
Set
<
bool
>
(
ir
::
kUseCuda
,
new
bool
(
true
));
return
pass
;
return
pass
;
}
}
...
@@ -225,7 +225,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
...
@@ -225,7 +225,7 @@ TEST(InferInplace, SingleOpInplaceInToOut) {
FakeSuccData
(
&
prog
);
FakeSuccData
(
&
prog
);
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
->
Set
(
ir
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
...
@@ -241,7 +241,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
...
@@ -241,7 +241,7 @@ TEST(InferInplace, SingleOpInplaceInToOutNoInplace) {
FakeNoInplaceData
(
&
prog
);
FakeNoInplaceData
(
&
prog
);
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
->
Set
(
ir
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
g
=
test_SingleOpInplaceInToOut
(
std
::
move
(
g
));
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"single_op"
);
...
@@ -274,7 +274,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
...
@@ -274,7 +274,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) {
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
,
1024
,
1024
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
16
,
1024
,
1024
});
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
->
Set
(
ir
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
auto
pass
=
CreateInplacePass
();
auto
pass
=
CreateInplacePass
();
pass
->
Apply
(
g
.
get
());
pass
->
Apply
(
g
.
get
());
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_op"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_op"
);
...
@@ -310,7 +310,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
...
@@ -310,7 +310,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) {
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
15
,
1024
,
1024
});
prog
.
MutableBlock
(
0
)
->
Var
(
"z0"
)
->
SetShape
({
32
,
15
,
1024
,
1024
});
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
std
::
unique_ptr
<
ir
::
Graph
>
g
(
new
ir
::
Graph
(
prog
));
g
->
Set
(
details
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
g
->
Set
(
ir
::
kMemOptSkipVars
,
new
std
::
unordered_set
<
std
::
string
>
());
auto
pass
=
CreateInplacePass
();
auto
pass
=
CreateInplacePass
();
pass
->
Apply
(
g
.
get
());
pass
->
Apply
(
g
.
get
());
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_grad"
);
auto
op_node
=
GetNodeFromGraph
(
g
.
get
(),
"multi_out_grad"
);
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
04bd413a
...
@@ -3,6 +3,9 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
...
@@ -3,6 +3,9 @@ file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeList
file
(
APPEND
${
pass_file
}
"
\#
pragma once
\n
"
)
file
(
APPEND
${
pass_file
}
"
\#
pragma once
\n
"
)
file
(
APPEND
${
pass_file
}
"
\#
include
\"
paddle/fluid/framework/ir/pass.h
\"\n
"
)
file
(
APPEND
${
pass_file
}
"
\#
include
\"
paddle/fluid/framework/ir/pass.h
\"\n
"
)
add_subdirectory
(
fuse_optimizer_ops_pass
)
add_subdirectory
(
memory_optimize_pass
)
add_subdirectory
(
multi_devices_graph_pass
)
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset
(
INFER_IR_PASSES CACHE
)
# clear the global variable
unset
(
INFER_IR_PASSES CACHE
)
# clear the global variable
...
@@ -34,7 +37,6 @@ function(pass_library TARGET DEST)
...
@@ -34,7 +37,6 @@ function(pass_library TARGET DEST)
endif
()
endif
()
endfunction
()
endfunction
()
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
node SRCS node.cc DEPS proto_desc
)
cc_library
(
graph SRCS graph.cc DEPS node pretty_log
)
cc_library
(
graph SRCS graph.cc DEPS node pretty_log
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
cc_library
(
graph_helper SRCS graph_helper.cc DEPS graph
)
...
@@ -43,6 +45,8 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
...
@@ -43,6 +45,8 @@ cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library
(
graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits
)
cc_library
(
graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits
)
cc_library
(
fuse_pass_base SRCS fuse_pass_base.cc DEPS pass
)
cc_library
(
fuse_pass_base SRCS fuse_pass_base.cc DEPS pass
)
cc_library
(
alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper
)
pass_library
(
graph_to_program_pass base
)
pass_library
(
graph_to_program_pass base
)
pass_library
(
graph_viz_pass base
)
pass_library
(
graph_viz_pass base
)
pass_library
(
lock_free_optimize_pass base
)
pass_library
(
lock_free_optimize_pass base
)
...
...
paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc
0 → 100644
浏览文件 @
04bd413a
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
DEFINE_uint64
(
fuse_parameter_memory_size
,
0
,
// 0 KB
"fuse_parameter_memory_size is up limited memory size "
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
"not set group according to memory_size."
);
DEFINE_int32
(
fuse_parameter_groups_size
,
3
,
"fuse_parameter_groups_size is the size of one group parameters' gradient. "
"The default value is a experimental result. If the "
"fuse_parameter_groups_size is 1, it means that the groups size is "
"the number of parameters' gradient. If the fuse_parameter_groups_size is "
"-1, it means that there are only one group. The default value is 3, it is "
"an experimental value."
);
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// SetFuseParameterGroupsSize and SetFuseParameterMemorySize are used in unit
// test, because it is invalid that seting 'FLAGS_fuse_parameter_memory_size'
// and 'FLAGS_fuse_parameter_groups_size' in unit test.
void
SetFuseParameterGroupsSize
(
int
group_size
)
{
FLAGS_fuse_parameter_groups_size
=
group_size
;
}
int
GetFuseParameterGroupsSize
()
{
return
FLAGS_fuse_parameter_groups_size
;
}
void
SetFuseParameterMemorySize
(
uint64_t
memory_size
)
{
FLAGS_fuse_parameter_memory_size
=
memory_size
;
}
uint64_t
GetFuseParameterMemorySize
()
{
return
FLAGS_fuse_parameter_memory_size
;
}
static
const
char
kUnKnow
[]
=
"@UNKNOW@"
;
static
framework
::
proto
::
VarType
::
Type
kDefaultDtype
=
framework
::
proto
::
VarType
::
Type
::
VarType_Type_BOOL
;
class
AllocContinuousSpaceForGradPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
&
result
=
*
graph
;
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
ResetAttribute
<
details
::
ParamsAndGrads
>
(
details
::
kParamsAndGrads
,
&
result
);
ResetAttribute
<
details
::
GroupGradsAndParams
>
(
details
::
kGroupGradsAndParams
,
&
result
);
// NOTE: The operator nodes should be in topology order.
std
::
vector
<
ir
::
Node
*>
topo_nodes
=
ir
::
TopologySortOperations
(
result
);
auto
&
params_grads
=
result
.
Get
<
details
::
ParamsAndGrads
>
(
details
::
kParamsAndGrads
);
for
(
auto
&
node
:
topo_nodes
)
{
RecordParamsAndGrads
(
node
,
&
params_grads
);
}
if
(
params_grads
.
size
()
==
0
)
{
VLOG
(
10
)
<<
"Doesn't find gradients"
;
return
;
}
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
vars
;
for
(
ir
::
Node
*
node
:
result
.
Nodes
())
{
if
(
node
->
IsVar
()
&&
node
->
Var
())
{
// Note: The graph may have the same name node. For example, parameter
// is the input of operator and it also is the output of optimizer;
vars
.
emplace
(
node
->
Var
()
->
Name
(),
node
);
}
}
auto
&
group_grads_params
=
result
.
Get
<
details
::
GroupGradsAndParams
>
(
details
::
kGroupGradsAndParams
);
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
SetGroupGradsAndParams
(
vars
,
params_grads
,
&
group_grads_params
);
params_grads
.
clear
();
for
(
auto
&
group_p_g
:
group_grads_params
)
{
params_grads
.
insert
(
params_grads
.
begin
(),
group_p_g
.
begin
(),
group_p_g
.
end
());
}
for
(
auto
&
p_g
:
params_grads
)
{
std
::
swap
(
p_g
.
first
,
p_g
.
second
);
}
// Set Gradients as Persistable to prevent this var becoming reusable.
auto
dtype
=
kDefaultDtype
;
for
(
auto
&
p_g
:
params_grads
)
{
// Get gradient var
auto
iter
=
vars
.
find
(
p_g
.
second
);
PADDLE_ENFORCE
(
iter
!=
vars
.
end
(),
"%s is not found."
,
p_g
.
second
);
iter
->
second
->
Var
()
->
SetPersistable
(
true
);
PADDLE_ENFORCE
(
IsSupportedVarType
(
iter
->
second
->
Var
()
->
GetType
()));
// Get Dtype
auto
ele_dtype
=
iter
->
second
->
Var
()
->
GetDataType
();
if
(
dtype
==
kDefaultDtype
)
{
dtype
=
ele_dtype
;
PADDLE_ENFORCE_NE
(
ele_dtype
,
kDefaultDtype
,
"The data type should not be bool."
);
}
PADDLE_ENFORCE_EQ
(
ele_dtype
,
dtype
,
"The data type of input is not consistent."
);
}
// Create a FusedVarsSet to avoid duplicating names for fused_var in other
// pass.
if
(
!
result
.
Has
(
details
::
kFusedVars
))
{
result
.
Set
(
details
::
kFusedVars
,
new
details
::
FusedVars
);
}
// the kFusedGrads is used be fuse_optimizer_op_pass.
result
.
Set
(
details
::
kFusedGrads
,
new
details
::
FusedGrads
);
// the fused_var_name should be unique, so it appends
// params_grads.begin()->second.
auto
fused_var_name
=
std
::
string
(
details
::
kFusedVarNamePrefix
)
+
"@GRAD@"
+
params_grads
.
begin
()
->
second
;
result
.
Get
<
details
::
FusedGrads
>
(
details
::
kFusedGrads
)
=
fused_var_name
;
auto
&
fused_var_set
=
result
.
Get
<
details
::
FusedVars
>
(
details
::
kFusedVars
);
PADDLE_ENFORCE_EQ
(
fused_var_set
.
count
(
fused_var_name
),
0
,
"%s is duplicate in FusedVars."
,
fused_var_name
);
fused_var_set
.
insert
(
fused_var_name
);
InitFusedVarsAndAllocSpaceForVars
(
places
,
local_scopes
,
vars
,
fused_var_name
,
params_grads
);
}
template
<
typename
AttrType
>
void
ResetAttribute
(
const
std
::
string
&
attr_name
,
ir
::
Graph
*
graph
)
const
{
if
(
graph
->
Has
(
attr_name
))
{
VLOG
(
10
)
<<
attr_name
<<
" is reset."
;
graph
->
Erase
(
attr_name
);
}
graph
->
Set
(
attr_name
,
new
AttrType
);
}
void
SetGroupGradsAndParams
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
const
details
::
ParamsAndGrads
&
params_grads
,
details
::
GroupGradsAndParams
*
group_grads_params
)
const
{
SetGroupAccordingToLayers
(
var_nodes
,
params_grads
,
group_grads_params
);
SetGroupAccordingToMemorySize
(
var_nodes
,
group_grads_params
);
SetGroupAccordingToGroupSize
(
var_nodes
,
group_grads_params
);
}
void
SetGroupAccordingToLayers
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
const
details
::
ParamsAndGrads
&
params_grads
,
details
::
GroupGradsAndParams
*
group_grads_params
)
const
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
layer_params
;
for
(
size_t
i
=
0
;
i
<
params_grads
.
size
();
++
i
)
{
auto
pos
=
params_grads
[
i
].
first
.
find_first_of
(
"."
);
if
(
pos
==
std
::
string
::
npos
)
{
layer_params
[
std
::
string
(
kUnKnow
)].
emplace_back
(
i
);
}
else
{
layer_params
[
params_grads
[
i
].
first
.
substr
(
0
,
pos
)].
emplace_back
(
i
);
}
}
group_grads_params
->
reserve
(
layer_params
.
size
());
for
(
size_t
i
=
0
;
i
<
params_grads
.
size
();
++
i
)
{
auto
pos
=
params_grads
[
i
].
first
.
find_first_of
(
"."
);
std
::
string
key
=
kUnKnow
;
if
(
pos
!=
std
::
string
::
npos
)
{
key
=
params_grads
[
i
].
first
.
substr
(
0
,
pos
);
}
auto
iter
=
layer_params
.
find
(
key
);
if
(
iter
==
layer_params
.
end
())
continue
;
group_grads_params
->
emplace_back
();
auto
&
local_group_grads_params
=
group_grads_params
->
back
();
for
(
auto
&
idx
:
iter
->
second
)
{
local_group_grads_params
.
emplace_back
(
std
::
make_pair
(
params_grads
[
idx
].
second
,
params_grads
[
idx
].
first
));
}
layer_params
.
erase
(
iter
);
}
VLOG
(
10
)
<<
"SetGroupAccordingToLayers: "
;
for
(
size_t
i
=
0
;
i
<
group_grads_params
->
size
();
++
i
)
{
VLOG
(
10
)
<<
"group "
<<
i
;
std
::
stringstream
out
;
for
(
auto
&
p_g
:
group_grads_params
->
at
(
i
))
{
out
<<
"("
<<
p_g
.
second
<<
", "
<<
p_g
.
first
<<
"), "
;
}
VLOG
(
10
)
<<
out
.
str
();
}
}
void
SetGroupAccordingToMemorySize
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
details
::
GroupGradsAndParams
*
group_grads_params
)
const
{
const
uint64_t
group_memory_size
=
GetFuseParameterMemorySize
();
if
(
group_memory_size
==
0
)
{
return
;
}
details
::
GroupGradsAndParams
local_group_grads_params
;
size_t
j
=
0
;
while
(
j
<
group_grads_params
->
size
())
{
local_group_grads_params
.
emplace_back
();
auto
&
group_p_g
=
local_group_grads_params
.
back
();
size_t
local_group_memory_size
=
0
;
while
(
j
<
group_grads_params
->
size
())
{
std
::
for_each
(
group_grads_params
->
at
(
j
).
begin
(),
group_grads_params
->
at
(
j
).
end
(),
[
&
local_group_memory_size
,
&
var_nodes
](
const
std
::
pair
<
std
::
string
,
std
::
string
>
&
g_p
)
{
auto
iter
=
var_nodes
.
find
(
g_p
.
second
);
PADDLE_ENFORCE
(
iter
!=
var_nodes
.
end
(),
"%s is not found."
,
g_p
.
second
);
auto
shape
=
iter
->
second
->
Var
()
->
GetShape
();
size_t
size
=
framework
::
SizeOfType
(
iter
->
second
->
Var
()
->
GetDataType
());
std
::
for_each
(
shape
.
begin
(),
shape
.
end
(),
[
&
size
](
const
int64_t
&
n
)
{
size
*=
n
;
});
local_group_memory_size
+=
size
;
});
group_p_g
.
insert
(
group_p_g
.
end
(),
group_grads_params
->
at
(
j
).
begin
(),
group_grads_params
->
at
(
j
).
end
());
++
j
;
if
(
local_group_memory_size
>=
group_memory_size
)
{
break
;
}
}
}
std
::
swap
(
*
group_grads_params
,
local_group_grads_params
);
VLOG
(
10
)
<<
string
::
Sprintf
(
"SetGroupAccordingToMemorySize(memory_size: %d):"
,
group_memory_size
);
for
(
size_t
i
=
0
;
i
<
group_grads_params
->
size
();
++
i
)
{
VLOG
(
10
)
<<
"group "
<<
i
;
std
::
stringstream
out
;
for
(
auto
&
g_p
:
group_grads_params
->
at
(
i
))
{
auto
iter
=
var_nodes
.
find
(
g_p
.
second
);
PADDLE_ENFORCE
(
iter
!=
var_nodes
.
end
(),
"%s is not found."
,
g_p
.
second
);
auto
shape
=
iter
->
second
->
Var
()
->
GetShape
();
size_t
size
=
framework
::
SizeOfType
(
iter
->
second
->
Var
()
->
GetDataType
());
std
::
for_each
(
shape
.
begin
(),
shape
.
end
(),
[
&
size
](
const
int64_t
&
n
)
{
size
*=
n
;
});
out
<<
string
::
Sprintf
(
"(%s(%d), %s)"
,
g_p
.
second
,
size
,
g_p
.
first
);
}
VLOG
(
10
)
<<
out
.
str
();
}
}
void
SetGroupAccordingToGroupSize
(
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
var_nodes
,
details
::
GroupGradsAndParams
*
group_grads_params
)
const
{
if
(
GetFuseParameterGroupsSize
()
==
1
)
{
return
;
}
const
int
group_size
=
GetFuseParameterGroupsSize
()
==
-
1
?
static_cast
<
int
>
(
group_grads_params
->
size
())
:
GetFuseParameterGroupsSize
();
PADDLE_ENFORCE_GT
(
group_size
,
1
);
size_t
groups
=
(
group_grads_params
->
size
()
+
group_size
-
1
)
/
group_size
;
details
::
GroupGradsAndParams
local_group_grads_params
;
local_group_grads_params
.
reserve
(
groups
);
size_t
j
=
0
;
for
(
size_t
i
=
0
;
i
<
groups
;
++
i
)
{
local_group_grads_params
.
emplace_back
();
auto
&
group_p_g
=
local_group_grads_params
.
back
();
group_p_g
.
reserve
(
group_size
);
while
(
j
<
group_grads_params
->
size
())
{
group_p_g
.
insert
(
group_p_g
.
end
(),
group_grads_params
->
at
(
j
).
begin
(),
group_grads_params
->
at
(
j
).
end
());
++
j
;
if
(
j
%
group_size
==
0
)
break
;
}
}
std
::
swap
(
*
group_grads_params
,
local_group_grads_params
);
VLOG
(
10
)
<<
string
::
Sprintf
(
"SetGroupAccordingToGroupSize(group_size: %d):"
,
group_size
);
for
(
size_t
i
=
0
;
i
<
group_grads_params
->
size
();
++
i
)
{
VLOG
(
10
)
<<
"group "
<<
i
;
std
::
stringstream
out
;
for
(
auto
&
p_g
:
group_grads_params
->
at
(
i
))
{
out
<<
"("
<<
p_g
.
second
<<
", "
<<
p_g
.
first
<<
"), "
;
}
VLOG
(
10
)
<<
out
.
str
();
}
}
private:
bool
IsSupportedVarType
(
const
proto
::
VarType
::
Type
&
type
)
const
{
// Current only support LOD_TENSOR.
return
type
==
proto
::
VarType
::
LOD_TENSOR
;
}
void
RecordParamsAndGrads
(
ir
::
Node
*
node
,
details
::
ParamsAndGrads
*
params_grads
)
const
{
try
{
bool
is_bk_op
=
static_cast
<
bool
>
(
boost
::
get
<
int
>
(
node
->
Op
()
->
GetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
()))
&
static_cast
<
int
>
(
OpRole
::
kBackward
));
if
(
!
is_bk_op
)
return
;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto
backward_vars
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
OpProtoAndCheckerMaker
::
OpRoleVarAttrName
()));
PADDLE_ENFORCE_EQ
(
backward_vars
.
size
()
%
2
,
static_cast
<
size_t
>
(
0
));
for
(
size_t
i
=
0
;
i
<
backward_vars
.
size
();
i
+=
2
)
{
VLOG
(
10
)
<<
"Trainable parameter: "
<<
backward_vars
[
i
]
<<
", gradient: "
<<
backward_vars
[
i
+
1
];
params_grads
->
emplace_back
(
std
::
make_pair
(
backward_vars
[
i
]
/*param*/
,
backward_vars
[
i
+
1
]
/*grad*/
));
}
}
catch
(
boost
::
bad_get
e
)
{
}
}
void
InitFusedVarsAndAllocSpaceForVars
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
unordered_map
<
std
::
string
,
ir
::
Node
*>
&
vars
,
const
std
::
string
&
fused_var_name
,
const
details
::
ParamsAndGrads
&
params_grads
)
const
{
// Init Gradients and FusedVars
VLOG
(
10
)
<<
"Init FusedVars and Gradients."
;
for
(
auto
it
=
local_scopes
.
rbegin
();
it
!=
local_scopes
.
rend
();
++
it
)
{
auto
&
scope
=
*
it
;
PADDLE_ENFORCE
(
scope
->
FindVar
(
fused_var_name
)
==
nullptr
,
"%s has existed in scope."
,
fused_var_name
);
scope
->
Var
(
fused_var_name
)
->
GetMutable
<
LoDTensor
>
();
for
(
auto
&
p_g
:
params_grads
)
{
auto
iter
=
vars
.
find
(
p_g
.
second
);
PADDLE_ENFORCE
(
iter
!=
vars
.
end
());
PADDLE_ENFORCE_NOT_NULL
(
iter
->
second
->
Var
());
PADDLE_ENFORCE_EQ
(
iter
->
second
->
Var
()
->
GetType
(),
proto
::
VarType
::
LOD_TENSOR
);
scope
->
Var
(
p_g
.
second
)
->
GetMutable
<
LoDTensor
>
();
}
}
// Alloc continuous space for vars.
std
::
vector
<
std
::
string
>
grads_name
;
std
::
vector
<
std
::
string
>
params_name
;
grads_name
.
reserve
(
params_grads
.
size
());
params_name
.
reserve
(
params_grads
.
size
());
for
(
auto
&
p_g
:
params_grads
)
{
params_name
.
emplace_back
(
p_g
.
first
);
grads_name
.
emplace_back
(
p_g
.
second
);
}
framework
::
ProgramDesc
program_desc
;
AppendAllocSpaceForVarsOp
(
params_name
,
grads_name
,
fused_var_name
,
program_desc
.
MutableBlock
(
0
));
for
(
size_t
i
=
0
;
i
<
local_scopes
.
size
();
++
i
)
{
for
(
auto
&
op_desc
:
program_desc
.
Block
(
0
).
AllOps
())
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
op
->
Run
(
*
local_scopes
[
i
],
places
[
i
]);
}
}
}
void
AppendAllocSpaceForVarsOp
(
const
std
::
vector
<
std
::
string
>
&
params_name
,
const
std
::
vector
<
std
::
string
>
&
grads_name
,
const
std
::
string
&
fused_var_name
,
BlockDesc
*
global_block
)
const
{
auto
op_desc
=
global_block
->
AppendOp
();
op_desc
->
SetType
(
"alloc_continuous_space"
);
op_desc
->
SetInput
(
"Input"
,
params_name
);
op_desc
->
SetOutput
(
"Output"
,
grads_name
);
op_desc
->
SetOutput
(
"FusedOutput"
,
{
fused_var_name
});
}
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
alloc_continuous_space_for_grad_pass
,
paddle
::
framework
::
ir
::
AllocContinuousSpaceForGradPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/
details/reference_count
_pass.h
→
paddle/fluid/framework/
ir/alloc_continuous_space_for_grad
_pass.h
浏览文件 @
04bd413a
//
Copyright (c) 2018
PaddlePaddle Authors. All Rights Reserved.
//
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -11,21 +11,19 @@
...
@@ -11,21 +11,19 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#pragma once
#pragma once
#include <algorithm>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
void
SetFuseParameterGroupsSize
(
int
group_size
);
int
GetFuseParameterGroupsSize
();
class
ReferenceCountPass
:
public
ir
::
Pass
{
void
SetFuseParameterMemorySize
(
uint64_t
memory_size
);
protected:
uint64_t
GetFuseParameterMemorySize
();
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/fuse_optimizer_ops_pass/CMakeLists.txt
0 → 100644
浏览文件 @
04bd413a
cc_library
(
fuse_optimizer_op_pass SRCS fuse_optimizer_op_pass.cc DEPS graph graph_helper
)
cc_library
(
fuse_adam_op_pass SRCS fuse_adam_op_pass.cc DEPS fuse_optimizer_op_pass
)
cc_library
(
fuse_sgd_op_pass SRCS fuse_sgd_op_pass.cc DEPS fuse_optimizer_op_pass
)
cc_library
(
fuse_momentum_op_pass SRCS fuse_momentum_op_pass.cc DEPS fuse_optimizer_op_pass
)
paddle/fluid/framework/
detail
s/fuse_adam_op_pass.cc
→
paddle/fluid/framework/
ir/fuse_optimizer_ops_pas
s/fuse_adam_op_pass.cc
浏览文件 @
04bd413a
...
@@ -16,16 +16,13 @@
...
@@ -16,16 +16,13 @@
#include <unordered_map>
#include <unordered_map>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
FuseAdamOpPass
:
public
FuseOptimizerOpPass
{
class
FuseAdamOpPass
:
public
FuseOptimizerOpPass
{
private:
private:
...
@@ -203,10 +200,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
...
@@ -203,10 +200,10 @@ class FuseAdamOpPass : public FuseOptimizerOpPass {
}
}
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
details
::
FuseAdamOpPass
)
REGISTER_PASS
(
fuse_adam_op_pass
,
paddle
::
framework
::
ir
::
FuseAdamOpPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/
detail
s/fuse_momentum_op_pass.cc
→
paddle/fluid/framework/
ir/fuse_optimizer_ops_pas
s/fuse_momentum_op_pass.cc
浏览文件 @
04bd413a
...
@@ -16,14 +16,13 @@
...
@@ -16,14 +16,13 @@
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
FuseMomentumOpPass
:
public
FuseOptimizerOpPass
{
class
FuseMomentumOpPass
:
public
FuseOptimizerOpPass
{
private:
private:
...
@@ -84,11 +83,10 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
...
@@ -84,11 +83,10 @@ class FuseMomentumOpPass : public FuseOptimizerOpPass {
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_momentum_op_pass
,
REGISTER_PASS
(
fuse_momentum_op_pass
,
paddle
::
framework
::
ir
::
FuseMomentumOpPass
)
paddle
::
framework
::
details
::
FuseMomentumOpPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/
detail
s/fuse_optimizer_op_pass.cc
→
paddle/fluid/framework/
ir/fuse_optimizer_ops_pas
s/fuse_optimizer_op_pass.cc
浏览文件 @
04bd413a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/
detail
s/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/
ir/fuse_optimizer_ops_pas
s/fuse_optimizer_op_pass.h"
#include <algorithm>
#include <algorithm>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
...
@@ -20,13 +20,13 @@
...
@@ -20,13 +20,13 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
void
FuseOptimizerOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
FuseOptimizerOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
ir
::
Graph
&
result
=
*
graph
;
ir
::
Graph
&
result
=
*
graph
;
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
const
std
::
string
fuse_op_type
=
GetOpType
();
const
std
::
string
fuse_op_type
=
GetOpType
();
std
::
vector
<
std
::
string
>
aux_var_names
=
GetAuxiliaryVarNames
();
std
::
vector
<
std
::
string
>
aux_var_names
=
GetAuxiliaryVarNames
();
...
@@ -47,24 +47,24 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -47,24 +47,24 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
return
;
return
;
}
}
if
(
result
.
Has
(
kFusedOptType
))
{
if
(
result
.
Has
(
details
::
kFusedOptType
))
{
VLOG
(
6
)
<<
"Currently only support fusing one type optimizer op. Has fused "
VLOG
(
6
)
<<
"Currently only support fusing one type optimizer op. Has fused "
<<
result
.
Get
<
FusedOptType
>
(
kFusedOptType
);
<<
result
.
Get
<
details
::
FusedOptType
>
(
details
::
kFusedOptType
);
return
;
return
;
}
else
{
}
else
{
result
.
Set
(
kFusedOptType
,
new
FusedOptType
);
result
.
Set
(
details
::
kFusedOptType
,
new
details
::
FusedOptType
);
}
}
result
.
Get
<
FusedOptType
>
(
kFusedOptType
)
=
fuse_op_type
;
result
.
Get
<
details
::
FusedOptType
>
(
details
::
kFusedOptType
)
=
fuse_op_type
;
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// Step 2: Insert fused_var_name to FusedVars, and the FusedVars need be
// initialized in scopes before execution.
// initialized in scopes before execution.
if
(
!
result
.
Has
(
kFusedVars
))
{
if
(
!
result
.
Has
(
details
::
kFusedVars
))
{
result
.
Set
(
kFusedVars
,
new
FusedVars
);
result
.
Set
(
details
::
kFusedVars
,
new
details
::
FusedVars
);
}
}
std
::
unordered_map
<
std
::
string
,
std
::
string
>
fused_vars_name
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
fused_vars_name
;
fused_vars_name
.
reserve
(
aux_var_names
.
size
());
fused_vars_name
.
reserve
(
aux_var_names
.
size
());
auto
&
fused_var_set
=
result
.
Get
<
FusedVars
>
(
kFusedVars
);
auto
&
fused_var_set
=
result
.
Get
<
details
::
FusedVars
>
(
details
::
kFusedVars
);
const
std
::
string
prefix
(
kFusedVarNamePrefix
);
const
std
::
string
prefix
(
details
::
kFusedVarNamePrefix
);
// NOTE: the fused_var_name should be unique.
// NOTE: the fused_var_name should be unique.
for
(
auto
&
var_name
:
aux_var_names
)
{
for
(
auto
&
var_name
:
aux_var_names
)
{
auto
fused_var_name
=
prefix
+
"_"
+
fuse_op_type
+
"_"
+
var_name
+
"_"
+
auto
fused_var_name
=
prefix
+
"_"
+
fuse_op_type
+
"_"
+
var_name
+
"_"
+
...
@@ -77,8 +77,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -77,8 +77,9 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// Step 3: Get the fused Gradient's name
// Step 3: Get the fused Gradient's name
bool
grad_fused
=
false
;
bool
grad_fused
=
false
;
if
(
result
.
Has
(
kParamsAndGrads
))
{
if
(
result
.
Has
(
details
::
kParamsAndGrads
))
{
auto
&
params_grads
=
result
.
Get
<
ParamsAndGrads
>
(
kParamsAndGrads
);
auto
&
params_grads
=
result
.
Get
<
details
::
ParamsAndGrads
>
(
details
::
kParamsAndGrads
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
params_grads
.
size
(),
aux_var_set
.
at
(
kGrad
).
size
(),
params_grads
.
size
(),
aux_var_set
.
at
(
kGrad
).
size
(),
"The number of gradients and optimizer ops is not equal."
);
"The number of gradients and optimizer ops is not equal."
);
...
@@ -94,13 +95,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -94,13 +95,13 @@ void FuseOptimizerOpPass::ApplyImpl(ir::Graph *graph) const {
// NOTE(zcd): the gradient of kParamsAndGrads may be different with the
// NOTE(zcd): the gradient of kParamsAndGrads may be different with the
// kGrad.
// kGrad.
if
(
same_grad_num
==
aux_var_set
.
at
(
kGrad
).
size
())
{
if
(
same_grad_num
==
aux_var_set
.
at
(
kGrad
).
size
())
{
if
(
!
result
.
Has
(
kFusedGrads
))
{
if
(
!
result
.
Has
(
details
::
kFusedGrads
))
{
PADDLE_THROW
(
PADDLE_THROW
(
"The alloc_continuous_space_for_grad_pass should be called before "
"The alloc_continuous_space_for_grad_pass should be called before "
"this pass."
);
"this pass."
);
}
}
auto
&
fused_grad
=
result
.
Get
<
FusedGrads
>
(
kFusedGrads
);
auto
&
fused_grad
=
result
.
Get
<
details
::
FusedGrads
>
(
details
::
kFusedGrads
);
auto
&
fused_vars
=
result
.
Get
<
FusedVars
>
(
kFusedVars
);
auto
&
fused_vars
=
result
.
Get
<
details
::
FusedVars
>
(
details
::
kFusedVars
);
auto
iter
=
std
::
find
(
fused_vars
.
begin
(),
fused_vars
.
end
(),
fused_grad
);
auto
iter
=
std
::
find
(
fused_vars
.
begin
(),
fused_vars
.
end
(),
fused_grad
);
PADDLE_ENFORCE
(
iter
!=
fused_vars
.
end
(),
"Not find the fused_grad."
);
PADDLE_ENFORCE
(
iter
!=
fused_vars
.
end
(),
"Not find the fused_grad."
);
fused_vars_name
[
kGrad
]
=
fused_grad
;
fused_vars_name
[
kGrad
]
=
fused_grad
;
...
@@ -323,6 +324,6 @@ void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
...
@@ -323,6 +324,6 @@ void FuseOptimizerOpPass::InserInputAndOutputForOptOps(
opt_node
->
outputs
.
insert
(
opt_node
->
outputs
.
begin
(),
outputs
.
begin
(),
opt_node
->
outputs
.
insert
(
opt_node
->
outputs
.
begin
(),
outputs
.
begin
(),
outputs
.
end
());
outputs
.
end
());
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/fuse_optimizer_op_pass.h
→
paddle/fluid/framework/
ir/fuse_optimizer_ops_pas
s/fuse_optimizer_op_pass.h
浏览文件 @
04bd413a
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
constexpr
char
kGrad
[]
=
"Grad"
;
constexpr
char
kGrad
[]
=
"Grad"
;
constexpr
char
kParam
[]
=
"Param"
;
constexpr
char
kParam
[]
=
"Param"
;
...
@@ -90,6 +90,6 @@ class FuseOptimizerOpPass : public ir::Pass {
...
@@ -90,6 +90,6 @@ class FuseOptimizerOpPass : public ir::Pass {
const
std
::
string
&
fused_var_name
)
const
;
const
std
::
string
&
fused_var_name
)
const
;
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/fuse_sgd_op_pass.cc
→
paddle/fluid/framework/
ir/fuse_optimizer_ops_pas
s/fuse_sgd_op_pass.cc
浏览文件 @
04bd413a
...
@@ -14,18 +14,13 @@
...
@@ -14,18 +14,13 @@
#include <algorithm>
#include <algorithm>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
FuseSgdOpPass
:
public
FuseOptimizerOpPass
{
class
FuseSgdOpPass
:
public
FuseOptimizerOpPass
{
private:
private:
...
@@ -66,10 +61,10 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
...
@@ -66,10 +61,10 @@ class FuseSgdOpPass : public FuseOptimizerOpPass {
InserInputAndOutputForOptOps
(
sgd_ops
,
sgd_node
);
InserInputAndOutputForOptOps
(
sgd_ops
,
sgd_node
);
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
details
::
FuseSgdOpPass
)
REGISTER_PASS
(
fuse_sgd_op_pass
,
paddle
::
framework
::
ir
::
FuseSgdOpPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLocalScopes
);
paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt
0 → 100644
浏览文件 @
04bd413a
cc_library
(
op_graph_view SRCS op_graph_view.cc DEPS op_handle_base
)
cc_library
(
while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper
)
if
(
WITH_GPU
)
cc_library
(
memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper gpu_info
)
else
()
cc_library
(
memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper cpu_info
)
endif
()
cc_library
(
memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass
)
cc_library
(
inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info
)
cc_test
(
memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper
)
cc_library
(
record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper
)
paddle/fluid/framework/
detail
s/eager_deletion_pass.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/eager_deletion_pass.cc
浏览文件 @
04bd413a
...
@@ -27,11 +27,11 @@
...
@@ -27,11 +27,11 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
// op -> variables which can be deleted after op runs
// op -> variables which can be deleted after op runs
using
OpToVarNameSetMap
=
using
OpToVarNameSetMap
=
std
::
unordered_map
<
details
::
ComputationOpHandle
*
,
std
::
unordered_map
<
ComputationOpHandle
*
,
std
::
unordered_set
<
std
::
string
>>
;
std
::
unordered_set
<
std
::
string
>>
;
static
std
::
map
<
size_t
,
std
::
unordered_set
<
std
::
string
>>
VarsGroupByScopeIdx
(
static
std
::
map
<
size_t
,
std
::
unordered_set
<
std
::
string
>>
VarsGroupByScopeIdx
(
const
OpToVarNameSetMap
&
map
)
{
const
OpToVarNameSetMap
&
map
)
{
...
@@ -53,7 +53,8 @@ static bool IsLoDTensor(VarDesc *var) {
...
@@ -53,7 +53,8 @@ static bool IsLoDTensor(VarDesc *var) {
// Get memory size of LoDTensor
// Get memory size of LoDTensor
static
int64_t
GetMemorySize
(
static
int64_t
GetMemorySize
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandle
*>>
&
vars
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
details
::
VarHandle
*>>
&
vars
,
const
std
::
string
&
var_name
)
{
const
std
::
string
&
var_name
)
{
auto
*
var_desc
=
TryGetLatestVarDesc
(
vars
.
at
(
var_name
));
auto
*
var_desc
=
TryGetLatestVarDesc
(
vars
.
at
(
var_name
));
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
...
@@ -69,13 +70,13 @@ static int64_t GetMemorySize(
...
@@ -69,13 +70,13 @@ static int64_t GetMemorySize(
// Since partial GC is based on static analysis of memory size of each variable
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
// So we should skip SelectedRows and LoDTensorArray here
static
void
SplitIntoLoDTensorAndNonLoDTensorVars
(
static
void
SplitIntoLoDTensorAndNonLoDTensorVars
(
const
OpToVarNameSetMap
&
m
,
const
GraphVars
&
vars
,
const
OpToVarNameSetMap
&
m
,
const
details
::
GraphVars
&
vars
,
OpToVarNameSetMap
*
lod_tensors
,
OpToVarNameSetMap
*
other_vars
)
{
OpToVarNameSetMap
*
lod_tensors
,
OpToVarNameSetMap
*
other_vars
)
{
lod_tensors
->
clear
();
lod_tensors
->
clear
();
other_vars
->
clear
();
other_vars
->
clear
();
for
(
auto
&
op_vars_pair
:
m
)
{
for
(
auto
&
op_vars_pair
:
m
)
{
for
(
auto
&
var_name
:
op_vars_pair
.
second
)
{
for
(
auto
var_name
:
op_vars_pair
.
second
)
{
auto
*
var_desc
=
TryGetLatestVarDesc
(
auto
*
var_desc
=
TryGetLatestVarDesc
(
vars
[
op_vars_pair
.
first
->
GetScopeIdx
()].
at
(
var_name
));
vars
[
op_vars_pair
.
first
->
GetScopeIdx
()].
at
(
var_name
));
if
(
IsLoDTensor
(
var_desc
))
{
if
(
IsLoDTensor
(
var_desc
))
{
...
@@ -89,23 +90,24 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars(
...
@@ -89,23 +90,24 @@ static void SplitIntoLoDTensorAndNonLoDTensorVars(
struct
GCVarInfo
{
struct
GCVarInfo
{
GCVarInfo
(
const
std
::
string
&
name
,
int64_t
memory_size
,
GCVarInfo
(
const
std
::
string
&
name
,
int64_t
memory_size
,
ComputationOpHandle
*
op
,
size_t
scope_idx
)
details
::
ComputationOpHandle
*
op
,
size_t
scope_idx
)
:
name_
(
name
),
:
name_
(
name
),
memory_size_
(
memory_size
),
memory_size_
(
memory_size
),
op_
(
op
),
op_
(
op
),
scope_idx_
(
scope_idx
)
{}
scope_idx_
(
scope_idx
)
{}
std
::
string
name_
;
// variable name
std
::
string
name_
;
// variable name
int64_t
memory_size_
;
// memory size
int64_t
memory_size_
;
// memory size
ComputationOpHandle
*
op_
;
// op after which the variable could be deleted
details
::
ComputationOpHandle
size_t
scope_idx_
;
// scope index where the variable locates
*
op_
;
// op after which the variable could be deleted
size_t
scope_idx_
;
// scope index where the variable locates
int64_t
AbsMemorySize
()
const
{
return
std
::
abs
(
memory_size_
);
}
int64_t
AbsMemorySize
()
const
{
return
std
::
abs
(
memory_size_
);
}
};
};
// Delete delete_lod_tensor_only is not used currently
// Delete delete_lod_tensor_only is not used currently
static
OpToVarNameSetMap
ShrinkGCVars
(
static
OpToVarNameSetMap
ShrinkGCVars
(
const
OpToVarNameSetMap
&
m
,
const
GraphVars
&
vars
,
const
OpToVarNameSetMap
&
m
,
const
details
::
GraphVars
&
vars
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
double
fraction_of_memory_size
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
double
fraction_of_memory_size
,
bool
delete_lod_tensor_only
=
false
)
{
bool
delete_lod_tensor_only
=
false
)
{
// Do not perform gc when fraction_of_memory_size = 0
// Do not perform gc when fraction_of_memory_size = 0
...
@@ -192,7 +194,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -192,7 +194,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE
(
ref_cnts
.
empty
(),
PADDLE_ENFORCE
(
ref_cnts
.
empty
(),
"kRuntimeReferenceCount should be initialized here!"
);
"kRuntimeReferenceCount should be initialized here!"
);
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
const
auto
&
vars
=
graph
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
);
ref_cnts
.
resize
(
vars
.
size
());
ref_cnts
.
resize
(
vars
.
size
());
const
auto
&
last_live_ops
=
const
auto
&
last_live_ops
=
...
@@ -222,27 +224,31 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -222,27 +224,31 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto
*
eager_deletion_node
=
auto
*
eager_deletion_node
=
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
eager_deletion_op
=
new
EagerDeletionOpHandle
(
auto
*
eager_deletion_op
=
new
details
::
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
var_names
,
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
var_names
,
gcs
.
at
(
places
[
op
->
GetScopeIdx
()]).
get
(),
gcs
.
at
(
places
[
op
->
GetScopeIdx
()]).
get
(),
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
op
->
Outputs
().
begin
(),
op
->
Outputs
().
end
(),
[](
VarHandleBase
*
var
)
{
op
->
Outputs
().
begin
(),
op
->
Outputs
().
end
(),
return
dynamic_cast
<
DummyVarHandle
*>
(
var
)
!=
nullptr
;
[](
details
::
VarHandleBase
*
var
)
{
return
dynamic_cast
<
details
::
DummyVarHandle
*>
(
var
)
!=
nullptr
;
});
});
if
(
it
!=
op
->
Outputs
().
end
())
{
if
(
it
!=
op
->
Outputs
().
end
())
{
eager_deletion_op
->
AddInput
(
*
it
);
eager_deletion_op
->
AddInput
(
*
it
);
}
else
{
}
else
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dep_var
=
new
details
::
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
)
.
emplace
(
dep_var
);
op
->
AddOutput
(
dep_var
);
op
->
AddOutput
(
dep_var
);
eager_deletion_op
->
AddInput
(
dep_var
);
eager_deletion_op
->
AddInput
(
dep_var
);
}
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dummy_leaf
=
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
new
details
::
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
)
.
emplace
(
dummy_leaf
);
eager_deletion_op
->
AddOutput
(
dummy_leaf
);
eager_deletion_op
->
AddOutput
(
dummy_leaf
);
}
}
...
@@ -262,15 +268,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -262,15 +268,14 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
while_op_eager_deletion_pass
->
Apply
(
graph
);
while_op_eager_deletion_pass
->
Apply
(
graph
);
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
eager_deletion_pass
,
REGISTER_PASS
(
eager_deletion_pass
,
paddle
::
framework
::
ir
::
EagerDeletionPass
)
paddle
::
framework
::
details
::
EagerDeletionPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kRuntimeReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kRuntimeReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kLastLiveOpsOfVars
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kAllPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kAllPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGarbageCollector
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
USE_PASS
(
while_op_eager_deletion_pass
);
USE_PASS
(
while_op_eager_deletion_pass
);
paddle/fluid/framework/
detail
s/inplace_op_pass.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/inplace_op_pass.cc
浏览文件 @
04bd413a
...
@@ -16,9 +16,9 @@
...
@@ -16,9 +16,9 @@
#include <queue>
#include <queue>
#include <string>
#include <string>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
...
@@ -52,7 +52,7 @@ DECLARE_string(memory_optimize_debug);
...
@@ -52,7 +52,7 @@ DECLARE_string(memory_optimize_debug);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
// clang-format off
// clang-format off
const
std
::
string
kInplacedOpWhiteList
[]
=
{
// NOLINT
const
std
::
string
kInplacedOpWhiteList
[]
=
{
// NOLINT
...
@@ -199,8 +199,8 @@ bool InplacePass::CheckOpDeps(ir::Node *op,
...
@@ -199,8 +199,8 @@ bool InplacePass::CheckOpDeps(ir::Node *op,
void
InplacePass
::
CollectSkipVars
(
ir
::
Graph
*
graph
,
void
InplacePass
::
CollectSkipVars
(
ir
::
Graph
*
graph
,
const
std
::
vector
<
ir
::
Node
*>
&
ops
)
const
{
const
std
::
vector
<
ir
::
Node
*>
&
ops
)
const
{
// 1. Collect op role vars
// 1. Collect op role vars
PADDLE_ENFORCE
(
graph
->
Has
(
details
::
kMemOptSkipVars
)
,
PADDLE_ENFORCE
(
graph
->
Has
(
kMemOptSkipVars
),
"Graph should have attr %s"
,
"Graph should have attr %s"
,
details
::
kMemOptSkipVars
);
kMemOptSkipVars
);
auto
&
mem_opt_whitelist
=
graph
->
Get
<
MemOptSkipVars
>
(
kMemOptSkipVars
);
auto
&
mem_opt_whitelist
=
graph
->
Get
<
MemOptSkipVars
>
(
kMemOptSkipVars
);
for
(
const
auto
&
var
:
mem_opt_whitelist
)
{
for
(
const
auto
&
var
:
mem_opt_whitelist
)
{
skip_vars_
.
emplace
(
var
);
skip_vars_
.
emplace
(
var
);
...
@@ -452,8 +452,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -452,8 +452,7 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
continue
;
continue
;
}
}
if
(
details
::
NodeSize
(
*
in_node
->
Var
())
!=
if
(
NodeSize
(
*
in_node
->
Var
())
!=
NodeSize
(
*
out_node
->
Var
())
&&
details
::
NodeSize
(
*
out_node
->
Var
())
&&
kSameShapeOpWhiteSet
.
count
(
op_desc
->
Type
())
==
0
)
{
kSameShapeOpWhiteSet
.
count
(
op_desc
->
Type
())
==
0
)
{
VLOG
(
4
)
<<
"Cannot inplace because Input("
<<
in_param
<<
")="
<<
in_arg
VLOG
(
4
)
<<
"Cannot inplace because Input("
<<
in_param
<<
")="
<<
in_arg
<<
" is not the same size with "
<<
" is not the same size with "
...
@@ -476,9 +475,9 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
...
@@ -476,9 +475,9 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
}
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
inplace_pass
,
paddle
::
framework
::
details
::
InplacePass
)
REGISTER_PASS
(
inplace_pass
,
paddle
::
framework
::
ir
::
InplacePass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kUseCuda
);
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kUseCuda
);
paddle/fluid/framework/
detail
s/memory_optimize_helper.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_helper.cc
浏览文件 @
04bd413a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/
detail
s/memory_optimize_helper.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_helper.h"
#include <algorithm>
#include <algorithm>
#include <deque>
#include <deque>
#include <functional>
#include <functional>
...
@@ -32,14 +32,15 @@
...
@@ -32,14 +32,15 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
using
paddle
::
framework
::
VarDesc
;
using
paddle
::
framework
::
VarDesc
;
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
std
::
vector
<
ir
::
Node
*>
SortOpLikeDescOrder
(
const
ir
::
Graph
&
graph
)
{
PADDLE_ENFORCE
(
graph
.
Has
(
kStaleProgramOpDescs
),
PADDLE_ENFORCE
(
graph
.
Has
(
details
::
kStaleProgramOpDescs
),
"Graph has no attribute of kStaleProgramOpDescs."
);
"Graph has no attribute of kStaleProgramOpDescs."
);
// 1. get op desc order
// 1. get op desc order
auto
&
op_descs
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kStaleProgramOpDescs
);
auto
&
op_descs
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
details
::
kStaleProgramOpDescs
);
// 2. topology sort order
// 2. topology sort order
auto
nodes
=
graph
.
Nodes
();
auto
nodes
=
graph
.
Nodes
();
...
@@ -563,6 +564,6 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
...
@@ -563,6 +564,6 @@ ir::Node* ControlFlowGraph::GetNodeByName(const std::string& name,
return
found_node
;
return
found_node
;
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/memory_optimize_helper.h
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_helper.h
浏览文件 @
04bd413a
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
/// this attribute is used to avoid some core variables removed/reused
/// this attribute is used to avoid some core variables removed/reused
/// in memory optimize related passes
/// in memory optimize related passes
...
@@ -184,6 +184,6 @@ void FilterVariables(const Container& nodes, Callback callback) {
...
@@ -184,6 +184,6 @@ void FilterVariables(const Container& nodes, Callback callback) {
FilterVariableImpl
<
Container
,
Callback
>
()(
nodes
,
callback
);
FilterVariableImpl
<
Container
,
Callback
>
()(
nodes
,
callback
);
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/memory_optimize_helper_test.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_helper_test.cc
浏览文件 @
04bd413a
...
@@ -11,8 +11,7 @@
...
@@ -11,8 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
#include <iterator>
#include <iterator>
...
@@ -32,7 +31,7 @@
...
@@ -32,7 +31,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
TEST
(
OrderedSet
,
Normal
)
{
TEST
(
OrderedSet
,
Normal
)
{
OrderedSet
pool
;
OrderedSet
pool
;
...
@@ -153,7 +152,7 @@ TEST(OrderedSet, FindBestFitNode) {
...
@@ -153,7 +152,7 @@ TEST(OrderedSet, FindBestFitNode) {
ASSERT_TRUE
(
cache
==
nullptr
);
ASSERT_TRUE
(
cache
==
nullptr
);
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
@@ -188,7 +187,7 @@ REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
...
@@ -188,7 +187,7 @@ REGISTER_OPERATOR(dummy, paddle::framework::DummyOp,
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
inline
static
ProgramDesc
FillProgramDesc
()
{
inline
static
ProgramDesc
FillProgramDesc
()
{
ProgramDesc
prog
;
ProgramDesc
prog
;
...
@@ -521,6 +520,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
...
@@ -521,6 +520,6 @@ TEST(SortOpLikeDescOrder, AddAndReplaceOpDescInplace) {
}
}
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/memory_optimize_pass.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_pass.cc
浏览文件 @
04bd413a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/
detail
s/memory_optimize_pass.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_pass.h"
#include <algorithm>
#include <algorithm>
#include <atomic>
#include <atomic>
#include <deque>
#include <deque>
...
@@ -42,12 +42,12 @@ DEFINE_string(memory_optimize_debug, "",
...
@@ -42,12 +42,12 @@ DEFINE_string(memory_optimize_debug, "",
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
void
MemoryOptimizePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
MemoryOptimizePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
CollectSkipVarsSet
(
graph
);
CollectSkipVarsSet
(
graph
);
cfg_
.
reset
(
new
details
::
ControlFlowGraph
(
*
graph
));
cfg_
.
reset
(
new
ControlFlowGraph
(
*
graph
));
cfg_
->
LiveVariableAnalysis
();
cfg_
->
LiveVariableAnalysis
();
InitSSAGraphNodes
();
InitSSAGraphNodes
();
...
@@ -205,7 +205,7 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
...
@@ -205,7 +205,7 @@ void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
void
MemoryOptimizePass
::
CollectSkipVarsSet
(
ir
::
Graph
*
graph
)
const
{
void
MemoryOptimizePass
::
CollectSkipVarsSet
(
ir
::
Graph
*
graph
)
const
{
// fill skip_set_
// fill skip_set_
PADDLE_ENFORCE
(
graph
->
Has
(
details
::
kMemOptSkipVars
));
PADDLE_ENFORCE
(
graph
->
Has
(
kMemOptSkipVars
));
auto
&
mem_opt_whitelist
=
graph
->
Get
<
MemOptSkipVars
>
(
kMemOptSkipVars
);
auto
&
mem_opt_whitelist
=
graph
->
Get
<
MemOptSkipVars
>
(
kMemOptSkipVars
);
for
(
const
auto
&
var
:
mem_opt_whitelist
)
{
for
(
const
auto
&
var
:
mem_opt_whitelist
)
{
skip_set_
.
emplace
(
var
);
skip_set_
.
emplace
(
var
);
...
@@ -316,10 +316,9 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
...
@@ -316,10 +316,9 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
}
}
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
memory_optimize_pass
,
REGISTER_PASS
(
memory_optimize_pass
,
paddle
::
framework
::
ir
::
MemoryOptimizePass
)
paddle
::
framework
::
details
::
MemoryOptimizePass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kStaleProgramOpDescs
);
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kStaleProgramOpDescs
);
paddle/fluid/framework/
detail
s/memory_optimize_pass.h
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_pass.h
浏览文件 @
04bd413a
...
@@ -26,13 +26,13 @@
...
@@ -26,13 +26,13 @@
#include <vector>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
MemoryOptimizePass
:
public
ir
::
Pass
{
class
MemoryOptimizePass
:
public
ir
::
Pass
{
protected:
protected:
...
@@ -67,6 +67,6 @@ class MemoryOptimizePass : public ir::Pass {
...
@@ -67,6 +67,6 @@ class MemoryOptimizePass : public ir::Pass {
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
mutable
std
::
map
<
std
::
string
,
std
::
vector
<
ir
::
Node
*>>
var_nodes_
;
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/op_graph_view.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/op_graph_view.cc
浏览文件 @
04bd413a
...
@@ -12,17 +12,19 @@
...
@@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/
detail
s/op_graph_view.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/op_graph_view.h"
#include <queue>
#include <queue>
#include <utility>
#include <utility>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
OpGraphView
::
OpGraphView
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
{
Build
(
ops
);
}
OpGraphView
::
OpGraphView
(
const
std
::
vector
<
details
::
OpHandleBase
*>
&
ops
)
{
Build
(
ops
);
}
void
OpGraphView
::
Build
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
{
void
OpGraphView
::
Build
(
const
std
::
vector
<
details
::
OpHandleBase
*>
&
ops
)
{
preceding_ops_
.
clear
();
preceding_ops_
.
clear
();
pending_ops_
.
clear
();
pending_ops_
.
clear
();
for
(
auto
&
op
:
ops
)
{
for
(
auto
&
op
:
ops
)
{
...
@@ -40,8 +42,8 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
...
@@ -40,8 +42,8 @@ void OpGraphView::Build(const std::vector<OpHandleBase *> &ops) {
"There are duplicate ops in graph."
);
"There are duplicate ops in graph."
);
}
}
std
::
unordered_set
<
OpHandleBase
*>
OpGraphView
::
AllOps
()
const
{
std
::
unordered_set
<
details
::
OpHandleBase
*>
OpGraphView
::
AllOps
()
const
{
std
::
unordered_set
<
OpHandleBase
*>
ret
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
ret
;
ret
.
reserve
(
preceding_ops_
.
size
());
ret
.
reserve
(
preceding_ops_
.
size
());
for
(
auto
&
pair
:
preceding_ops_
)
{
for
(
auto
&
pair
:
preceding_ops_
)
{
ret
.
insert
(
pair
.
first
);
ret
.
insert
(
pair
.
first
);
...
@@ -49,21 +51,21 @@ std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
...
@@ -49,21 +51,21 @@ std::unordered_set<OpHandleBase *> OpGraphView::AllOps() const {
return
ret
;
return
ret
;
}
}
bool
OpGraphView
::
HasOp
(
OpHandleBase
*
op
)
const
{
bool
OpGraphView
::
HasOp
(
details
::
OpHandleBase
*
op
)
const
{
return
preceding_ops_
.
count
(
op
)
!=
0
;
return
preceding_ops_
.
count
(
op
)
!=
0
;
}
}
void
OpGraphView
::
EnforceHasOp
(
OpHandleBase
*
op
)
const
{
void
OpGraphView
::
EnforceHasOp
(
details
::
OpHandleBase
*
op
)
const
{
PADDLE_ENFORCE
(
HasOp
(
op
),
"Cannot find op %s in OpGraphView"
,
PADDLE_ENFORCE
(
HasOp
(
op
),
"Cannot find op %s in OpGraphView"
,
op
==
nullptr
?
"nullptr"
:
op
->
DebugString
());
op
==
nullptr
?
"nullptr"
:
op
->
DebugString
());
}
}
const
std
::
unordered_set
<
OpHandleBase
*>
&
OpGraphView
::
PendingOps
(
const
std
::
unordered_set
<
details
::
OpHandleBase
*>
&
OpGraphView
::
PendingOps
(
OpHandleBase
*
op
)
const
{
details
::
OpHandleBase
*
op
)
const
{
EnforceHasOp
(
op
);
EnforceHasOp
(
op
);
return
pending_ops_
.
at
(
op
);
return
pending_ops_
.
at
(
op
);
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/op_graph_view.h
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/op_graph_view.h
浏览文件 @
04bd413a
...
@@ -22,39 +22,42 @@
...
@@ -22,39 +22,42 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
OpGraphView
{
class
OpGraphView
{
public:
public:
explicit
OpGraphView
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
);
explicit
OpGraphView
(
const
std
::
vector
<
details
::
OpHandleBase
*>
&
ops
);
std
::
unordered_set
<
OpHandleBase
*>
AllOps
()
const
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
AllOps
()
const
;
const
std
::
unordered_set
<
OpHandleBase
*>
&
PendingOps
(
OpHandleBase
*
op
)
const
;
const
std
::
unordered_set
<
details
::
OpHandleBase
*>
&
PendingOps
(
details
::
OpHandleBase
*
op
)
const
;
bool
HasOp
(
OpHandleBase
*
op
)
const
;
bool
HasOp
(
details
::
OpHandleBase
*
op
)
const
;
// Use a visitor to visit all pending ops of op
// Use a visitor to visit all pending ops of op
// Stop when callback returns false
// Stop when callback returns false
template
<
typename
Callback
>
template
<
typename
Callback
>
bool
VisitAllPendingOps
(
OpHandleBase
*
op
,
Callback
&&
callback
)
const
;
bool
VisitAllPendingOps
(
details
::
OpHandleBase
*
op
,
Callback
&&
callback
)
const
;
private:
private:
void
Build
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
);
void
Build
(
const
std
::
vector
<
details
::
OpHandleBase
*>
&
ops
);
void
EnforceHasOp
(
OpHandleBase
*
op
)
const
;
void
EnforceHasOp
(
details
::
OpHandleBase
*
op
)
const
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unordered_set
<
OpHandleBase
*>>
std
::
unordered_map
<
details
::
OpHandleBase
*
,
std
::
unordered_set
<
details
::
OpHandleBase
*>>
preceding_ops_
;
preceding_ops_
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unordered_set
<
OpHandleBase
*>>
std
::
unordered_map
<
details
::
OpHandleBase
*
,
std
::
unordered_set
<
details
::
OpHandleBase
*>>
pending_ops_
;
pending_ops_
;
};
};
template
<
typename
Callback
>
template
<
typename
Callback
>
bool
OpGraphView
::
VisitAllPendingOps
(
OpHandleBase
*
op
,
bool
OpGraphView
::
VisitAllPendingOps
(
details
::
OpHandleBase
*
op
,
Callback
&&
callback
)
const
{
Callback
&&
callback
)
const
{
EnforceHasOp
(
op
);
EnforceHasOp
(
op
);
std
::
unordered_set
<
OpHandleBase
*>
visited
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
visited
;
std
::
queue
<
OpHandleBase
*>
q
;
std
::
queue
<
details
::
OpHandleBase
*>
q
;
q
.
push
(
op
);
q
.
push
(
op
);
while
(
!
q
.
empty
())
{
while
(
!
q
.
empty
())
{
op
=
q
.
front
();
op
=
q
.
front
();
...
@@ -72,6 +75,6 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
...
@@ -72,6 +75,6 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op,
return
true
;
return
true
;
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/record_skip_memory_opt_vars_pass.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/record_skip_memory_opt_vars_pass.cc
浏览文件 @
04bd413a
...
@@ -15,16 +15,16 @@
...
@@ -15,16 +15,16 @@
#include <string>
#include <string>
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
RecordSkipMemoryOptVarsPass
:
public
ir
::
Pass
{
class
RecordSkipMemoryOptVarsPass
:
public
ir
::
Pass
{
protected:
protected:
...
@@ -162,9 +162,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
...
@@ -162,9 +162,9 @@ class RecordSkipMemoryOptVarsPass : public ir::Pass {
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
record_skip_memory_opt_vars_pass
,
REGISTER_PASS
(
record_skip_memory_opt_vars_pass
,
paddle
::
framework
::
details
::
RecordSkipMemoryOptVarsPass
);
paddle
::
framework
::
ir
::
RecordSkipMemoryOptVarsPass
);
paddle/fluid/framework/
detail
s/reference_count_pass.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/reference_count_pass.cc
浏览文件 @
04bd413a
...
@@ -24,14 +24,20 @@
...
@@ -24,14 +24,20 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
ReferenceCountPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
// A functor to shrink/remove operators who depend on other operators in a set
// A functor to shrink/remove operators who depend on other operators in a set
class
ShrinkDepsOpFunctor
{
class
ShrinkDepsOpFunctor
{
...
@@ -39,19 +45,21 @@ class ShrinkDepsOpFunctor {
...
@@ -39,19 +45,21 @@ class ShrinkDepsOpFunctor {
enum
RelationShip
{
kSame
=
0
,
kNoDeps
=
1
,
kBefore
=
2
,
kAfter
=
3
};
enum
RelationShip
{
kSame
=
0
,
kNoDeps
=
1
,
kBefore
=
2
,
kAfter
=
3
};
public:
public:
explicit
ShrinkDepsOpFunctor
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
explicit
ShrinkDepsOpFunctor
(
const
std
::
vector
<
details
::
OpHandleBase
*>
&
all_ops
)
:
graph_
(
all_ops
)
{}
:
graph_
(
all_ops
)
{}
template
<
typename
OpSet
>
template
<
typename
OpSet
>
OpSet
operator
()(
const
OpSet
&
op_set
)
const
{
OpSet
operator
()(
const
OpSet
&
op_set
)
const
{
using
KeyType
=
typename
OpSet
::
key_type
;
using
KeyType
=
typename
OpSet
::
key_type
;
static_assert
(
static_assert
(
std
::
is_base_of
<
OpHandleBase
,
std
::
is_base_of
<
details
::
OpHandleBase
,
typename
std
::
remove_pointer
<
KeyType
>::
type
>::
value
,
typename
std
::
remove_pointer
<
KeyType
>::
type
>::
value
,
"Key type of OpSet must be OpHandleBase, or derived of OpHandleBase"
);
"Key type of OpSet must be details::OpHandleBase, or derived of "
"details::OpHandleBase"
);
if
(
op_set
.
size
()
<=
1
)
return
op_set
;
if
(
op_set
.
size
()
<=
1
)
return
op_set
;
std
::
vector
<
OpHandleBase
*>
ops
(
op_set
.
begin
(),
op_set
.
end
());
std
::
vector
<
details
::
OpHandleBase
*>
ops
(
op_set
.
begin
(),
op_set
.
end
());
OpSet
ret
;
OpSet
ret
;
auto
rels
=
GetRelations
(
ops
);
auto
rels
=
GetRelations
(
ops
);
auto
not_before
=
[](
RelationShip
r
)
{
return
r
!=
kBefore
;
};
auto
not_before
=
[](
RelationShip
r
)
{
return
r
!=
kBefore
;
};
...
@@ -65,8 +73,8 @@ class ShrinkDepsOpFunctor {
...
@@ -65,8 +73,8 @@ class ShrinkDepsOpFunctor {
private:
private:
std
::
vector
<
std
::
vector
<
RelationShip
>>
GetRelations
(
std
::
vector
<
std
::
vector
<
RelationShip
>>
GetRelations
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
const
{
const
std
::
vector
<
details
::
OpHandleBase
*>
&
ops
)
const
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
op_to_idx
;
std
::
unordered_map
<
details
::
OpHandleBase
*
,
size_t
>
op_to_idx
;
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
graph_
.
HasOp
(
ops
[
i
]),
"Op does not exist in graph"
);
PADDLE_ENFORCE
(
graph_
.
HasOp
(
ops
[
i
]),
"Op does not exist in graph"
);
op_to_idx
[
ops
[
i
]]
=
i
;
op_to_idx
[
ops
[
i
]]
=
i
;
...
@@ -81,7 +89,7 @@ class ShrinkDepsOpFunctor {
...
@@ -81,7 +89,7 @@ class ShrinkDepsOpFunctor {
size_t
found_num
=
ops
.
size
();
size_t
found_num
=
ops
.
size
();
size_t
total_num
=
ops
.
size
()
*
ops
.
size
();
size_t
total_num
=
ops
.
size
()
*
ops
.
size
();
auto
visitor
=
[
&
](
OpHandleBase
*
op
,
size_t
i
)
{
auto
visitor
=
[
&
](
details
::
OpHandleBase
*
op
,
size_t
i
)
{
auto
it
=
op_to_idx
.
find
(
op
);
auto
it
=
op_to_idx
.
find
(
op
);
if
(
it
!=
op_to_idx
.
end
())
{
if
(
it
!=
op_to_idx
.
end
())
{
size_t
j
=
it
->
second
;
size_t
j
=
it
->
second
;
...
@@ -98,7 +106,9 @@ class ShrinkDepsOpFunctor {
...
@@ -98,7 +106,9 @@ class ShrinkDepsOpFunctor {
};
};
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
auto
sub_visitor
=
[
&
,
i
](
OpHandleBase
*
op
)
{
return
visitor
(
op
,
i
);
};
auto
sub_visitor
=
[
&
,
i
](
details
::
OpHandleBase
*
op
)
{
return
visitor
(
op
,
i
);
};
if
(
!
graph_
.
VisitAllPendingOps
(
ops
[
i
],
sub_visitor
))
{
if
(
!
graph_
.
VisitAllPendingOps
(
ops
[
i
],
sub_visitor
))
{
break
;
break
;
}
}
...
@@ -133,8 +143,8 @@ class ShrinkDepsOpFunctor {
...
@@ -133,8 +143,8 @@ class ShrinkDepsOpFunctor {
*/
*/
static
bool
ShrinkNoNeedBufferVarOpDependency
(
static
bool
ShrinkNoNeedBufferVarOpDependency
(
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
std
::
unordered_set
<
ComputationOpHandle
*>
*
op_handles
)
{
std
::
unordered_set
<
details
::
ComputationOpHandle
*>
*
op_handles
)
{
std
::
vector
<
ComputationOpHandle
*>
skip_ops
;
std
::
vector
<
details
::
ComputationOpHandle
*>
skip_ops
;
for
(
auto
*
op_handle
:
*
op_handles
)
{
for
(
auto
*
op_handle
:
*
op_handles
)
{
auto
*
op_base
=
op_handle
->
GetOp
();
auto
*
op_base
=
op_handle
->
GetOp
();
auto
&
inferer
=
op_base
->
Info
().
NoNeedBufferVarsInferer
();
auto
&
inferer
=
op_base
->
Info
().
NoNeedBufferVarsInferer
();
...
@@ -195,15 +205,15 @@ static bool ShrinkNoNeedBufferVarOpDependency(
...
@@ -195,15 +205,15 @@ static bool ShrinkNoNeedBufferVarOpDependency(
* Find the nearest downstream computation op handle. If the op is a
* Find the nearest downstream computation op handle. If the op is a
* computation op, just return itself.
* computation op, just return itself.
*/
*/
static
ComputationOpHandle
*
FindNextComputationOpHandleOrReturnItself
(
static
details
::
ComputationOpHandle
*
FindNextComputationOpHandleOrReturnItself
(
OpHandleBase
*
op
,
size_t
scope_idx
)
{
details
::
OpHandleBase
*
op
,
size_t
scope_idx
)
{
std
::
queue
<
OpHandleBase
*>
q
;
std
::
queue
<
details
::
OpHandleBase
*>
q
;
std
::
unordered_set
<
OpHandleBase
*>
visited
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
visited
;
q
.
push
(
op
);
q
.
push
(
op
);
while
(
!
q
.
empty
())
{
while
(
!
q
.
empty
())
{
auto
*
op
=
q
.
front
();
auto
*
op
=
q
.
front
();
q
.
pop
();
q
.
pop
();
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
auto
*
compute_op
=
dynamic_cast
<
details
::
ComputationOpHandle
*>
(
op
);
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetScopeIdx
()
==
scope_idx
)
{
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetScopeIdx
()
==
scope_idx
)
{
return
compute_op
;
return
compute_op
;
}
}
...
@@ -220,13 +230,13 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
...
@@ -220,13 +230,13 @@ static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
enum
LastLiveOpSearchStatus
{
kSuccess
,
kFailure
,
kShouldPrecede
};
enum
LastLiveOpSearchStatus
{
kSuccess
,
kFailure
,
kShouldPrecede
};
static
std
::
unordered_set
<
ComputationOpHandle
*>
static
std
::
unordered_set
<
details
::
ComputationOpHandle
*>
ExtractComputationOpFromLastLivedVar
(
VarHandle
*
var
,
size_t
scope_idx
,
ExtractComputationOpFromLastLivedVar
(
details
::
VarHandle
*
var
,
size_t
scope_idx
,
const
std
::
string
&
var_name
,
const
std
::
string
&
var_name
,
const
ShrinkDepsOpFunctor
&
shrink_func
,
const
ShrinkDepsOpFunctor
&
shrink_func
,
LastLiveOpSearchStatus
*
status
)
{
LastLiveOpSearchStatus
*
status
)
{
// stage one. Get last op for variable.
// stage one. Get last op for variable.
std
::
unordered_set
<
OpHandleBase
*>
candidates
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
candidates
;
{
{
if
(
var
->
PendingOps
().
empty
()
&&
var
->
GeneratedOp
())
{
if
(
var
->
PendingOps
().
empty
()
&&
var
->
GeneratedOp
())
{
// No operator depends on this variable. So the last operator is the op
// No operator depends on this variable. So the last operator is the op
...
@@ -251,7 +261,7 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
...
@@ -251,7 +261,7 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
// some op handle may operate on many DeviceContext, however, our garbage
// some op handle may operate on many DeviceContext, however, our garbage
// collector can only wait one DeviceContext for now. So currently, we wait
// collector can only wait one DeviceContext for now. So currently, we wait
// the nearest compute op.
// the nearest compute op.
std
::
unordered_set
<
ComputationOpHandle
*>
computation_op
;
std
::
unordered_set
<
details
::
ComputationOpHandle
*>
computation_op
;
{
{
for
(
auto
*
op
:
candidates
)
{
for
(
auto
*
op
:
candidates
)
{
auto
*
compute_op
=
auto
*
compute_op
=
...
@@ -293,13 +303,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -293,13 +303,13 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
"Last Live Ops and Reference Counts of vars should be "
"Last Live Ops and Reference Counts of vars should be "
"initialized at here."
);
"initialized at here."
);
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
const
auto
&
vars
=
graph
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
);
last_live_ops_of_vars
.
resize
(
vars
.
size
());
last_live_ops_of_vars
.
resize
(
vars
.
size
());
ref_cnts
.
resize
(
vars
.
size
());
ref_cnts
.
resize
(
vars
.
size
());
ShrinkDepsOpFunctor
shrink_func
(
ShrinkDepsOpFunctor
shrink_func
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
));
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
));
VLOG
(
1
)
<<
"Place number: "
<<
vars
.
size
();
VLOG
(
1
)
<<
"Place number: "
<<
vars
.
size
();
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
...
@@ -360,11 +370,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
...
@@ -360,11 +370,10 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
}
}
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
reference_count_pass
,
REGISTER_PASS
(
reference_count_pass
,
paddle
::
framework
::
ir
::
ReferenceCountPass
)
paddle
::
framework
::
details
::
ReferenceCountPass
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGlobalReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGlobalReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kLastLiveOpsOfVars
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
);
paddle/fluid/framework/
detail
s/reference_count_pass_helper.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/reference_count_pass_helper.cc
浏览文件 @
04bd413a
...
@@ -12,23 +12,24 @@
...
@@ -12,23 +12,24 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/
detail
s/reference_count_pass_helper.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
VarHandle
*>
&
vars
)
{
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
details
::
VarHandle
*>
&
vars
)
{
VarDesc
*
var_desc
=
nullptr
;
VarDesc
*
var_desc
=
nullptr
;
std
::
find_if
(
vars
.
rbegin
(),
vars
.
rend
(),
[
&
](
VarHandle
*
var_handle
)
->
bool
{
std
::
find_if
(
vars
.
rbegin
(),
vars
.
rend
(),
var_desc
=
var_handle
->
Node
()
->
Var
();
[
&
](
details
::
VarHandle
*
var_handle
)
->
bool
{
return
var_desc
!=
nullptr
;
var_desc
=
var_handle
->
Node
()
->
Var
();
});
return
var_desc
!=
nullptr
;
});
return
var_desc
;
return
var_desc
;
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/reference_count_pass_helper.h
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/reference_count_pass_helper.h
浏览文件 @
04bd413a
...
@@ -22,17 +22,16 @@
...
@@ -22,17 +22,16 @@
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
VarDesc
;
class
VarDesc
;
class
VarHandle
;
namespace
details
{
namespace
ir
{
class
ComputationOpHandle
;
using
ReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
size_t
>
;
using
ReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
size_t
>
;
...
@@ -48,11 +47,12 @@ const char kGarbageCollector[] = "garbage_collector";
...
@@ -48,11 +47,12 @@ const char kGarbageCollector[] = "garbage_collector";
const
char
kAllPlaces
[]
=
"all_places"
;
const
char
kAllPlaces
[]
=
"all_places"
;
using
LastLiveOpsOfVars
=
using
LastLiveOpsOfVars
=
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ComputationOpHandle
*>>
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
details
::
ComputationOpHandle
*>>
;
const
char
kLastLiveOpsOfVars
[]
=
"last_live_ops_of_var"
;
const
char
kLastLiveOpsOfVars
[]
=
"last_live_ops_of_var"
;
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
VarHandle
*>
&
vars
);
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
details
::
VarHandle
*>
&
vars
);
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/while_op_eager_deletion_pass.cc
→
paddle/fluid/framework/
ir/memory_optimize_pas
s/while_op_eager_deletion_pass.cc
浏览文件 @
04bd413a
...
@@ -19,19 +19,19 @@
...
@@ -19,19 +19,19 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
WhileOpEagerDeletionPass
:
public
ir
::
Pass
{
class
WhileOpEagerDeletionPass
:
public
ir
::
Pass
{
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
);
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
);
// Find all while_op and while_grad_op
// Find all while_op and while_grad_op
std
::
unordered_map
<
size_t
,
std
::
pair
<
std
::
vector
<
OperatorBase
*>
,
std
::
unordered_map
<
size_t
,
std
::
pair
<
std
::
vector
<
OperatorBase
*>
,
std
::
vector
<
OperatorBase
*>>>
std
::
vector
<
OperatorBase
*>>>
target_ops
;
target_ops
;
for
(
auto
*
op
:
all_ops
)
{
for
(
auto
*
op
:
all_ops
)
{
auto
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
auto
compute_op
=
dynamic_cast
<
details
::
ComputationOpHandle
*>
(
op
);
if
(
compute_op
==
nullptr
)
continue
;
if
(
compute_op
==
nullptr
)
continue
;
if
(
compute_op
->
Name
()
==
"while"
)
{
if
(
compute_op
->
Name
()
==
"while"
)
{
...
@@ -52,9 +52,9 @@ class WhileOpEagerDeletionPass : public ir::Pass {
...
@@ -52,9 +52,9 @@ class WhileOpEagerDeletionPass : public ir::Pass {
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
while_op_eager_deletion_pass
,
REGISTER_PASS
(
while_op_eager_deletion_pass
,
paddle
::
framework
::
details
::
WhileOpEagerDeletionPass
);
paddle
::
framework
::
ir
::
WhileOpEagerDeletionPass
);
paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt
0 → 100644
浏览文件 @
04bd413a
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper
)
cc_library
(
multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper
)
set
(
ALL_REDUCE_OP_HANDLES all_reduce_op_handle
)
if
(
WITH_GPU AND WITH_DGC
)
list
(
APPEND ALL_REDUCE_OP_HANDLES sparse_all_reduce_op_handle
)
endif
()
cc_library
(
multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle
${
ALL_REDUCE_OP_HANDLES
}
reduce_op_handle broadcast_op_handle fused_broadcast_op_handle
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle
)
cc_library
(
all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass
)
paddle/fluid/framework/
detail
s/all_reduce_deps_pass.cc
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/all_reduce_deps_pass.cc
浏览文件 @
04bd413a
...
@@ -23,7 +23,6 @@
...
@@ -23,7 +23,6 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -31,17 +30,18 @@
...
@@ -31,17 +30,18 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
AllReduceDepsPass
:
public
ir
::
Pass
{
class
AllReduceDepsPass
:
public
ir
::
Pass
{
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
std
::
vector
<
AllReduceOpHandle
*>
all_reduce_op_handles
=
std
::
vector
<
details
::
AllReduceOpHandle
*>
all_reduce_op_handles
=
GetSortedAllReduceOps
(
*
graph
);
GetSortedAllReduceOps
(
*
graph
);
for
(
size_t
i
=
1
;
i
<
all_reduce_op_handles
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
all_reduce_op_handles
.
size
();
++
i
)
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dep_var
=
new
details
::
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
)
.
emplace
(
dep_var
);
all_reduce_op_handles
[
i
-
1
]
->
AddOutput
(
dep_var
);
all_reduce_op_handles
[
i
-
1
]
->
AddOutput
(
dep_var
);
all_reduce_op_handles
[
i
]
->
AddInput
(
dep_var
);
all_reduce_op_handles
[
i
]
->
AddInput
(
dep_var
);
}
}
...
@@ -51,16 +51,16 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -51,16 +51,16 @@ class AllReduceDepsPass : public ir::Pass {
}
}
}
}
std
::
vector
<
AllReduceOpHandle
*>
GetSortedAllReduceOps
(
std
::
vector
<
details
::
AllReduceOpHandle
*>
GetSortedAllReduceOps
(
const
ir
::
Graph
&
graph
)
const
{
const
ir
::
Graph
&
graph
)
const
{
std
::
vector
<
AllReduceOpHandle
*>
all_reduce_op_handles
;
std
::
vector
<
details
::
AllReduceOpHandle
*>
all_reduce_op_handles
;
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_map
<
details
::
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_set
<
OpHandleBase
*>
ready_ops
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
ready_ops
;
std
::
unordered_set
<
OpHandleBase
*>
next_ready_ops
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
next_ready_ops
;
auto
op_handles
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
graph
);
auto
op_handles
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
graph
);
size_t
num_of_ops
=
op_handles
.
size
();
size_t
num_of_ops
=
op_handles
.
size
();
for
(
OpHandleBase
*
op
:
op_handles
)
{
for
(
details
::
OpHandleBase
*
op
:
op_handles
)
{
size_t
not_ready_vars
=
op
->
NotReadyInputSize
();
size_t
not_ready_vars
=
op
->
NotReadyInputSize
();
if
(
not_ready_vars
)
{
if
(
not_ready_vars
)
{
pending_ops
.
insert
({
op
,
not_ready_vars
});
pending_ops
.
insert
({
op
,
not_ready_vars
});
...
@@ -94,11 +94,12 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -94,11 +94,12 @@ class AllReduceDepsPass : public ir::Pass {
}
}
void
GetSortedAllReduceOps
(
void
GetSortedAllReduceOps
(
const
std
::
unordered_set
<
OpHandleBase
*>&
ready_ops
,
const
std
::
unordered_set
<
details
::
OpHandleBase
*>&
ready_ops
,
std
::
vector
<
AllReduceOpHandle
*>*
all_reduce_op_handles
)
const
{
std
::
vector
<
details
::
AllReduceOpHandle
*>*
all_reduce_op_handles
)
const
{
std
::
vector
<
AllReduceOpHandle
*>
current_all_reduce_op_handles
;
std
::
vector
<
details
::
AllReduceOpHandle
*>
current_all_reduce_op_handles
;
for
(
auto
&
op_handle
:
ready_ops
)
{
for
(
auto
&
op_handle
:
ready_ops
)
{
auto
all_reduce_op_handle
=
dynamic_cast
<
AllReduceOpHandle
*>
(
op_handle
);
auto
all_reduce_op_handle
=
dynamic_cast
<
details
::
AllReduceOpHandle
*>
(
op_handle
);
if
(
all_reduce_op_handle
)
{
if
(
all_reduce_op_handle
)
{
current_all_reduce_op_handles
.
emplace_back
(
all_reduce_op_handle
);
current_all_reduce_op_handles
.
emplace_back
(
all_reduce_op_handle
);
}
}
...
@@ -109,10 +110,12 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -109,10 +110,12 @@ class AllReduceDepsPass : public ir::Pass {
// Sort the current_all_reduce_op_handles according to the name of input.
// Sort the current_all_reduce_op_handles according to the name of input.
sort
(
current_all_reduce_op_handles
.
begin
(),
sort
(
current_all_reduce_op_handles
.
begin
(),
current_all_reduce_op_handles
.
end
(),
current_all_reduce_op_handles
.
end
(),
[](
const
AllReduceOpHandle
*
left
,
[](
const
details
::
AllReduceOpHandle
*
left
,
const
AllReduceOpHandle
*
right
)
->
bool
{
const
details
::
AllReduceOpHandle
*
right
)
->
bool
{
auto
left_in_vars
=
DynamicCast
<
VarHandle
>
(
left
->
Inputs
());
auto
left_in_vars
=
auto
right_in_vars
=
DynamicCast
<
VarHandle
>
(
right
->
Inputs
());
details
::
DynamicCast
<
details
::
VarHandle
>
(
left
->
Inputs
());
auto
right_in_vars
=
details
::
DynamicCast
<
details
::
VarHandle
>
(
right
->
Inputs
());
PADDLE_ENFORCE_GT
(
left_in_vars
.
size
(),
0
);
PADDLE_ENFORCE_GT
(
left_in_vars
.
size
(),
0
);
PADDLE_ENFORCE_EQ
(
left_in_vars
.
size
(),
right_in_vars
.
size
());
PADDLE_ENFORCE_EQ
(
left_in_vars
.
size
(),
right_in_vars
.
size
());
return
left_in_vars
[
0
]
->
Name
()
>
right_in_vars
[
0
]
->
Name
();
return
left_in_vars
[
0
]
->
Name
()
>
right_in_vars
[
0
]
->
Name
();
...
@@ -123,15 +126,15 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -123,15 +126,15 @@ class AllReduceDepsPass : public ir::Pass {
current_all_reduce_op_handles
.
end
());
current_all_reduce_op_handles
.
end
());
}
}
void
DebugString
(
void
DebugString
(
const
ir
::
Graph
&
graph
,
const
ir
::
Graph
&
graph
,
const
std
::
vector
<
details
::
AllReduceOpHandle
*>&
const
std
::
vector
<
AllReduceOpHandle
*>&
all_reduce_op_handles
)
const
{
all_reduce_op_handles
)
const
{
// get vars order
// get vars order
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
vars
=
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
vars
=
GetSoredGradientsFromStaleProgram
(
graph
);
GetSoredGradientsFromStaleProgram
(
graph
);
std
::
stringstream
out
;
std
::
stringstream
out
;
size_t
grads_of_stale_program
=
0
;
size_t
grads_of_stale_program
=
0
;
out
<<
"Get Order From kStaleProgramOpDescs: "
;
out
<<
"Get Order From
details::
kStaleProgramOpDescs: "
;
for
(
auto
&
var
:
vars
)
{
for
(
auto
&
var
:
vars
)
{
out
<<
"Order "
<<
var
.
first
<<
" ["
;
out
<<
"Order "
<<
var
.
first
<<
" ["
;
for
(
auto
&
var_name
:
var
.
second
)
{
for
(
auto
&
var_name
:
var
.
second
)
{
...
@@ -147,7 +150,7 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -147,7 +150,7 @@ class AllReduceDepsPass : public ir::Pass {
for
(
auto
&
op
:
all_reduce_op_handles
)
{
for
(
auto
&
op
:
all_reduce_op_handles
)
{
bool
find_valid_input
=
false
;
bool
find_valid_input
=
false
;
for
(
auto
&
in_var
:
op
->
Inputs
())
{
for
(
auto
&
in_var
:
op
->
Inputs
())
{
if
(
dynamic_cast
<
VarHandle
*>
(
in_var
))
{
if
(
dynamic_cast
<
details
::
VarHandle
*>
(
in_var
))
{
out2
<<
in_var
->
Name
()
<<
", "
;
out2
<<
in_var
->
Name
()
<<
", "
;
find_valid_input
=
true
;
find_valid_input
=
true
;
break
;
break
;
...
@@ -165,7 +168,8 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -165,7 +168,8 @@ class AllReduceDepsPass : public ir::Pass {
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
GetSoredGradientsFromStaleProgram
(
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
GetSoredGradientsFromStaleProgram
(
const
ir
::
Graph
&
graph
)
const
{
const
ir
::
Graph
&
graph
)
const
{
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
vars
;
std
::
map
<
int
,
std
::
vector
<
std
::
string
>>
vars
;
auto
ops
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
kStaleProgramOpDescs
);
auto
ops
=
graph
.
Get
<
const
std
::
vector
<
OpDesc
*>>
(
details
::
kStaleProgramOpDescs
);
int
order
=
0
;
int
order
=
0
;
for
(
auto
*
op_desc
:
ops
)
{
for
(
auto
*
op_desc
:
ops
)
{
try
{
try
{
...
@@ -193,10 +197,9 @@ class AllReduceDepsPass : public ir::Pass {
...
@@ -193,10 +197,9 @@ class AllReduceDepsPass : public ir::Pass {
return
vars
;
return
vars
;
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
all_reduce_deps_pass
,
REGISTER_PASS
(
all_reduce_deps_pass
,
paddle
::
framework
::
ir
::
AllReduceDepsPass
)
paddle
::
framework
::
details
::
AllReduceDepsPass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kStaleProgramOpDescs
);
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kStaleProgramOpDescs
);
paddle/fluid/framework/
detail
s/fuse_all_reduce_op_pass.cc
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/fuse_all_reduce_op_pass.cc
浏览文件 @
04bd413a
...
@@ -24,21 +24,22 @@
...
@@ -24,21 +24,22 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
FuseAllReduceOpPass
:
public
ir
::
Pass
{
class
FuseAllReduceOpPass
:
public
ir
::
Pass
{
protected:
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
ir
::
Graph
&
result
=
*
graph
;
ir
::
Graph
&
result
=
*
graph
;
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
auto
&
places
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
auto
&
local_scopes
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
nccl_ctxs
=
&
Get
<
platform
::
NCCLContextMap
>
(
kNCCLCtxs
);
auto
*
nccl_ctxs
=
&
Get
<
platform
::
NCCLContextMap
>
(
details
::
kNCCLCtxs
);
#endif
#endif
std
::
unordered_set
<
std
::
string
>
grads
;
std
::
unordered_set
<
std
::
string
>
grads
;
auto
&
params_grads
=
result
.
Get
<
ParamsAndGrads
>
(
kParamsAndGrads
);
auto
&
params_grads
=
result
.
Get
<
details
::
ParamsAndGrads
>
(
details
::
kParamsAndGrads
);
size_t
num_of_all_reduce
=
params_grads
.
size
();
size_t
num_of_all_reduce
=
params_grads
.
size
();
grads
.
reserve
(
num_of_all_reduce
);
grads
.
reserve
(
num_of_all_reduce
);
for
(
auto
p_g
:
params_grads
)
{
for
(
auto
p_g
:
params_grads
)
{
...
@@ -50,11 +51,12 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -50,11 +51,12 @@ class FuseAllReduceOpPass : public ir::Pass {
all_reduce_ops
.
reserve
(
grads
.
size
());
all_reduce_ops
.
reserve
(
grads
.
size
());
for
(
auto
&
node
:
result
.
Nodes
())
{
for
(
auto
&
node
:
result
.
Nodes
())
{
if
(
node
->
IsOp
())
{
if
(
node
->
IsOp
())
{
PADDLE_ENFORCE
(
node
->
IsWrappedBy
<
OpHandleBase
>
());
PADDLE_ENFORCE
(
node
->
IsWrappedBy
<
details
::
OpHandleBase
>
());
auto
*
all_reduce_op_handle
=
auto
*
all_reduce_op_handle
=
dynamic_cast
<
details
::
AllReduceOpHandle
*>
(
dynamic_cast
<
AllReduceOpHandle
*>
(
&
node
->
Wrapper
<
OpHandleBase
>
());
&
node
->
Wrapper
<
details
::
OpHandleBase
>
());
if
(
all_reduce_op_handle
)
{
if
(
all_reduce_op_handle
)
{
auto
inputs
=
DynamicCast
<
VarHandle
>
(
all_reduce_op_handle
->
Inputs
());
auto
inputs
=
details
::
DynamicCast
<
details
::
VarHandle
>
(
all_reduce_op_handle
->
Inputs
());
PADDLE_ENFORCE_EQ
(
inputs
.
size
(),
num_place
);
PADDLE_ENFORCE_EQ
(
inputs
.
size
(),
num_place
);
// The inputs' name should be the same.
// The inputs' name should be the same.
auto
&
grad_name
=
inputs
[
0
]
->
name
();
auto
&
grad_name
=
inputs
[
0
]
->
name
();
...
@@ -80,7 +82,7 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -80,7 +82,7 @@ class FuseAllReduceOpPass : public ir::Pass {
VLOG
(
10
)
<<
"Insert fused_all_reduce"
;
VLOG
(
10
)
<<
"Insert fused_all_reduce"
;
auto
&
group_grads_params
=
auto
&
group_grads_params
=
graph
->
Get
<
GroupGradsAndParams
>
(
kGroupGradsAndParams
);
graph
->
Get
<
details
::
GroupGradsAndParams
>
(
details
::
kGroupGradsAndParams
);
for
(
auto
&
group_g_p
:
group_grads_params
)
{
for
(
auto
&
group_g_p
:
group_grads_params
)
{
size_t
group_size
=
group_g_p
.
size
();
size_t
group_size
=
group_g_p
.
size
();
...
@@ -108,24 +110,25 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -108,24 +110,25 @@ class FuseAllReduceOpPass : public ir::Pass {
const
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
platform
::
NCCLContextMap
*
nccl_ctxs
,
#endif
#endif
ir
::
Graph
*
result
)
const
{
ir
::
Graph
*
result
)
const
{
std
::
vector
<
VarHandleBase
*>
inputs
;
std
::
vector
<
details
::
VarHandleBase
*>
inputs
;
std
::
vector
<
VarHandleBase
*>
outputs
;
std
::
vector
<
details
::
VarHandleBase
*>
outputs
;
for
(
auto
&
op
:
all_reduce_ops
)
{
for
(
auto
&
op
:
all_reduce_ops
)
{
auto
&
op_handle
=
op
->
Wrapper
<
OpHandleBase
>
();
auto
&
op_handle
=
op
->
Wrapper
<
details
::
OpHandleBase
>
();
inputs
.
insert
(
inputs
.
end
(),
op_handle
.
Inputs
().
begin
(),
inputs
.
insert
(
inputs
.
end
(),
op_handle
.
Inputs
().
begin
(),
op_handle
.
Inputs
().
end
());
op_handle
.
Inputs
().
end
());
// Remove output
// Remove output
for_each
(
op_handle
.
Inputs
().
begin
(),
op_handle
.
Inputs
().
end
(),
for_each
(
op_handle
.
Inputs
().
begin
(),
op_handle
.
Inputs
().
end
(),
[
&
op_handle
](
VarHandleBase
*
var_handle
)
{
[
&
op_handle
](
details
::
VarHandleBase
*
var_handle
)
{
var_handle
->
RemoveOutput
(
&
op_handle
,
op_handle
.
Node
());
var_handle
->
RemoveOutput
(
&
op_handle
,
op_handle
.
Node
());
});
});
outputs
.
insert
(
outputs
.
end
(),
op_handle
.
Outputs
().
begin
(),
outputs
.
insert
(
outputs
.
end
(),
op_handle
.
Outputs
().
begin
(),
op_handle
.
Outputs
().
end
());
op_handle
.
Outputs
().
end
());
// Remove Input
// Remove Input
for_each
(
for_each
(
op_handle
.
Outputs
().
begin
(),
op_handle
.
Outputs
().
end
(),
op_handle
.
Outputs
().
begin
(),
op_handle
.
Outputs
().
end
(),
[](
details
::
VarHandleBase
*
var_handle
)
{
[](
VarHandleBase
*
var_handle
)
{
var_handle
->
ClearGeneratedOp
();
});
var_handle
->
ClearGeneratedOp
();
});
result
->
RemoveNode
(
op_handle
.
Node
());
result
->
RemoveNode
(
op_handle
.
Node
());
}
}
...
@@ -140,21 +143,22 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -140,21 +143,22 @@ class FuseAllReduceOpPass : public ir::Pass {
}
}
private:
private:
void
CreateFusedAllReduceOp
(
const
std
::
vector
<
VarHandleBase
*>
&
inputs
,
void
CreateFusedAllReduceOp
(
const
std
::
vector
<
VarHandleBase
*>
&
outputs
,
const
std
::
vector
<
details
::
VarHandleBase
*>
&
inputs
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
details
::
VarHandleBase
*>
&
outputs
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
size_t
num_of_all_reduce
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const
platform
::
NCCLContextMap
*
nccl_ctxs
,
const
platform
::
NCCLContextMap
*
nccl_ctxs
,
#endif
#endif
ir
::
Graph
*
result
)
const
{
ir
::
Graph
*
result
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
op_handle
=
new
FusedAllReduceOpHandle
(
auto
*
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes
,
places
,
num_of_all_reduce
,
nccl_ctxs
);
local_scopes
,
places
,
num_of_all_reduce
,
nccl_ctxs
);
#else
#else
auto
*
op_handle
=
new
FusedAllReduceOpHandle
(
auto
*
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes
,
places
,
num_of_all_reduce
);
local_scopes
,
places
,
num_of_all_reduce
);
#endif
#endif
...
@@ -176,8 +180,9 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -176,8 +180,9 @@ class FuseAllReduceOpPass : public ir::Pass {
#endif
#endif
}
}
void
SetCommunicationContext
(
const
std
::
vector
<
platform
::
Place
>
&
places
,
void
SetCommunicationContext
(
FusedAllReduceOpHandle
*
op_handle
)
const
{
const
std
::
vector
<
platform
::
Place
>
&
places
,
details
::
FusedAllReduceOpHandle
*
op_handle
)
const
{
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
op_handle
->
SetDeviceContext
(
op_handle
->
SetDeviceContext
(
places
[
i
],
platform
::
DeviceContextPool
::
Instance
().
Get
(
places
[
i
]));
places
[
i
],
platform
::
DeviceContextPool
::
Instance
().
Get
(
places
[
i
]));
...
@@ -185,9 +190,9 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -185,9 +190,9 @@ class FuseAllReduceOpPass : public ir::Pass {
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
fuse_all_reduce_op_pass
,
REGISTER_PASS
(
fuse_all_reduce_op_pass
,
paddle
::
framework
::
details
::
FuseAllReduceOpPass
);
paddle
::
framework
::
ir
::
FuseAllReduceOpPass
);
paddle/fluid/framework/
detail
s/modify_op_lock_and_record_event_pass.cc
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/modify_op_lock_and_record_event_pass.cc
浏览文件 @
04bd413a
...
@@ -12,21 +12,20 @@
...
@@ -12,21 +12,20 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_graph_view.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
static
bool
IsLockAndRecordEventFreeComputationOpHandle
(
static
bool
IsLockAndRecordEventFreeComputationOpHandle
(
ComputationOpHandle
*
op
,
const
OpGraphView
&
graph_view
)
{
details
::
ComputationOpHandle
*
op
,
const
OpGraphView
&
graph_view
)
{
if
(
!
platform
::
is_gpu_place
(
op
->
GetPlace
()))
return
false
;
if
(
!
platform
::
is_gpu_place
(
op
->
GetPlace
()))
return
false
;
for
(
auto
&
pending_op
:
graph_view
.
PendingOps
(
op
))
{
for
(
auto
&
pending_op
:
graph_view
.
PendingOps
(
op
))
{
auto
*
tmp
=
dynamic_cast
<
ComputationOpHandle
*>
(
pending_op
);
auto
*
tmp
=
dynamic_cast
<
details
::
ComputationOpHandle
*>
(
pending_op
);
if
(
tmp
==
nullptr
||
!
(
tmp
->
GetPlace
()
==
op
->
GetPlace
()))
{
if
(
tmp
==
nullptr
||
!
(
tmp
->
GetPlace
()
==
op
->
GetPlace
()))
{
return
false
;
return
false
;
}
}
...
@@ -34,25 +33,27 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
...
@@ -34,25 +33,27 @@ static bool IsLockAndRecordEventFreeComputationOpHandle(
return
true
;
return
true
;
}
}
void
ModifyOpLockAndRecordEventPass
::
ApplyImpl
(
ir
::
Graph
*
ir_graph
)
const
{
class
ModifyOpLockAndRecordEventPass
:
public
ir
::
Pass
{
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
ir_graph
);
protected:
OpGraphView
graph_view
(
all_ops
);
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
for
(
auto
&
op
:
all_ops
)
{
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
);
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
OpGraphView
graph_view
(
all_ops
);
if
(
compute_op
==
nullptr
)
continue
;
for
(
auto
&
op
:
all_ops
)
{
bool
is_lock_and_record_event_free
=
auto
*
compute_op
=
dynamic_cast
<
details
::
ComputationOpHandle
*>
(
op
);
IsLockAndRecordEventFreeComputationOpHandle
(
compute_op
,
graph_view
);
if
(
compute_op
==
nullptr
)
continue
;
compute_op
->
SetLockAndRecordEventFree
(
is_lock_and_record_event_free
);
bool
is_lock_and_record_event_free
=
if
(
is_lock_and_record_event_free
)
{
IsLockAndRecordEventFreeComputationOpHandle
(
compute_op
,
graph_view
);
VLOG
(
10
)
<<
"Set is_lock_and_record_event_free be true in op "
compute_op
->
SetLockAndRecordEventFree
(
is_lock_and_record_event_free
);
<<
compute_op
->
DebugString
();
if
(
is_lock_and_record_event_free
)
{
VLOG
(
10
)
<<
"Set is_lock_and_record_event_free be true in op "
<<
compute_op
->
DebugString
();
}
}
}
}
}
}
};
}
// namespace ir
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
modify_op_lock_and_record_event_pass
,
REGISTER_PASS
(
modify_op_lock_and_record_event_pass
,
paddle
::
framework
::
details
::
ModifyOpLockAndRecordEventPass
);
paddle
::
framework
::
ir
::
ModifyOpLockAndRecordEventPass
);
paddle/fluid/framework/
detail
s/multi_devices_graph_check_pass.cc
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/multi_devices_graph_check_pass.cc
浏览文件 @
04bd413a
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
SSAGraghBuilderWithChecker
:
public
ir
::
Pass
{
class
SSAGraghBuilderWithChecker
:
public
ir
::
Pass
{
protected:
protected:
...
@@ -28,19 +28,19 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
...
@@ -28,19 +28,19 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
{
bool
IsValidGraph
(
const
ir
::
Graph
*
graph
)
const
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_map
<
details
::
OpHandleBase
*
,
size_t
>
pending_ops
;
std
::
unordered_set
<
VarHandleBase
*>
pending_vars
;
std
::
unordered_set
<
details
::
VarHandleBase
*>
pending_vars
;
std
::
unordered_set
<
VarHandleBase
*>
ready_vars
;
std
::
unordered_set
<
details
::
VarHandleBase
*>
ready_vars
;
std
::
unordered_set
<
OpHandleBase
*>
ready_ops
;
std
::
unordered_set
<
details
::
OpHandleBase
*>
ready_ops
;
auto
insert_pending_var
=
[
&
](
VarHandleBase
*
var
)
{
auto
insert_pending_var
=
[
&
](
details
::
VarHandleBase
*
var
)
{
pending_vars
.
insert
(
var
);
pending_vars
.
insert
(
var
);
if
(
var
->
GeneratedOp
()
==
nullptr
)
{
if
(
var
->
GeneratedOp
()
==
nullptr
)
{
ready_vars
.
emplace
(
var
);
ready_vars
.
emplace
(
var
);
}
}
};
};
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
for
(
auto
&
version_pair
:
name_pair
.
second
)
{
insert_pending_var
(
version_pair
);
insert_pending_var
(
version_pair
);
...
@@ -48,11 +48,12 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
...
@@ -48,11 +48,12 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
}
}
for
(
auto
&
var
:
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
for
(
auto
&
var
:
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
insert_pending_var
(
var
);
insert_pending_var
(
var
);
}
}
for
(
OpHandleBase
*
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
))
{
for
(
auto
*
op
:
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
*
graph
))
{
if
(
op
->
Inputs
().
empty
())
{
if
(
op
->
Inputs
().
empty
())
{
ready_ops
.
insert
(
op
);
ready_ops
.
insert
(
op
);
}
else
{
}
else
{
...
@@ -60,7 +61,7 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
...
@@ -60,7 +61,7 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
}
}
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
OpHandleBase
*>
&
set
)
{
auto
run_all_ops
=
[
&
](
std
::
unordered_set
<
details
::
OpHandleBase
*>
&
set
)
{
for
(
auto
*
op
:
set
)
{
for
(
auto
*
op
:
set
)
{
for
(
auto
out
:
op
->
Outputs
())
{
for
(
auto
out
:
op
->
Outputs
())
{
ready_vars
.
emplace
(
out
);
ready_vars
.
emplace
(
out
);
...
@@ -91,11 +92,11 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
...
@@ -91,11 +92,11 @@ class SSAGraghBuilderWithChecker : public ir::Pass {
}
}
};
};
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_devices_check_pass
,
REGISTER_PASS
(
multi_devices_check_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithChecker
)
paddle
::
framework
::
ir
::
SSAGraghBuilderWithChecker
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphVars
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
);
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kGraphDepVars
);
paddle/fluid/framework/
detail
s/multi_devices_graph_pass.cc
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/multi_devices_graph_pass.cc
浏览文件 @
04bd413a
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/
detail
s/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/
ir/multi_devices_graph_pas
s/multi_devices_graph_pass.h"
#include <algorithm>
#include <algorithm>
#include <fstream>
#include <fstream>
#include <memory>
#include <memory>
...
@@ -40,13 +40,13 @@
...
@@ -40,13 +40,13 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
namespace
{
namespace
{
// TODO(panyx0718): Clean this up as well.
// TODO(panyx0718): Clean this up as well.
// all operators. NOTE that even we use a vector here, the operators is
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
// unordered.
typedef
std
::
vector
<
OpHandleBase
*>
GraphOps
;
typedef
std
::
vector
<
details
::
OpHandleBase
*>
GraphOps
;
const
char
kGraphOps
[]
=
"ops"
;
const
char
kGraphOps
[]
=
"ops"
;
bool
OpHaveRole
(
const
ir
::
Node
&
node
,
const
framework
::
OpRole
&
role
)
{
bool
OpHaveRole
(
const
ir
::
Node
&
node
,
const
framework
::
OpRole
&
role
)
{
...
@@ -56,7 +56,7 @@ bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
...
@@ -56,7 +56,7 @@ bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
}
}
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
void
PolishGraphToSupportDataHazards
(
ir
::
Graph
*
graph
)
{
for
(
auto
&
var_map
:
graph
->
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
var_map
:
graph
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
name_pair
:
var_map
)
{
for
(
auto
&
name_pair
:
var_map
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
if
(
name_pair
.
second
.
size
()
<=
1
)
{
continue
;
continue
;
...
@@ -65,7 +65,7 @@ void PolishGraphToSupportDataHazards(ir::Graph *graph) {
...
@@ -65,7 +65,7 @@ void PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto
it_old
=
name_pair
.
second
.
rbegin
();
auto
it_old
=
name_pair
.
second
.
rbegin
();
++
it_old
;
++
it_old
;
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
for
(;
it_old
!=
name_pair
.
second
.
rend
();
it_new
=
it_old
,
++
it_old
)
{
OpHandleBase
*
write_op
=
(
*
it_new
)
->
GeneratedOp
();
details
::
OpHandleBase
*
write_op
=
(
*
it_new
)
->
GeneratedOp
();
const
auto
&
read_ops
=
(
*
it_old
)
->
PendingOps
();
const
auto
&
read_ops
=
(
*
it_old
)
->
PendingOps
();
for
(
auto
*
read_op
:
read_ops
)
{
for
(
auto
*
read_op
:
read_ops
)
{
...
@@ -85,28 +85,31 @@ void PolishGraphToSupportDataHazards(ir::Graph *graph) {
...
@@ -85,28 +85,31 @@ void PolishGraphToSupportDataHazards(ir::Graph *graph) {
}
}
if
(
has_dep
)
continue
;
if
(
has_dep
)
continue
;
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dep_var
=
new
details
::
DummyVarHandle
(
graph
->
CreateControlDepVar
());
read_op
->
AddOutput
(
dep_var
);
read_op
->
AddOutput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
write_op
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
)
.
emplace
(
dep_var
);
}
}
}
}
}
}
}
}
}
}
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
details
::
VarHandle
*
CreateOrGetLatestVarHandle
(
ir
::
Graph
*
graph
,
ir
::
Node
*
node
,
const
platform
::
Place
&
place
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
var_holders
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holders
=
graph
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
auto
&
var_holder
=
var_holders
[
node
->
Name
()];
VarHandle
*
var
=
nullptr
;
details
::
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
if
(
var_holder
.
empty
())
{
if
(
node
->
Var
())
{
if
(
node
->
Var
())
{
var
=
new
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
place_offset
,
var
=
new
details
::
VarHandle
(
graph
->
CreateVarNode
(
node
->
Var
()),
0
,
node
->
Name
(),
place
);
place_offset
,
node
->
Name
(),
place
);
}
else
{
}
else
{
var
=
new
VarHandle
(
var
=
new
details
::
VarHandle
(
graph
->
CreateEmptyNode
(
node
->
Name
(),
ir
::
Node
::
Type
::
kVariable
),
0
,
graph
->
CreateEmptyNode
(
node
->
Name
(),
ir
::
Node
::
Type
::
kVariable
),
0
,
place_offset
,
node
->
Name
(),
place
);
place_offset
,
node
->
Name
(),
place
);
}
}
...
@@ -117,14 +120,14 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
...
@@ -117,14 +120,14 @@ VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
return
var
;
return
var
;
}
}
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
OpHandleBase
*
op_handle
,
void
CreateOpOutput
(
ir
::
Graph
*
graph
,
details
::
OpHandleBase
*
op_handle
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
ir
::
Node
*
new_node
,
const
platform
::
Place
&
place
,
size_t
place_offset
)
{
size_t
place_offset
)
{
auto
&
vars
=
auto
&
vars
=
graph
->
Get
<
details
::
GraphVars
>
(
graph
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
details
::
kGraphVars
)[
place_offset
][
new_node
->
Name
()];
size_t
version
=
vars
.
size
();
size_t
version
=
vars
.
size
();
auto
var
=
auto
var
=
new
details
::
VarHandle
(
new_node
,
version
,
place_offset
,
new
VarHandle
(
new_node
,
version
,
place_offset
,
new_node
->
Name
(),
place
);
new_node
->
Name
(),
place
);
vars
.
emplace_back
(
var
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
op_handle
->
AddOutput
(
var
);
}
}
...
@@ -134,8 +137,10 @@ void AddOutputToLeafOps(ir::Graph *graph) {
...
@@ -134,8 +137,10 @@ void AddOutputToLeafOps(ir::Graph *graph) {
if
(
!
op
->
Outputs
().
empty
())
{
if
(
!
op
->
Outputs
().
empty
())
{
continue
;
continue
;
}
}
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
auto
*
dummy_leaf
=
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
new
details
::
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
)
.
emplace
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
op
->
AddOutput
(
dummy_leaf
);
}
}
}
}
...
@@ -148,11 +153,11 @@ void MultiDevSSAGraphBuilderBase::Init() const {
...
@@ -148,11 +153,11 @@ void MultiDevSSAGraphBuilderBase::Init() const {
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
loss_var_name_
=
Get
<
const
std
::
string
>
(
kLossVarName
);
VLOG
(
10
)
<<
"Init MultiDevSSAGraphBuilder, loss name: "
<<
loss_var_name_
;
VLOG
(
10
)
<<
"Init MultiDevSSAGraphBuilder, loss name: "
<<
loss_var_name_
;
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
kPlaces
);
places_
=
Get
<
const
std
::
vector
<
platform
::
Place
>>
(
details
::
kPlaces
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
kLocalScopes
);
local_scopes_
=
Get
<
const
std
::
vector
<
Scope
*>>
(
details
::
kLocalScopes
);
strategy_
=
Get
<
const
BuildStrategy
>
(
kStrategy
);
strategy_
=
Get
<
const
details
::
BuildStrategy
>
(
kStrategy
);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
kNCCLCtxs
);
nccl_ctxs_
=
&
Get
<
platform
::
NCCLContextMap
>
(
details
::
kNCCLCtxs
);
#endif
#endif
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
PADDLE_ENFORCE_EQ
(
places_
.
size
(),
local_scopes_
.
size
());
}
}
...
@@ -172,8 +177,8 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
...
@@ -172,8 +177,8 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
}
}
// We cannot invoke resize. It is a bug of GCC 4.8
// We cannot invoke resize. It is a bug of GCC 4.8
result
.
Set
(
kGraphVars
,
new
GraphVars
(
places_
.
size
()));
result
.
Set
(
details
::
kGraphVars
,
new
details
::
GraphVars
(
places_
.
size
()));
result
.
Set
(
kGraphDepVars
,
new
GraphDepVars
);
result
.
Set
(
details
::
kGraphDepVars
,
new
details
::
GraphDepVars
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
result
.
Set
(
kGraphOps
,
new
GraphOps
);
bool
is_forwarding
=
true
;
bool
is_forwarding
=
true
;
...
@@ -260,13 +265,13 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
...
@@ -260,13 +265,13 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
// user can customize loss@grad if not use_default_grad_scale_
// user can customize loss@grad if not use_default_grad_scale_
size_t
loss_scale
=
0
;
size_t
loss_scale
=
0
;
switch
(
this
->
strategy_
.
gradient_scale_
)
{
switch
(
this
->
strategy_
.
gradient_scale_
)
{
case
BuildStrategy
::
GradientScaleStrategy
::
kOne
:
case
details
::
BuildStrategy
::
GradientScaleStrategy
::
kOne
:
loss_scale
=
1
;
loss_scale
=
1
;
break
;
break
;
case
BuildStrategy
::
GradientScaleStrategy
::
kCoeffNumDevice
:
case
details
::
BuildStrategy
::
GradientScaleStrategy
::
kCoeffNumDevice
:
loss_scale
=
Get
<
size_t
>
(
kNRanks
);
loss_scale
=
Get
<
size_t
>
(
kNRanks
);
break
;
break
;
case
BuildStrategy
::
GradientScaleStrategy
::
kCustomized
:
case
details
::
BuildStrategy
::
GradientScaleStrategy
::
kCustomized
:
loss_scale
=
0
;
loss_scale
=
0
;
break
;
break
;
default:
default:
...
@@ -328,7 +333,8 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
...
@@ -328,7 +333,8 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
p
));
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
result
,
input
,
p
,
place_id
);
details
::
VarHandle
*
var
=
CreateOrGetLatestVarHandle
(
result
,
input
,
p
,
place_id
);
op_handle
->
AddInput
(
var
);
op_handle
->
AddInput
(
var
);
}
}
...
@@ -345,7 +351,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
...
@@ -345,7 +351,7 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
}
}
void
MultiDevSSAGraphBuilderBase
::
SetCommunicationContext
(
void
MultiDevSSAGraphBuilderBase
::
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
{
details
::
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if
(
nccl_ctxs_
==
nullptr
)
{
if
(
nccl_ctxs_
==
nullptr
)
{
op_handle
->
SetDeviceContext
(
p
,
op_handle
->
SetDeviceContext
(
p
,
...
@@ -361,25 +367,28 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
...
@@ -361,25 +367,28 @@ void MultiDevSSAGraphBuilderBase::CreateBroadcastOp(ir::Graph *result,
const
std
::
string
&
p_name
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
{
size_t
src_dev_id
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
op_handle
=
new
BroadcastOpHandle
(
auto
*
op_handle
=
new
details
::
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
);
local_scopes_
,
places_
,
nccl_ctxs_
);
#else
#else
auto
*
op_handle
=
new
BroadcastOpHandle
(
auto
*
op_handle
=
new
details
::
BroadcastOpHandle
(
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
);
local_scopes_
,
places_
);
#endif
#endif
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
auto
*
in
=
auto
*
in
=
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
src_dev_id
).
at
(
p_name
).
back
();
.
at
(
src_dev_id
)
.
at
(
p_name
)
.
back
();
op_handle
->
AddInput
(
in
);
op_handle
->
AddInput
(
in
);
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
i
).
at
(
p_name
);
auto
&
vars
=
auto
*
out_var
=
new
VarHandle
(
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
).
at
(
i
).
at
(
p_name
);
auto
*
out_var
=
new
details
::
VarHandle
(
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
p_name
,
p
);
i
,
p_name
,
p
);
vars
.
emplace_back
(
out_var
);
vars
.
emplace_back
(
out_var
);
...
@@ -391,11 +400,11 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
...
@@ -391,11 +400,11 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
ir
::
Graph
*
result
,
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
{
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto
*
op_handle
=
new
FusedBroadcastOpHandle
(
auto
*
op_handle
=
new
details
::
FusedBroadcastOpHandle
(
result
->
CreateEmptyNode
(
"fused_broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"fused_broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
);
local_scopes_
,
places_
,
nccl_ctxs_
);
#else
#else
auto
*
op_handle
=
new
FusedBroadcastOpHandle
(
auto
*
op_handle
=
new
details
::
FusedBroadcastOpHandle
(
result
->
CreateEmptyNode
(
"fused_broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"fused_broadcast"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
);
local_scopes_
,
places_
);
#endif
#endif
...
@@ -408,14 +417,17 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
...
@@ -408,14 +417,17 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_varnames
.
size
();
++
dev_id
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
bcast_varnames
.
size
();
++
dev_id
)
{
for
(
auto
&
p_name
:
bcast_varnames
[
dev_id
])
{
for
(
auto
&
p_name
:
bcast_varnames
[
dev_id
])
{
auto
*
in
=
auto
*
in
=
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
dev_id
).
at
(
p_name
).
back
();
.
at
(
dev_id
)
.
at
(
p_name
)
.
back
();
op_handle
->
AddInput
(
in
);
op_handle
->
AddInput
(
in
);
for
(
size_t
out_dev_id
=
0
;
out_dev_id
<
places_
.
size
();
++
out_dev_id
)
{
for
(
size_t
out_dev_id
=
0
;
out_dev_id
<
places_
.
size
();
++
out_dev_id
)
{
auto
&
p
=
places_
[
out_dev_id
];
auto
&
p
=
places_
[
out_dev_id
];
auto
&
vars
=
auto
&
vars
=
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)
result
->
Get
<
GraphVars
>
(
kGraphVars
).
at
(
out_dev_id
).
at
(
p_name
);
.
at
(
out_dev_id
)
auto
*
out_var
=
new
VarHandle
(
.
at
(
p_name
);
auto
*
out_var
=
new
details
::
VarHandle
(
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
result
->
CreateEmptyNode
(
p_name
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
out_dev_id
,
p_name
,
p
);
vars
.
size
(),
out_dev_id
,
p_name
,
p
);
vars
.
emplace_back
(
out_var
);
vars
.
emplace_back
(
out_var
);
...
@@ -429,39 +441,44 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
...
@@ -429,39 +441,44 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir
::
Node
*
node
,
ir
::
Node
*
node
,
size_t
dev_id
)
const
{
size_t
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
new
details
::
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
],
dev_id
));
local_scopes_
[
dev_id
],
places_
[
dev_id
],
dev_id
));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
}
}
void
MultiDevSSAGraphBuilderBase
::
CreateAllReduceOp
(
ir
::
Graph
*
result
,
void
MultiDevSSAGraphBuilderBase
::
CreateAllReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
const
std
::
string
&
og
,
bool
is_encoded
)
const
{
bool
is_encoded
)
const
{
OpHandleBase
*
op_handle
=
nullptr
;
details
::
OpHandleBase
*
op_handle
=
nullptr
;
auto
append_allreduce_op
=
[
&
](
auto
append_allreduce_op
=
[
&
](
const
std
::
vector
<
Scope
*>
&
scopes
,
const
std
::
vector
<
Scope
*>
&
scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
)
->
OpHandleBase
*
{
const
std
::
vector
<
platform
::
Place
>
&
places
)
->
details
::
OpHandleBase
*
{
#if defined(PADDLE_WITH_DGC)
#if defined(PADDLE_WITH_DGC)
if
(
is_encoded
)
{
if
(
is_encoded
)
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
SparseAllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
new
details
::
SparseAllReduceOpHandle
(
scopes
,
places
,
nccl_ctxs_
,
is_encoded
,
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
static_cast
<
int
>
(
strategy_
.
trainers_endpoints_
.
size
())
*
scopes
,
places
,
nccl_ctxs_
,
is_encoded
,
places_
.
size
()));
static_cast
<
int
>
(
strategy_
.
trainers_endpoints_
.
size
())
*
places_
.
size
()));
}
else
{
}
else
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
new
details
::
AllReduceOpHandle
(
scopes
,
places
,
nccl_ctxs_
));
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
));
}
}
#elif defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#elif defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
new
details
::
AllReduceOpHandle
(
scopes
,
places
,
nccl_ctxs_
));
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
AllReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
new
details
::
AllReduceOpHandle
(
scopes
,
places
));
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
));
#endif
#endif
return
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
return
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
};
};
...
@@ -475,15 +492,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
...
@@ -475,15 +492,15 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
}
}
SetCommunicationContext
(
op_handle
,
places_
[
i
]);
SetCommunicationContext
(
op_handle
,
places_
[
i
]);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
);
op_handle
->
AddInput
(
prev_grad
);
VLOG
(
10
)
<<
"all_reduce_op_handle add input "
<<
prev_grad
->
DebugString
();
VLOG
(
10
)
<<
"all_reduce_op_handle add input "
<<
prev_grad
->
DebugString
();
auto
var
=
auto
var
=
new
details
::
VarHandle
(
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
)
,
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
i
,
vars
.
size
(),
i
,
og
,
places_
[
i
]);
og
,
places_
[
i
]);
vars
.
emplace_back
(
var
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
op_handle
->
AddOutput
(
var
);
VLOG
(
10
)
<<
"all_reduce_op_handle add output "
<<
og
VLOG
(
10
)
<<
"all_reduce_op_handle add output "
<<
og
...
@@ -497,7 +514,7 @@ void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp(
...
@@ -497,7 +514,7 @@ void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp(
proto
::
VarType
::
Type
dtype
)
const
{
proto
::
VarType
::
Type
dtype
)
const
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]);
auto
*
op_handle
=
new
ScaleLossGradOpHandle
(
auto
*
op_handle
=
new
details
::
ScaleLossGradOpHandle
(
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"scale_loss_grad"
,
ir
::
Node
::
Type
::
kOperation
),
loss_scale
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
loss_scale
,
local_scopes_
[
i
],
places_
[
i
],
dev_ctx
,
dtype
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
op_handle
);
...
@@ -518,20 +535,21 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
...
@@ -518,20 +535,21 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
,
scope_idx
));
new
details
::
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
,
scope_idx
));
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
}
}
}
}
VarHandle
*
MultiDevSSAGraphBuilderBase
::
CreateReduceOp
(
details
::
VarHandle
*
MultiDevSSAGraphBuilderBase
::
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
size_t
dst_dev_id
)
const
{
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
size_t
dst_dev_id
)
const
{
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
,
nccl_ctxs_
));
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
#else
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ReduceOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
ReduceOpHandle
(
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"reduce"
,
ir
::
Node
::
Type
::
kOperation
),
local_scopes_
,
places_
));
local_scopes_
,
places_
));
#endif
#endif
...
@@ -540,15 +558,16 @@ VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
...
@@ -540,15 +558,16 @@ VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
SetCommunicationContext
(
op_handle
,
p
);
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
i
][
og
];
auto
&
vars
=
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)[
i
][
og
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
PADDLE_ENFORCE
(
!
vars
.
empty
());
auto
&
prev_grad
=
vars
.
back
();
auto
&
prev_grad
=
vars
.
back
();
op_handle
->
AddInput
(
prev_grad
);
op_handle
->
AddInput
(
prev_grad
);
}
}
auto
&
vars
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
dst_dev_id
][
og
];
auto
&
vars
=
auto
var
=
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)[
dst_dev_id
][
og
];
new
VarHandle
(
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
auto
var
=
new
details
::
VarHandle
(
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
result
->
CreateEmptyNode
(
og
,
ir
::
Node
::
Type
::
kVariable
),
vars
.
size
(),
dst_dev_id
,
og
,
places_
[
dst_dev_id
]);
vars
.
emplace_back
(
var
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
op_handle
->
AddOutput
(
var
);
return
var
;
return
var
;
...
@@ -596,7 +615,7 @@ int BalanceVarSSAGraphBuilder::GetVarDeviceID(
...
@@ -596,7 +615,7 @@ int BalanceVarSSAGraphBuilder::GetVarDeviceID(
}
}
int
BalanceVarSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
int
BalanceVarSSAGraphBuilder
::
GetOpDeviceID
(
ir
::
Node
*
node
)
const
{
if
(
strategy_
.
reduce_
!=
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
if
(
strategy_
.
reduce_
!=
details
::
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
-
1
;
return
-
1
;
}
}
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
if
(
!
OpHaveRole
(
*
node
,
framework
::
OpRole
::
kOptimize
))
{
...
@@ -830,9 +849,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
...
@@ -830,9 +849,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
void
SetOpInputsAllPlaces
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
int
num_places
)
{
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
auto
*
op_handle
=
result
->
Get
<
GraphOps
>
(
kGraphOps
).
back
();
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
for
(
ir
::
Node
*
input
:
node
->
inputs
)
{
VarHandle
*
var
=
nullptr
;
details
::
VarHandle
*
var
=
nullptr
;
for
(
int
place_offset
=
0
;
place_offset
<
num_places
;
++
place_offset
)
{
for
(
int
place_offset
=
0
;
place_offset
<
num_places
;
++
place_offset
)
{
auto
&
var_holders
=
result
->
Get
<
GraphVars
>
(
kGraphVars
)[
place_offset
];
auto
&
var_holders
=
result
->
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
)[
place_offset
];
auto
&
var_holder
=
var_holders
[
input
->
Name
()];
auto
&
var_holder
=
var_holders
[
input
->
Name
()];
if
(
!
var_holder
.
empty
())
{
if
(
!
var_holder
.
empty
())
{
var
=
*
var_holder
.
rbegin
();
var
=
*
var_holder
.
rbegin
();
...
@@ -852,7 +872,8 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
...
@@ -852,7 +872,8 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
"This hack no longer holds, please fix."
);
"This hack no longer holds, please fix."
);
// the variable name which contains .block means it was splited by
// the variable name which contains .block means it was splited by
// split_byref op
// split_byref op
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
&&
if
(
strategy_
.
reduce_
==
details
::
BuildStrategy
::
ReduceStrategy
::
kAllReduce
&&
node
->
inputs
[
0
]
->
Name
().
find
(
".block"
)
==
std
::
string
::
npos
)
{
node
->
inputs
[
0
]
->
Name
().
find
(
".block"
)
==
std
::
string
::
npos
)
{
std
::
vector
<
std
::
string
>
input_var_names
;
std
::
vector
<
std
::
string
>
input_var_names
;
for
(
ir
::
Node
*
n
:
node
->
inputs
)
{
for
(
ir
::
Node
*
n
:
node
->
inputs
)
{
...
@@ -898,10 +919,11 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
...
@@ -898,10 +919,11 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
// Create fetch_barrier op handle to enable output on all devices.
// Create fetch_barrier op handle to enable output on all devices.
// **NOTE** fetch_barrier should output variables list same as recv op does.
// **NOTE** fetch_barrier should output variables list same as recv op does.
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
FetchBarrierOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
,
places_
));
new
details
::
FetchBarrierOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
,
places_
));
}
else
{
}
else
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
RPCOpHandle
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
details
::
RPCOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
result
->
CreateOpNode
(
node
->
Op
()),
*
node
->
Op
(),
local_scopes_
[
op_dev_id
],
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
node
->
Op
()
->
Type
(),
places_
[
op_dev_id
]));
}
}
...
@@ -954,7 +976,8 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -954,7 +976,8 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
node
->
Op
()
->
Type
()
==
"split_ids"
)
{
// TODO(paddle-dev): getting the first var is not safe.
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
op_dev_id
=
GetVarDeviceID
(
input_var_names
[
0
]);
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
if
(
strategy_
.
reduce_
==
details
::
BuildStrategy
::
ReduceStrategy
::
kAllReduce
)
{
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
op_dev_id
=
GetAppropriateDeviceID
(
input_var_names
);
for
(
auto
&
varname
:
input_var_names
)
{
for
(
auto
&
varname
:
input_var_names
)
{
sharded_var_device_
.
emplace
(
varname
,
op_dev_id
);
sharded_var_device_
.
emplace
(
varname
,
op_dev_id
);
...
@@ -985,7 +1008,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
...
@@ -985,7 +1008,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
#if defined(PADDLE_WITH_DGC)
#if defined(PADDLE_WITH_DGC)
bool
AllReduceSSAGraphBuilder
::
IsEncoded
(
const
std
::
string
&
p_name
)
const
{
bool
AllReduceSSAGraphBuilder
::
IsEncoded
(
const
std
::
string
&
p_name
)
const
{
auto
u_name
=
p_name
+
g_dgc_u
;
auto
u_name
=
p_name
+
details
::
g_dgc_u
;
auto
it
=
all_vars_
.
find
(
u_name
);
auto
it
=
all_vars_
.
find
(
u_name
);
if
(
it
==
all_vars_
.
end
())
{
if
(
it
==
all_vars_
.
end
())
{
VLOG
(
10
)
<<
"can't find u_name, so it's not encoded:"
<<
u_name
;
VLOG
(
10
)
<<
"can't find u_name, so it's not encoded:"
<<
u_name
;
...
@@ -1006,12 +1029,12 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
...
@@ -1006,12 +1029,12 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
// collective gradient to each device
// collective gradient to each device
size_t
cur_device_id
=
0
;
size_t
cur_device_id
=
0
;
switch
(
strategy_
.
reduce_
)
{
switch
(
strategy_
.
reduce_
)
{
case
BuildStrategy
::
ReduceStrategy
::
kReduce
:
case
details
::
BuildStrategy
::
ReduceStrategy
::
kReduce
:
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
cur_device_id
=
GetAppropriateDeviceID
({
g_name
});
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
CreateReduceOp
(
result
,
g_name
,
cur_device_id
);
sharded_var_device_
.
emplace
(
g_name
,
cur_device_id
);
sharded_var_device_
.
emplace
(
g_name
,
cur_device_id
);
break
;
break
;
case
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
case
details
::
BuildStrategy
::
ReduceStrategy
::
kAllReduce
:
if
(
IsSparseGradient
(
g_name
))
{
if
(
IsSparseGradient
(
g_name
))
{
CreateReduceOp
(
result
,
g_name
,
0
);
CreateReduceOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
CreateBroadcastOp
(
result
,
g_name
,
0
);
...
@@ -1038,7 +1061,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
...
@@ -1038,7 +1061,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
// 4. CPU && Reduce: because all parameters share the same memory, did not
// 4. CPU && Reduce: because all parameters share the same memory, did not
// broadcast received parameters.
// broadcast received parameters.
if
(
!
UseGPU
()
&&
if
(
!
UseGPU
()
&&
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
strategy_
.
reduce_
==
details
::
BuildStrategy
::
ReduceStrategy
::
kReduce
)
{
return
;
return
;
}
}
if
(
strategy_
.
fuse_broadcast_ops_
)
{
if
(
strategy_
.
fuse_broadcast_ops_
)
{
...
@@ -1064,29 +1087,28 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
...
@@ -1064,29 +1087,28 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) {
return
0
;
return
0
;
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#define REGISTER_MULTI_DEVICES_PASS(pass_name, pass_class)
\
#define REGISTER_MULTI_DEVICES_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(
\
STATIC_ASSERT_GLOBAL_NAMESPACE( \
_reg_ssa_graph_builder_##pass_name,
\
_reg_ssa_graph_builder_##pass_name, \
"REGISTER_MULTI_DEVICES_PASS must be called in global namespace.");
\
"REGISTER_MULTI_DEVICES_PASS must be called in global namespace."); \
int _reg_ssa_graph_builder_entry_##pass_name =
\
int _reg_ssa_graph_builder_entry_##pass_name = \
paddle::framework::
details
::MultiDevSSAGraphBuilderRegister(#pass_name); \
paddle::framework::
ir
::MultiDevSSAGraphBuilderRegister(#pass_name); \
REGISTER_PASS(pass_name, pass_class)
\
REGISTER_PASS(pass_name, pass_class) \
.RequirePassAttr(paddle::framework::
details
::kLossVarName) \
.RequirePassAttr(paddle::framework::
ir
::kLossVarName) \
.RequirePassAttr(paddle::framework::details::kPlaces)
\
.RequirePassAttr(paddle::framework::details::kPlaces) \
.RequirePassAttr(paddle::framework::details::kLocalScopes)
\
.RequirePassAttr(paddle::framework::details::kLocalScopes) \
.RequirePassAttr(paddle::framework::
details
::kStrategy) \
.RequirePassAttr(paddle::framework::
ir
::kStrategy) \
.RequirePassAttr(paddle::framework::
details
::kNRanks)
.RequirePassAttr(paddle::framework::
ir
::kNRanks)
REGISTER_MULTI_DEVICES_PASS
(
reduce_mode_multi_devices_pass
,
REGISTER_MULTI_DEVICES_PASS
(
reduce_mode_multi_devices_pass
,
paddle
::
framework
::
details
::
ReduceSSAGraphBuilder
);
paddle
::
framework
::
ir
::
ReduceSSAGraphBuilder
);
REGISTER_MULTI_DEVICES_PASS
(
REGISTER_MULTI_DEVICES_PASS
(
all_reduce_mode_multi_devices_pass
,
all_reduce_mode_multi_devices_pass
,
paddle
::
framework
::
ir
::
AllReduceSSAGraphBuilder
);
paddle
::
framework
::
details
::
AllReduceSSAGraphBuilder
);
REGISTER_MULTI_DEVICES_PASS
(
dist_multi_devices_pass
,
REGISTER_MULTI_DEVICES_PASS
(
dist_multi_devices_pass
,
paddle
::
framework
::
details
::
DistSSAGraphBuilder
);
paddle
::
framework
::
ir
::
DistSSAGraphBuilder
);
REGISTER_MULTI_DEVICES_PASS
(
async_multi_devices_pass
,
REGISTER_MULTI_DEVICES_PASS
(
async_multi_devices_pass
,
paddle
::
framework
::
details
::
AsyncSSAGraphBuilder
);
paddle
::
framework
::
ir
::
AsyncSSAGraphBuilder
);
paddle/fluid/framework/
detail
s/multi_devices_graph_pass.h
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/multi_devices_graph_pass.h
浏览文件 @
04bd413a
...
@@ -31,7 +31,7 @@ class NCCLContextMap;
...
@@ -31,7 +31,7 @@ class NCCLContextMap;
namespace
framework
{
namespace
framework
{
class
Scope
;
class
Scope
;
namespace
details
{
namespace
ir
{
constexpr
char
kLossVarName
[]
=
"loss_var_name"
;
constexpr
char
kLossVarName
[]
=
"loss_var_name"
;
constexpr
char
kStrategy
[]
=
"strategy"
;
constexpr
char
kStrategy
[]
=
"strategy"
;
...
@@ -69,8 +69,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -69,8 +69,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir
::
Node
*
out_var_node
,
size_t
loss_scale
,
ir
::
Node
*
out_var_node
,
size_t
loss_scale
,
proto
::
VarType
::
Type
dtype
)
const
;
proto
::
VarType
::
Type
dtype
)
const
;
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
details
::
VarHandle
*
CreateReduceOp
(
ir
::
Graph
*
result
,
const
std
::
string
&
og
,
size_t
dst_dev_id
)
const
;
size_t
dst_dev_id
)
const
;
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateComputationalOp
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
size_t
dev_id
)
const
;
size_t
dev_id
)
const
;
...
@@ -89,7 +89,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -89,7 +89,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
ir
::
Graph
*
result
,
ir
::
Graph
*
result
,
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
;
const
std
::
vector
<
std
::
unordered_set
<
std
::
string
>>
&
bcast_varnames
)
const
;
void
SetCommunicationContext
(
OpHandleBase
*
op_handle
,
void
SetCommunicationContext
(
details
::
OpHandleBase
*
op_handle
,
const
platform
::
Place
&
p
)
const
;
const
platform
::
Place
&
p
)
const
;
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
void
CreateOpHandleIOs
(
ir
::
Graph
*
result
,
ir
::
Node
*
node
,
...
@@ -103,7 +103,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
...
@@ -103,7 +103,7 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
mutable
std
::
vector
<
platform
::
Place
>
places_
;
mutable
std
::
vector
<
platform
::
Place
>
places_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
mutable
std
::
vector
<
Scope
*>
local_scopes_
;
mutable
BuildStrategy
strategy_
;
mutable
details
::
BuildStrategy
strategy_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
mutable
std
::
unordered_map
<
std
::
string
,
VarDesc
*>
all_vars_
;
};
};
...
@@ -209,6 +209,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
...
@@ -209,6 +209,6 @@ class DistSSAGraphBuilder : public BalanceVarSSAGraphBuilder {
std
::
unordered_set
<
std
::
string
>
&
MultiDevSSAGraphBuilder
();
std
::
unordered_set
<
std
::
string
>
&
MultiDevSSAGraphBuilder
();
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/
detail
s/multi_devices_graph_print_pass.cc
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/multi_devices_graph_print_pass.cc
浏览文件 @
04bd413a
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/
detail
s/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/
ir/multi_devices_graph_pas
s/multi_devices_graph_print_pass.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
...
@@ -21,11 +21,21 @@
...
@@ -21,11 +21,21 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
class
SSAGraghBuilderWithPrinterPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
}
};
template
<
typename
Callback
>
template
<
typename
Callback
>
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
static
inline
void
IterAllVar
(
const
ir
::
Graph
&
graph
,
Callback
callback
)
{
for
(
auto
&
each
:
graph
.
Get
<
GraphVars
>
(
kGraphVars
))
{
for
(
auto
&
each
:
graph
.
Get
<
details
::
GraphVars
>
(
details
::
kGraphVars
))
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair1
:
each
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
for
(
auto
&
pair2
:
pair1
.
second
)
{
callback
(
*
pair2
);
callback
(
*
pair2
);
...
@@ -33,7 +43,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
...
@@ -33,7 +43,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
}
}
for
(
auto
&
var
:
graph
.
Get
<
GraphDepVars
>
(
kGraphDepVars
))
{
for
(
auto
&
var
:
graph
.
Get
<
details
::
GraphDepVars
>
(
details
::
kGraphDepVars
))
{
callback
(
*
var
);
callback
(
*
var
);
}
}
}
}
...
@@ -41,14 +51,14 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
...
@@ -41,14 +51,14 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
void
GraphvizSSAGraphPrinter
::
Print
(
const
ir
::
Graph
&
graph
,
void
GraphvizSSAGraphPrinter
::
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
{
std
::
ostream
&
sout
)
const
{
size_t
var_id
=
0
;
size_t
var_id
=
0
;
std
::
unordered_map
<
const
VarHandleBase
*
,
size_t
>
vars
;
std
::
unordered_map
<
const
details
::
VarHandleBase
*
,
size_t
>
vars
;
sout
<<
"digraph G {
\n
"
;
sout
<<
"digraph G {
\n
"
;
IterAllVar
(
graph
,
[
&
](
const
VarHandleBase
&
var
)
{
IterAllVar
(
graph
,
[
&
](
const
details
::
VarHandleBase
&
var
)
{
auto
*
var_ptr
=
&
var
;
auto
*
var_ptr
=
&
var
;
auto
*
var_handle_ptr
=
dynamic_cast
<
const
VarHandle
*>
(
var_ptr
);
auto
*
var_handle_ptr
=
dynamic_cast
<
const
details
::
VarHandle
*>
(
var_ptr
);
auto
*
dummy_ptr
=
dynamic_cast
<
const
DummyVarHandle
*>
(
var_ptr
);
auto
*
dummy_ptr
=
dynamic_cast
<
const
details
::
DummyVarHandle
*>
(
var_ptr
);
size_t
cur_var_id
=
var_id
++
;
size_t
cur_var_id
=
var_id
++
;
vars
[
var_ptr
]
=
cur_var_id
;
vars
[
var_ptr
]
=
cur_var_id
;
...
@@ -65,7 +75,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
...
@@ -65,7 +75,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
});
size_t
op_id
=
0
;
size_t
op_id
=
0
;
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
graph
))
{
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
details
::
OpHandleBase
>
(
graph
))
{
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
std
::
string
op_name
=
"op_"
+
std
::
to_string
(
op_id
++
);
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
sout
<<
op_name
<<
" [label=
\"
"
<<
op
->
Name
()
<<
"
\"
, shape=rect]"
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -82,10 +92,10 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
...
@@ -82,10 +92,10 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
sout
<<
"}
\n
"
;
sout
<<
"}
\n
"
;
}
}
}
// namespace
details
}
// namespace
ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
REGISTER_PASS
(
multi_devices_print_pass
,
REGISTER_PASS
(
multi_devices_print_pass
,
paddle
::
framework
::
details
::
SSAGraghBuilderWithPrinter
)
paddle
::
framework
::
ir
::
SSAGraghBuilderWithPrinterPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGraphvizPath
);
.
RequirePassAttr
(
paddle
::
framework
::
ir
::
kGraphvizPath
);
paddle/fluid/framework/
detail
s/multi_devices_graph_print_pass.h
→
paddle/fluid/framework/
ir/multi_devices_graph_pas
s/multi_devices_graph_print_pass.h
浏览文件 @
04bd413a
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
ir
{
constexpr
char
kGraphvizPath
[]
=
"debug_graphviz_path"
;
constexpr
char
kGraphvizPath
[]
=
"debug_graphviz_path"
;
...
@@ -39,16 +39,6 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
...
@@ -39,16 +39,6 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
void
Print
(
const
ir
::
Graph
&
graph
,
std
::
ostream
&
sout
)
const
override
;
};
};
class
SSAGraghBuilderWithPrinter
:
public
ir
::
Pass
{
}
// namespace ir
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
std
::
unique_ptr
<
std
::
ostream
>
fout
(
new
std
::
ofstream
(
Get
<
std
::
string
>
(
kGraphvizPath
)));
PADDLE_ENFORCE
(
fout
->
good
());
Get
<
GraphvizSSAGraphPrinter
>
(
"graph_printer"
).
Print
(
*
graph
,
*
fout
);
}
};
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/ir/multi_devices_graph_pass/sequential_execution_pass.cc
0 → 100644
浏览文件 @
04bd413a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimize_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
bool
IsSameOpDesc
(
OpDesc
*
op1
,
OpDesc
*
op2
)
{
return
op1
->
Type
()
==
op2
->
Type
()
&&
op1
->
Inputs
()
==
op2
->
Inputs
()
&&
op1
->
Outputs
()
==
op2
->
Outputs
();
}
class
SequentialExecutionPass
:
public
ir
::
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static
std
::
unordered_set
<
std
::
string
>
skip_dist_ops
{
"send"
,
"recv"
,
"send_barrier"
,
"fetch_barrier"
};
auto
&
ops
=
graph
->
Get
<
const
std
::
vector
<
OpDesc
*>>
(
details
::
kStaleProgramOpDescs
);
std
::
vector
<
ir
::
Node
*>
op_node_list
;
op_node_list
.
reserve
(
ops
.
size
());
std
::
unordered_map
<
ir
::
Node
*
,
size_t
>
op_deps
;
std
::
unordered_map
<
ir
::
Node
*
,
std
::
unordered_set
<
ir
::
Node
*>>
pending_ops
;
std
::
unordered_set
<
ir
::
Node
*>
ready_ops
;
for
(
ir
::
Node
*
node
:
graph
->
Nodes
())
{
if
(
!
node
->
IsOp
())
continue
;
std
::
unordered_set
<
ir
::
Node
*>
preceding_ops
;
for
(
auto
*
in
:
node
->
inputs
)
{
PADDLE_ENFORCE
(
in
->
IsVar
(),
"Preceding Node of Op Nodes must be Var Node"
);
if
(
in
->
inputs
.
empty
())
continue
;
PADDLE_ENFORCE
(
in
->
inputs
.
size
()
==
1
&&
in
->
inputs
[
0
]
->
IsOp
(),
"Preceding Op Node of Var Node must be unique"
);
preceding_ops
.
insert
(
in
->
inputs
[
0
]);
pending_ops
[
in
->
inputs
[
0
]].
insert
(
node
);
}
op_deps
[
node
]
=
preceding_ops
.
size
();
if
(
preceding_ops
.
empty
())
{
ready_ops
.
insert
(
node
);
}
}
for
(
auto
*
op_desc
:
ops
)
{
ir
::
Node
*
found_node
=
nullptr
;
for
(
auto
*
node
:
ready_ops
)
{
if
(
IsSameOpDesc
(
op_desc
,
node
->
Op
()))
{
PADDLE_ENFORCE
(
found_node
==
nullptr
,
"Found multiple op_desc in graph: %s"
,
op_desc
->
Type
());
found_node
=
node
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
found_node
,
"Cannot find op_desc in graph: %s"
,
op_desc
->
Type
());
for
(
auto
*
pending_op
:
pending_ops
[
found_node
])
{
if
(
--
op_deps
.
at
(
pending_op
)
==
0
)
{
ready_ops
.
insert
(
pending_op
);
}
}
ready_ops
.
erase
(
found_node
);
if
(
skip_dist_ops
.
count
(
op_desc
->
Type
())
==
0
)
{
op_node_list
.
push_back
(
found_node
);
}
}
for
(
size_t
i
=
1
;
i
<
op_node_list
.
size
();
++
i
)
{
auto
*
dep_var
=
graph
->
CreateControlDepVar
();
op_node_list
[
i
]
->
inputs
.
push_back
(
dep_var
);
op_node_list
[
i
-
1
]
->
outputs
.
push_back
(
dep_var
);
dep_var
->
outputs
.
push_back
(
op_node_list
[
i
]);
dep_var
->
inputs
.
push_back
(
op_node_list
[
i
-
1
]);
VLOG
(
10
)
<<
"Add dependencies between "
<<
op_node_list
[
i
-
1
]
->
Name
()
<<
" and "
<<
op_node_list
[
i
]
->
Name
();
}
}
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
sequential_execution_pass
,
paddle
::
framework
::
ir
::
SequentialExecutionPass
)
.
RequireGraphAttr
(
paddle
::
framework
::
details
::
kStaleProgramOpDescs
);
paddle/fluid/framework/ir/sync_batch_norm_pass.cc
浏览文件 @
04bd413a
...
@@ -12,30 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,30 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <utility>
#include <utility>
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
void
SyncBatchNormPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
class
SyncBatchNormPass
:
public
Pass
{
VLOG
(
3
)
<<
"Use synchronous batch norm"
;
protected:
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
{
if
(
n
->
IsOp
())
{
VLOG
(
3
)
<<
"Use synchronous batch norm"
;
auto
*
op
=
n
->
Op
();
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
if
(
op
->
Type
()
==
"batch_norm"
)
{
if
(
n
->
IsOp
())
{
op
->
SetType
(
"sync_batch_norm"
);
auto
*
op
=
n
->
Op
();
}
if
(
op
->
Type
()
==
"batch_norm"
)
{
if
(
op
->
Type
()
==
"batch_norm_grad"
)
{
op
->
SetType
(
"sync_batch_norm"
);
op
->
SetType
(
"sync_batch_norm_grad"
);
}
if
(
op
->
Type
()
==
"batch_norm_grad"
)
{
op
->
SetType
(
"sync_batch_norm_grad"
);
}
}
}
}
}
}
}
}
};
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/ir/sync_batch_norm_pass.h
已删除
100644 → 0
浏览文件 @
648320bb
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
SyncBatchNormPass
:
public
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc
浏览文件 @
04bd413a
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
ir
{
namespace
ir
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
04bd413a
...
@@ -23,11 +23,11 @@ limitations under the License. */
...
@@ -23,11 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef WITH_GPERFTOOLS
#ifdef WITH_GPERFTOOLS
...
@@ -110,9 +110,9 @@ class ParallelExecutorPrivate {
...
@@ -110,9 +110,9 @@ class ParallelExecutorPrivate {
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std
::
vector
<
details
::
ReferenceCountMap
>
global_ref_cnts_
;
std
::
vector
<
ir
::
ReferenceCountMap
>
global_ref_cnts_
;
std
::
vector
<
details
::
AtomicReferenceCountMap
>
runtime_ref_cnts_
;
std
::
vector
<
ir
::
AtomicReferenceCountMap
>
runtime_ref_cnts_
;
details
::
GarbageCollectorMap
gcs_
;
ir
::
GarbageCollectorMap
gcs_
;
};
};
ir
::
Graph
*
ParallelExecutorPrivate
::
PrepareGCAndRefCnts
(
ir
::
Graph
*
ParallelExecutorPrivate
::
PrepareGCAndRefCnts
(
...
@@ -150,25 +150,23 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
...
@@ -150,25 +150,23 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
}
}
if
(
!
gcs_
.
empty
())
{
if
(
!
gcs_
.
empty
())
{
std
::
vector
<
details
::
LastLiveOpsOfVars
>
last_live_ops_of_vars
;
std
::
vector
<
ir
::
LastLiveOpsOfVars
>
last_live_ops_of_vars
;
auto
ref_cnt_pass
=
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
ref_cnt_pass
->
SetNotOwned
(
ir
::
kGlobalReferenceCount
,
&
global_ref_cnts_
);
&
global_ref_cnts_
);
ref_cnt_pass
->
SetNotOwned
(
ir
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
ref_cnt_pass
->
Apply
(
graph
);
graph
=
ref_cnt_pass
->
Apply
(
graph
);
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
auto
eager_deletion_pass
=
auto
eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kRuntimeReferenceCount
,
eager_deletion_pass
->
SetNotOwned
(
ir
::
kRuntimeReferenceCount
,
&
runtime_ref_cnts_
);
&
runtime_ref_cnts_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
eager_deletion_pass
->
SetNotOwned
(
ir
::
kGarbageCollector
,
&
gcs_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
eager_deletion_pass
->
SetNotOwned
(
ir
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
&
last_live_ops_of_vars
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
eager_deletion_pass
->
SetNotOwned
(
ir
::
kAllPlaces
,
&
places_
);
graph
=
eager_deletion_pass
->
Apply
(
graph
);
graph
=
eager_deletion_pass
->
Apply
(
graph
);
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
}
}
...
...
paddle/fluid/pybind/const_value.cc
浏览文件 @
04bd413a
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/framework/
detail
s/memory_optimize_pass.h"
#include "paddle/fluid/framework/
ir/memory_optimize_pas
s/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
...
@@ -34,7 +34,7 @@ void BindConstValue(pybind11::module* m) {
...
@@ -34,7 +34,7 @@ void BindConstValue(pybind11::module* m) {
m
->
def
(
"kControlDepVarName"
,
m
->
def
(
"kControlDepVarName"
,
[]
{
return
framework
::
ir
::
Node
::
kControlDepVarName
;
});
[]
{
return
framework
::
ir
::
Node
::
kControlDepVarName
;
});
m
->
def
(
"kNewGradSuffix"
,
[]
{
return
framework
::
kNewGradSuffix
;
});
m
->
def
(
"kNewGradSuffix"
,
[]
{
return
framework
::
kNewGradSuffix
;
});
m
->
def
(
"kMemOptSkipVars"
,
[]
{
return
framework
::
details
::
kMemOptSkipVars
;
});
m
->
def
(
"kMemOptSkipVars"
,
[]
{
return
framework
::
ir
::
kMemOptSkipVars
;
});
auto
op_proto_and_checker_maker
=
auto
op_proto_and_checker_maker
=
m
->
def_submodule
(
"op_proto_and_checker_maker"
);
m
->
def_submodule
(
"op_proto_and_checker_maker"
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
04bd413a
...
@@ -21,11 +21,11 @@ limitations under the License. */
...
@@ -21,11 +21,11 @@ limitations under the License. */
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -170,9 +170,9 @@ PYBIND11_MODULE(core, m) {
...
@@ -170,9 +170,9 @@ PYBIND11_MODULE(core, m) {
m
.
def
(
"_set_eager_deletion_mode"
,
&
paddle
::
framework
::
SetEagerDeletionMode
);
m
.
def
(
"_set_eager_deletion_mode"
,
&
paddle
::
framework
::
SetEagerDeletionMode
);
m
.
def
(
"_set_fuse_parameter_group_size"
,
m
.
def
(
"_set_fuse_parameter_group_size"
,
&
paddle
::
framework
::
details
::
SetFuseParameterGroupsSize
);
&
paddle
::
framework
::
ir
::
SetFuseParameterGroupsSize
);
m
.
def
(
"_set_fuse_parameter_memory_size"
,
m
.
def
(
"_set_fuse_parameter_memory_size"
,
&
paddle
::
framework
::
details
::
SetFuseParameterMemorySize
);
&
paddle
::
framework
::
ir
::
SetFuseParameterMemorySize
);
m
.
add_object
(
"_cleanup"
,
m
.
add_object
(
"_cleanup"
,
py
::
capsule
([]()
{
ScopePool
::
Instance
().
Clear
();
}));
py
::
capsule
([]()
{
ScopePool
::
Instance
().
Clear
();
}));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录