Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
472f16b5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
472f16b5
编写于
3月 11, 2019
作者:
Z
Zeng Jinle
提交者:
GitHub
3月 11, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16063 from sneaxiy/enhance_gc
Enhance gc
上级
e31f6e98
732fa00e
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
847 addition
and
59 deletion
+847
-59
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-1
paddle/fluid/framework/details/computation_op_handle.h
paddle/fluid/framework/details/computation_op_handle.h
+3
-0
paddle/fluid/framework/details/eager_deletion_op_handle.cc
paddle/fluid/framework/details/eager_deletion_op_handle.cc
+12
-2
paddle/fluid/framework/details/eager_deletion_pass.cc
paddle/fluid/framework/details/eager_deletion_pass.cc
+166
-5
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+4
-9
paddle/fluid/framework/details/reference_count_pass_helper.cc
...le/fluid/framework/details/reference_count_pass_helper.cc
+14
-1
paddle/fluid/framework/details/reference_count_pass_helper.h
paddle/fluid/framework/details/reference_count_pass_helper.h
+8
-1
paddle/fluid/framework/details/while_op_eager_deletion_pass.cc
...e/fluid/framework/details/while_op_eager_deletion_pass.cc
+62
-0
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+31
-15
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+13
-4
paddle/fluid/operators/controlflow/CMakeLists.txt
paddle/fluid/operators/controlflow/CMakeLists.txt
+1
-0
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+1
-8
paddle/fluid/operators/controlflow/while_op_helper.cc
paddle/fluid/operators/controlflow/while_op_helper.cc
+291
-0
paddle/fluid/operators/controlflow/while_op_helper.h
paddle/fluid/operators/controlflow/while_op_helper.h
+43
-0
paddle/fluid/operators/recurrent_op.cc
paddle/fluid/operators/recurrent_op.cc
+6
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-2
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+5
-5
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+1
-1
python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py
.../fluid/tests/unittests/test_eager_deletion_transformer.py
+1
-2
python/paddle/fluid/tests/unittests/test_eager_deletion_while_op.py
...dle/fluid/tests/unittests/test_eager_deletion_while_op.py
+153
-0
python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py
...ests/unittests/test_partial_eager_deletion_transformer.py
+25
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
472f16b5
...
@@ -174,7 +174,7 @@ else()
...
@@ -174,7 +174,7 @@ else()
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
endif
()
endif
()
target_link_libraries
(
executor garbage_collector
)
target_link_libraries
(
executor garbage_collector
while_op_helper
)
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
472f16b5
...
@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
...
@@ -61,7 +61,8 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass
)
cc_library
(
while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle
)
cc_library
(
eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
cc_library
(
sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass
)
...
...
paddle/fluid/framework/details/computation_op_handle.h
浏览文件 @
472f16b5
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <memory>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase {
...
@@ -31,6 +32,8 @@ class ComputationOpHandle : public OpHandleBase {
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
,
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
,
size_t
scope_idx
);
size_t
scope_idx
);
OperatorBase
*
GetOp
()
{
return
op_
.
get
();
}
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
const
Scope
*
GetScope
()
const
{
return
scope_
;
}
const
Scope
*
GetScope
()
const
{
return
scope_
;
}
...
...
paddle/fluid/framework/details/eager_deletion_op_handle.cc
浏览文件 @
472f16b5
...
@@ -12,6 +12,10 @@
...
@@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <memory>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
...
@@ -45,6 +49,7 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
}
}
}
}
#endif
#endif
PADDLE_ENFORCE
(
!
var_names_
.
empty
(),
"Var names cannot be empty"
);
}
}
EagerDeletionOpHandle
::~
EagerDeletionOpHandle
()
{
EagerDeletionOpHandle
::~
EagerDeletionOpHandle
()
{
...
@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
...
@@ -60,15 +65,20 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
std
::
string
EagerDeletionOpHandle
::
Name
()
const
{
return
"eager_deletion"
;
}
std
::
string
EagerDeletionOpHandle
::
Name
()
const
{
return
"eager_deletion"
;
}
void
EagerDeletionOpHandle
::
RunImpl
()
{
void
EagerDeletionOpHandle
::
RunImpl
()
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
()
;
Scope
*
exec_scope
=
nullptr
;
std
::
deque
<
std
::
shared_ptr
<
memory
::
Allocation
>>
garbages
;
std
::
deque
<
std
::
shared_ptr
<
memory
::
Allocation
>>
garbages
;
for
(
auto
&
name
:
var_names_
)
{
for
(
auto
&
name
:
var_names_
)
{
auto
it
=
ref_cnts_
->
find
(
name
);
auto
it
=
ref_cnts_
->
find
(
name
);
//
Var not found, not r
eference count has not decreased to 0
//
R
eference count has not decreased to 0
if
(
it
==
ref_cnts_
->
end
()
||
it
->
second
.
fetch_sub
(
1
)
!=
1
)
{
if
(
it
==
ref_cnts_
->
end
()
||
it
->
second
.
fetch_sub
(
1
)
!=
1
)
{
continue
;
continue
;
}
}
if
(
!
exec_scope
)
{
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
}
// Var not found
auto
*
var
=
exec_scope
->
FindVar
(
name
);
auto
*
var
=
exec_scope
->
FindVar
(
name
);
if
(
var
==
nullptr
)
{
if
(
var
==
nullptr
)
{
continue
;
continue
;
...
...
paddle/fluid/framework/details/eager_deletion_pass.cc
浏览文件 @
472f16b5
...
@@ -12,20 +12,173 @@
...
@@ -12,20 +12,173 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <algorithm>
#include <functional>
#include <queue>
#include <queue>
#include <string>
#include <string>
#include <tuple>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
DEFINE_double
(
memory_fraction_of_eager_deletion
,
1.0
,
"Fraction of eager deletion. If less than 1.0, all variables in "
"the program would be sorted according to its memory size, and "
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
"variables would be deleted."
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
// op -> variables which can be deleted after op runs
using
OpToVarNameSetMap
=
std
::
unordered_map
<
ComputationOpHandle
*
,
std
::
unordered_set
<
std
::
string
>>
;
// Check whether the variable is LoDTensor based on static VarDesc info
static
bool
IsLoDTensor
(
VarDesc
*
var
)
{
return
var
->
Proto
()
->
type
().
type
()
==
proto
::
VarType
::
LOD_TENSOR
;
}
// Get memory size of LoDTensor
static
int64_t
GetMemorySize
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandle
*>>
&
vars
,
const
std
::
string
&
var_name
)
{
auto
*
var_desc
=
TryGetLatestVarDesc
(
vars
.
at
(
var_name
));
PADDLE_ENFORCE_NOT_NULL
(
var_desc
);
PADDLE_ENFORCE
(
IsLoDTensor
(
var_desc
));
auto
dims
=
var_desc
->
GetShape
();
return
SizeOfType
(
var_desc
->
GetDataType
())
*
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
static_cast
<
int64_t
>
(
1
),
std
::
multiplies
<
int64_t
>
());
}
// Split all variables in the graph into LoDTensor and Non-LoDTensor (e.g.
// SelectedRows, LoDTensorArray)
// Since partial GC is based on static analysis of memory size of each variable
// So we should skip SelectedRows and LoDTensorArray here
static
void
SplitIntoLoDTensorAndNonLoDTensorVars
(
const
OpToVarNameSetMap
&
m
,
const
GraphVars
&
vars
,
OpToVarNameSetMap
*
lod_tensors
,
OpToVarNameSetMap
*
other_vars
)
{
lod_tensors
->
clear
();
other_vars
->
clear
();
for
(
auto
&
op_vars_pair
:
m
)
{
for
(
auto
&
var_name
:
op_vars_pair
.
second
)
{
auto
*
var_desc
=
TryGetLatestVarDesc
(
vars
[
op_vars_pair
.
first
->
GetScopeIdx
()].
at
(
var_name
));
if
(
IsLoDTensor
(
var_desc
))
{
(
*
lod_tensors
)[
op_vars_pair
.
first
].
insert
(
var_name
);
}
else
{
(
*
other_vars
)[
op_vars_pair
.
first
].
insert
(
var_name
);
}
}
}
}
struct
GCVarInfo
{
GCVarInfo
(
const
std
::
string
&
name
,
int64_t
memory_size
,
ComputationOpHandle
*
op
,
size_t
scope_idx
)
:
name_
(
name
),
memory_size_
(
memory_size
),
op_
(
op
),
scope_idx_
(
scope_idx
)
{}
std
::
string
name_
;
// variable name
int64_t
memory_size_
;
// memory size
ComputationOpHandle
*
op_
;
// op after which the variable could be deleted
size_t
scope_idx_
;
// scope index where the variable locates
int64_t
AbsMemorySize
()
const
{
return
std
::
abs
(
memory_size_
);
}
};
// Delete delete_lod_tensor_only is not used currently
static
OpToVarNameSetMap
ShrinkGCVars
(
const
OpToVarNameSetMap
&
m
,
const
GraphVars
&
vars
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
double
fraction_of_memory_size
,
bool
delete_lod_tensor_only
=
false
)
{
// Do not perform gc when fraction_of_memory_size = 0
if
(
fraction_of_memory_size
<=
0.0
)
return
{};
/**
* Step 1: Split all variables into LoDTensor and Non-LoDTensor.
* We can only calculate memory size of LoDTensors
*/
OpToVarNameSetMap
lod_tensors
,
other_vars
;
SplitIntoLoDTensorAndNonLoDTensorVars
(
m
,
vars
,
&
lod_tensors
,
&
other_vars
);
// Perform complete gc when fraction_of_memory_size >= 1
if
(
fraction_of_memory_size
>=
1.0
)
{
return
delete_lod_tensor_only
?
lod_tensors
:
m
;
}
/**
* Step 2: build GCVarInfos, and calculate total memory sizes of each device
*/
// place -> variable info (name, memory size, place, scope_idx)
std
::
map
<
platform
::
Place
,
std
::
vector
<
GCVarInfo
>>
place_to_vars
;
// place -> total memory sizes
std
::
map
<
platform
::
Place
,
int64_t
>
place_to_size
;
for
(
auto
&
op_vars_pair
:
lod_tensors
)
{
auto
*
op
=
op_vars_pair
.
first
;
auto
&
var_names
=
op_vars_pair
.
second
;
auto
scope_idx
=
op
->
GetScopeIdx
();
auto
&
place
=
places
[
scope_idx
];
for
(
auto
&
var_name
:
var_names
)
{
auto
var_size
=
GetMemorySize
(
vars
[
scope_idx
],
var_name
);
GCVarInfo
var_info
(
var_name
,
var_size
,
op
,
scope_idx
);
place_to_size
[
place
]
+=
var_info
.
AbsMemorySize
();
place_to_vars
[
place
].
emplace_back
(
std
::
move
(
var_info
));
}
}
/**
* Step 3: sort GCVarInfos, and only delete the largest variables.
*/
OpToVarNameSetMap
partial_vars
;
for
(
auto
&
place_to_var_pair
:
place_to_vars
)
{
auto
&
place
=
place_to_var_pair
.
first
;
auto
&
gc_vars
=
place_to_var_pair
.
second
;
std
::
sort
(
gc_vars
.
begin
(),
gc_vars
.
end
(),
[](
const
GCVarInfo
&
var1
,
const
GCVarInfo
&
var2
)
{
return
var1
.
AbsMemorySize
()
>
var2
.
AbsMemorySize
();
});
int64_t
accumulated_size
=
0
;
int64_t
size_threshold
=
static_cast
<
int64_t
>
(
fraction_of_memory_size
*
place_to_size
[
place
]);
for
(
size_t
i
=
0
;
i
<
gc_vars
.
size
()
&&
accumulated_size
<
size_threshold
;
++
i
)
{
partial_vars
[
gc_vars
[
i
].
op_
].
insert
(
gc_vars
[
i
].
name_
);
accumulated_size
+=
gc_vars
[
i
].
AbsMemorySize
();
}
}
/**
* Step 4: Combine other vars (SelectedRows, LoDTensorArray)
*/
if
(
!
delete_lod_tensor_only
)
{
for
(
auto
&
op_vars_pair
:
other_vars
)
{
partial_vars
[
op_vars_pair
.
first
].
insert
(
op_vars_pair
.
second
.
begin
(),
op_vars_pair
.
second
.
end
());
}
}
return
partial_vars
;
}
class
EagerDeletionPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
std
::
unique_ptr
<
ir
::
Graph
>
EagerDeletionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
EagerDeletionPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
ref_cnts
=
auto
&
ref_cnts
=
...
@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
...
@@ -43,9 +196,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
// a reverse map of last_live_ops
// a reverse map of last_live_ops
// i.e., last op --> variable names which can be deleted.
// i.e., last op --> variable names which can be deleted.
std
::
unordered_map
<
ComputationOpHandle
*
,
std
::
unordered_set
<
std
::
string
>>
OpToVarNameSetMap
op_vars_map
;
op_vars_map
;
for
(
auto
&
var_ops_map
:
last_live_ops
)
{
for
(
auto
&
var_ops_map
:
last_live_ops
)
{
for
(
auto
&
var_ops_pair
:
var_ops_map
)
{
for
(
auto
&
var_ops_pair
:
var_ops_map
)
{
const
std
::
string
&
var_name
=
var_ops_pair
.
first
;
const
std
::
string
&
var_name
=
var_ops_pair
.
first
;
...
@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
...
@@ -55,6 +206,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
}
}
}
}
op_vars_map
=
ShrinkGCVars
(
op_vars_map
,
vars
,
places
,
FLAGS_memory_fraction_of_eager_deletion
);
for
(
auto
&
pair
:
op_vars_map
)
{
for
(
auto
&
pair
:
op_vars_map
)
{
auto
*
op
=
pair
.
first
;
auto
*
op
=
pair
.
first
;
auto
&
var_names
=
pair
.
second
;
auto
&
var_names
=
pair
.
second
;
...
@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
...
@@ -85,8 +239,13 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
eager_deletion_op
->
AddOutput
(
dummy_leaf
);
eager_deletion_op
->
AddOutput
(
dummy_leaf
);
}
}
VLOG
(
10
)
<<
"FLAGS_memory_fraction_of_eager_deletion = "
<<
FLAGS_memory_fraction_of_eager_deletion
;
VLOG
(
10
)
<<
"Create "
<<
op_vars_map
.
size
()
<<
" EagerDeletionOpHandle(s)"
;
VLOG
(
10
)
<<
"Create "
<<
op_vars_map
.
size
()
<<
" EagerDeletionOpHandle(s)"
;
return
graph
;
auto
while_op_eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"while_op_eager_deletion_pass"
);
return
while_op_eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
}
}
}
// namespace details
}
// namespace details
...
@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass,
...
@@ -99,3 +258,5 @@ REGISTER_PASS(eager_deletion_pass,
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kAllPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kAllPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
USE_PASS
(
while_op_eager_deletion_pass
);
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
472f16b5
...
@@ -12,9 +12,13 @@
...
@@ -12,9 +12,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <memory>
#include <queue>
#include <queue>
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
...
@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
...
@@ -189,15 +193,6 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
return
shrink_func
(
computation_op
);
return
shrink_func
(
computation_op
);
}
}
static
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
VarHandle
*>
&
vars
)
{
VarDesc
*
var_desc
=
nullptr
;
std
::
find_if
(
vars
.
rbegin
(),
vars
.
rend
(),
[
&
](
VarHandle
*
var_handle
)
->
bool
{
var_desc
=
var_handle
->
Node
()
->
Var
();
return
var_desc
!=
nullptr
;
});
return
var_desc
;
}
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
ref_cnts
=
Get
<
std
::
vector
<
ReferenceCountMap
>>
(
kGlobalReferenceCount
);
auto
&
ref_cnts
=
Get
<
std
::
vector
<
ReferenceCountMap
>>
(
kGlobalReferenceCount
);
...
...
paddle/fluid/framework/details/reference_count_pass_helper.cc
浏览文件 @
472f16b5
...
@@ -13,9 +13,22 @@
...
@@ -13,9 +13,22 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/var_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
namespace
details
{}
// namespace details
namespace
details
{
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
VarHandle
*>
&
vars
)
{
VarDesc
*
var_desc
=
nullptr
;
std
::
find_if
(
vars
.
rbegin
(),
vars
.
rend
(),
[
&
](
VarHandle
*
var_handle
)
->
bool
{
var_desc
=
var_handle
->
Node
()
->
Var
();
return
var_desc
!=
nullptr
;
});
return
var_desc
;
}
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/reference_count_pass_helper.h
浏览文件 @
472f16b5
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <atomic>
#include <atomic>
#include <map>
#include <map>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
...
@@ -25,6 +26,10 @@
...
@@ -25,6 +26,10 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
VarDesc
;
class
VarHandle
;
namespace
details
{
namespace
details
{
class
ComputationOpHandle
;
class
ComputationOpHandle
;
...
@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector";
...
@@ -43,9 +48,11 @@ const char kGarbageCollector[] = "garbage_collector";
const
char
kAllPlaces
[]
=
"all_places"
;
const
char
kAllPlaces
[]
=
"all_places"
;
using
LastLiveOpsOfVars
=
using
LastLiveOpsOfVars
=
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ComputationOpHandle
*>>
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ComputationOpHandle
*>>
;
const
char
kLastLiveOpsOfVars
[]
=
"last_live_ops_of_var"
;
const
char
kLastLiveOpsOfVars
[]
=
"last_live_ops_of_var"
;
VarDesc
*
TryGetLatestVarDesc
(
const
std
::
vector
<
VarHandle
*>
&
vars
);
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/details/while_op_eager_deletion_pass.cc
0 → 100644
浏览文件 @
472f16b5
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
WhileOpEagerDeletionPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
{
auto
all_ops
=
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
);
// Find all while_op and while_grad_op
std
::
unordered_map
<
size_t
,
std
::
pair
<
std
::
vector
<
OperatorBase
*>
,
std
::
vector
<
OperatorBase
*>>>
target_ops
;
for
(
auto
*
op
:
all_ops
)
{
auto
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
==
nullptr
)
continue
;
if
(
compute_op
->
Name
()
==
"while"
)
{
target_ops
[
compute_op
->
GetScopeIdx
()].
first
.
emplace_back
(
compute_op
->
GetOp
());
}
else
if
(
compute_op
->
Name
()
==
"while_grad"
)
{
target_ops
[
compute_op
->
GetScopeIdx
()].
second
.
emplace_back
(
compute_op
->
GetOp
());
}
}
for
(
auto
&
ops_pair
:
target_ops
)
{
auto
&
while_ops
=
ops_pair
.
second
.
first
;
auto
&
while_grad_ops
=
ops_pair
.
second
.
second
;
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
while_ops
,
while_grad_ops
);
}
return
graph
;
}
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
while_op_eager_deletion_pass
,
paddle
::
framework
::
details
::
WhileOpEagerDeletionPass
);
paddle/fluid/framework/executor.cc
浏览文件 @
472f16b5
...
@@ -14,6 +14,10 @@ limitations under the License. */
...
@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include <deque>
#include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_rank_table.h"
...
@@ -23,6 +27,7 @@ limitations under the License. */
...
@@ -23,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
...
@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
...
@@ -75,11 +80,11 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
ExecutorPrepareContext
::
ExecutorPrepareContext
(
ExecutorPrepareContext
::
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
,
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
)
const
std
::
vector
<
std
::
string
>&
keep_vars
,
bool
force_disable_gc
)
:
prog_
(
prog
),
block_id_
(
block_id
)
{
:
prog_
(
prog
),
block_id_
(
block_id
)
,
force_disable_gc_
(
force_disable_gc
)
{
if
(
GetEagerDeletionThreshold
()
>=
0
)
{
if
(
GetEagerDeletionThreshold
()
>=
0
&&
!
force_disable_gc_
)
{
global_ref_cnts_
=
GetNonPersistableReferenceCounts
(
prog
.
Block
(
block_id
),
global_ref_cnts_
=
skip_ref_cnt
_vars
);
GetNonPersistableReferenceCounts
(
prog
.
Block
(
block_id
),
keep
_vars
);
}
}
}
}
...
@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
...
@@ -184,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
}
}
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
bool
create_local_scope
,
bool
create_vars
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
,
bool
force_disable_gc
)
{
platform
::
RecordBlock
b
(
block_id
);
platform
::
RecordBlock
b
(
block_id
);
if
(
FLAGS_use_mkldnn
)
EnableMKLDNN
(
pdesc
);
if
(
FLAGS_use_mkldnn
)
EnableMKLDNN
(
pdesc
);
#ifdef PADDLE_WITH_NGRAPH
#ifdef PADDLE_WITH_NGRAPH
if
(
FLAGS_use_ngraph
)
operators
::
NgraphEngine
::
EnableNgraph
(
pdesc
);
if
(
FLAGS_use_ngraph
)
operators
::
NgraphEngine
::
EnableNgraph
(
pdesc
);
#endif
#endif
auto
ctx
=
Prepare
(
pdesc
,
block_id
);
auto
ctx
=
Prepare
(
pdesc
,
block_id
,
skip_ref_cnt_vars
,
force_disable_gc
);
RunPreparedContext
(
ctx
.
get
(),
scope
,
create_local_scope
,
create_vars
);
RunPreparedContext
(
ctx
.
get
(),
scope
,
create_local_scope
,
create_vars
);
}
}
...
@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
...
@@ -357,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std
::
unique_ptr
<
ExecutorPrepareContext
>
Executor
::
Prepare
(
std
::
unique_ptr
<
ExecutorPrepareContext
>
Executor
::
Prepare
(
const
ProgramDesc
&
program
,
int
block_id
,
const
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
)
{
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
,
bool
force_disable_gc
)
{
std
::
unique_ptr
<
ExecutorPrepareContext
>
ctx
(
std
::
unique_ptr
<
ExecutorPrepareContext
>
ctx
(
new
ExecutorPrepareContext
(
new
ExecutorPrepareContext
(
program
,
block_id
,
skip_ref_cnt_vars
));
program
,
block_id
,
skip_ref_cnt_vars
,
force_disable_gc
));
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
block_id
),
program
.
Size
());
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
block_id
),
program
.
Size
());
auto
&
block
=
program
.
Block
(
block_id
);
auto
&
block
=
program
.
Block
(
block_id
);
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
...
@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
...
@@ -370,7 +377,8 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Executor
::
Prepare
(
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Executor
::
Prepare
(
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
,
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
skip_ref_cnt_vars
)
{
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
skip_ref_cnt_vars
,
bool
force_disable_gc
)
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
skip_ref_cnt_vars
.
empty
()
||
skip_ref_cnt_vars
.
size
()
==
block_ids
.
size
(),
skip_ref_cnt_vars
.
empty
()
||
skip_ref_cnt_vars
.
size
()
==
block_ids
.
size
(),
"skip_ref_cnt_vars should be either empty or equals to block number %d"
,
"skip_ref_cnt_vars should be either empty or equals to block number %d"
,
...
@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
...
@@ -380,9 +388,11 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
for
(
auto
&
bid
:
block_ids
)
{
for
(
auto
&
bid
:
block_ids
)
{
ExecutorPrepareContext
*
ctx
;
ExecutorPrepareContext
*
ctx
;
if
(
skip_ref_cnt_vars
.
empty
())
{
if
(
skip_ref_cnt_vars
.
empty
())
{
ctx
=
new
ExecutorPrepareContext
(
program
,
bid
);
ctx
=
new
ExecutorPrepareContext
(
program
,
bid
,
std
::
vector
<
std
::
string
>
(),
force_disable_gc
);
}
else
{
}
else
{
ctx
=
new
ExecutorPrepareContext
(
program
,
bid
,
skip_ref_cnt_vars
[
idx
]);
ctx
=
new
ExecutorPrepareContext
(
program
,
bid
,
skip_ref_cnt_vars
[
idx
],
force_disable_gc
);
}
}
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
bid
),
program
.
Size
());
PADDLE_ENFORCE_LT
(
static_cast
<
size_t
>
(
bid
),
program
.
Size
());
auto
&
block
=
program
.
Block
(
bid
);
auto
&
block
=
program
.
Block
(
bid
);
...
@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
...
@@ -409,8 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
std
::
unique_ptr
<
GarbageCollector
>
gc
;
std
::
unique_ptr
<
GarbageCollector
>
gc
;
// skip while_op and while_grad_op temporarily
// FIXME(zjl): recurrent_op is rather complex, we would
if
(
max_memory_size
>=
0
&&
!
keep_kids
)
{
// disable gc forcely in recurrent_op
if
(
!
ctx
->
force_disable_gc_
&&
max_memory_size
>=
0
)
{
ctx
->
ResetReferenceCount
();
ctx
->
ResetReferenceCount
();
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
platform
::
is_gpu_place
(
place_
))
{
...
@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
...
@@ -428,6 +439,11 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
}
}
#endif
#endif
// If gc is enabled and block size > 1
if
(
gc
&&
ctx
->
prog_
.
Size
()
>
1
)
{
operators
::
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
ctx
->
block_id_
,
ctx
->
ops_
);
}
}
}
for
(
auto
&
op
:
ctx
->
ops_
)
{
for
(
auto
&
op
:
ctx
->
ops_
)
{
...
...
paddle/fluid/framework/executor.h
浏览文件 @
472f16b5
...
@@ -15,7 +15,9 @@ limitations under the License. */
...
@@ -15,7 +15,9 @@ limitations under the License. */
#pragma once
#pragma once
#include <map>
#include <map>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
...
@@ -30,7 +32,8 @@ namespace framework {
...
@@ -30,7 +32,8 @@ namespace framework {
struct
ExecutorPrepareContext
{
struct
ExecutorPrepareContext
{
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
,
ExecutorPrepareContext
(
const
framework
::
ProgramDesc
&
prog
,
size_t
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
=
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
=
std
::
vector
<
std
::
string
>
());
std
::
vector
<
std
::
string
>
(),
bool
force_disable_gc
=
false
);
~
ExecutorPrepareContext
();
~
ExecutorPrepareContext
();
...
@@ -38,6 +41,7 @@ struct ExecutorPrepareContext {
...
@@ -38,6 +41,7 @@ struct ExecutorPrepareContext {
const
framework
::
ProgramDesc
&
prog_
;
const
framework
::
ProgramDesc
&
prog_
;
size_t
block_id_
;
size_t
block_id_
;
bool
force_disable_gc_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
global_ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
global_ref_cnts_
;
...
@@ -66,7 +70,10 @@ class Executor {
...
@@ -66,7 +70,10 @@ class Executor {
* Scope
* Scope
*/
*/
void
Run
(
const
ProgramDesc
&
prog
,
Scope
*
scope
,
int
block_id
,
void
Run
(
const
ProgramDesc
&
prog
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
=
true
,
bool
create_vars
=
true
);
bool
create_local_scope
=
true
,
bool
create_vars
=
true
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
=
std
::
vector
<
std
::
string
>
(),
bool
force_disable_gc
=
false
);
// This API is very slow.
// This API is very slow.
void
Run
(
const
ProgramDesc
&
program
,
Scope
*
scope
,
void
Run
(
const
ProgramDesc
&
program
,
Scope
*
scope
,
...
@@ -79,12 +86,14 @@ class Executor {
...
@@ -79,12 +86,14 @@ class Executor {
static
std
::
unique_ptr
<
ExecutorPrepareContext
>
Prepare
(
static
std
::
unique_ptr
<
ExecutorPrepareContext
>
Prepare
(
const
ProgramDesc
&
program
,
int
block_id
,
const
ProgramDesc
&
program
,
int
block_id
,
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
=
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
=
std
::
vector
<
std
::
string
>
());
std
::
vector
<
std
::
string
>
(),
bool
force_disable_gc
=
false
);
static
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Prepare
(
static
std
::
vector
<
std
::
shared_ptr
<
ExecutorPrepareContext
>>
Prepare
(
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
,
const
ProgramDesc
&
program
,
const
std
::
vector
<
int
>&
block_ids
,
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
skip_ref_cnt_vars
=
const
std
::
vector
<
std
::
vector
<
std
::
string
>>&
skip_ref_cnt_vars
=
std
::
vector
<
std
::
vector
<
std
::
string
>>
());
std
::
vector
<
std
::
vector
<
std
::
string
>>
(),
bool
force_disable_gc
=
false
);
void
CreateVariables
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
);
void
CreateVariables
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
);
...
...
paddle/fluid/operators/controlflow/CMakeLists.txt
浏览文件 @
472f16b5
include
(
operators
)
include
(
operators
)
register_operators
(
DEPS naive_executor
)
register_operators
(
DEPS naive_executor
)
cc_library
(
while_op_helper SRCS while_op_helper.cc DEPS operator
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(logical_and);
\n
USE_NO_KERNEL_OP(read_from_array);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(less_than);
\n
USE_OP(logical_and);
\n
USE_NO_KERNEL_OP(read_from_array);
\n
"
)
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
472f16b5
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -26,14 +27,6 @@ namespace operators {
...
@@ -26,14 +27,6 @@ namespace operators {
using
StepScopeVar
=
std
::
vector
<
framework
::
Scope
*>
;
using
StepScopeVar
=
std
::
vector
<
framework
::
Scope
*>
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
static
constexpr
char
kStepBlock
[]
=
"sub_block"
;
static
constexpr
char
kCondition
[]
=
"Condition"
;
static
constexpr
char
kStepScopes
[]
=
"StepScopes"
;
static
constexpr
char
kX
[]
=
"X"
;
static
constexpr
char
kXGRAD
[]
=
"X@GRAD"
;
static
constexpr
char
kOutputs
[]
=
"Out"
;
static
constexpr
char
kSkipEagerDeletionVars
[]
=
"skip_eager_deletion_vars"
;
namespace
{
// NOLINT
namespace
{
// NOLINT
static
std
::
string
GetSkipEagerDeletionVarsDebugString
(
static
std
::
string
GetSkipEagerDeletionVarsDebugString
(
const
std
::
vector
<
std
::
string
>
&
vars
)
{
const
std
::
vector
<
std
::
string
>
&
vars
)
{
...
...
paddle/fluid/operators/controlflow/while_op_helper.cc
0 → 100644
浏览文件 @
472f16b5
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
operators
{
// OpVariant is a wrapper class of OpDesc and OperatorBase
// So that API would be the same.
class
OpVariant
{
struct
InputsVisitor
:
public
boost
::
static_visitor
<
const
framework
::
VariableNameMap
*>
{
template
<
typename
OpType
>
const
framework
::
VariableNameMap
*
operator
()(
const
OpType
*
op
)
const
{
return
&
(
op
->
Inputs
());
}
};
struct
OutputsVisitor
:
public
boost
::
static_visitor
<
const
framework
::
VariableNameMap
*>
{
template
<
typename
OpType
>
const
framework
::
VariableNameMap
*
operator
()(
const
OpType
*
op
)
const
{
return
&
(
op
->
Outputs
());
}
};
struct
AttributeMapVisitor
:
public
boost
::
static_visitor
<
const
framework
::
AttributeMap
*>
{
const
framework
::
AttributeMap
*
operator
()(
const
framework
::
OpDesc
*
op
)
const
{
return
&
(
op
->
GetAttrMap
());
}
const
framework
::
AttributeMap
*
operator
()(
const
framework
::
OperatorBase
*
op
)
const
{
return
&
(
op
->
Attrs
());
}
};
struct
RawPointerVisitor
:
public
boost
::
static_visitor
<
const
void
*>
{
template
<
typename
OpType
>
const
void
*
operator
()(
const
OpType
*
op
)
const
{
return
op
;
}
};
public:
OpVariant
(
const
framework
::
OperatorBase
*
op
)
:
op_
(
op
)
{}
// NOLINT
OpVariant
(
const
framework
::
OpDesc
*
op
)
:
op_
(
op
)
{}
// NOLINT
const
framework
::
VariableNameMap
&
Inputs
()
const
{
return
*
boost
::
apply_visitor
(
InputsVisitor
(),
op_
);
}
const
framework
::
VariableNameMap
&
Outputs
()
const
{
return
*
boost
::
apply_visitor
(
OutputsVisitor
(),
op_
);
}
const
framework
::
AttributeMap
&
Attrs
()
const
{
return
*
boost
::
apply_visitor
(
AttributeMapVisitor
(),
op_
);
}
template
<
typename
AttrType
>
const
AttrType
&
Attr
(
const
std
::
string
&
name
)
const
{
auto
&
attrs
=
Attrs
();
auto
it
=
attrs
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs
.
end
(),
"Cannot find attribute %s"
,
name
);
return
boost
::
get
<
AttrType
>
(
it
->
second
);
}
bool
operator
==
(
const
OpVariant
&
other
)
const
{
return
RawPointer
()
==
other
.
RawPointer
();
}
const
void
*
RawPointer
()
const
{
return
boost
::
apply_visitor
(
RawPointerVisitor
(),
op_
);
}
int
which
()
const
{
return
static_cast
<
int
>
(
op_
.
which
());
}
struct
Hasher
{
size_t
operator
()(
const
OpVariant
&
op
)
const
{
return
reinterpret_cast
<
size_t
>
(
op
.
RawPointer
());
}
};
private:
const
boost
::
variant
<
const
framework
::
OperatorBase
*
,
const
framework
::
OpDesc
*>
op_
;
};
static
std
::
string
GetDebugString
(
const
std
::
vector
<
std
::
string
>
&
names
)
{
if
(
names
.
empty
())
return
""
;
std
::
string
ret
=
names
[
0
];
for
(
size_t
i
=
1
;
i
<
names
.
size
();
++
i
)
{
ret
+=
(
" "
+
names
[
i
]);
}
return
ret
;
}
// Set skip variables of while_op and while_grad_op
// These variables should be skipped when eager deletion enables.
// It is because:
// 1. while_grad_op needs some variables defined in while_op.
// 2. while_grad_op needs variables from the previous time step.
static
void
SetSkipVars
(
const
OpVariant
&
op
,
std
::
vector
<
std
::
string
>
attr
)
{
auto
&
attrs
=
const_cast
<
framework
::
AttributeMap
&>
(
op
.
Attrs
());
VLOG
(
2
)
<<
"Prepare to skip "
<<
attr
.
size
()
<<
" var(s): "
<<
GetDebugString
(
attr
);
attrs
[
kSkipEagerDeletionVars
]
=
std
::
move
(
attr
);
}
// Check whether the forward while_op and while_grad_op match
// The program may have many while_ops.
static
bool
IsMatchedWhileOpAndWhileGradOp
(
const
OpVariant
&
fwd_op
,
const
OpVariant
&
grad_op
)
{
return
fwd_op
.
Inputs
().
at
(
kX
)
==
grad_op
.
Inputs
().
at
(
kX
)
&&
fwd_op
.
Outputs
().
at
(
kOutputs
)
==
grad_op
.
Inputs
().
at
(
kOutputs
);
}
// Test whether the variable is skippable in forward while_op
// The variable is skippable in while_op when the variable used in while_grad
// is not from grad_block.
static
bool
IsSkippableVar
(
const
std
::
string
&
name
,
framework
::
BlockDesc
*
grad_block
)
{
return
name
!=
framework
::
kEmptyVarName
&&
!
grad_block
->
HasVar
(
name
);
}
static
void
ModifyWhileOpAndWhileGradOpAttr
(
const
OpVariant
&
fwd_op
,
const
OpVariant
&
bwd_op
)
{
auto
*
grad_block
=
bwd_op
.
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
// Find all skippable variables in forward while_op
std
::
unordered_set
<
std
::
string
>
forward_skip_vars
;
for
(
auto
*
op_desc
:
grad_block
->
AllOps
())
{
for
(
auto
&
in_arg_name
:
op_desc
->
InputArgumentNames
())
{
if
(
IsSkippableVar
(
in_arg_name
,
grad_block
))
{
forward_skip_vars
.
insert
(
in_arg_name
);
}
}
for
(
auto
&
out_arg_name
:
op_desc
->
OutputArgumentNames
())
{
if
(
IsSkippableVar
(
out_arg_name
,
grad_block
))
{
forward_skip_vars
.
insert
(
out_arg_name
);
}
}
}
SetSkipVars
(
fwd_op
,
std
::
vector
<
std
::
string
>
(
forward_skip_vars
.
begin
(),
forward_skip_vars
.
end
()));
// Find all skippable variables in while_grad_op
// The skipped variables are those which would be used across time steps.
auto
&
fwd_input
=
fwd_op
.
Inputs
().
at
(
kX
);
auto
&
in_grads
=
bwd_op
.
Outputs
().
at
(
framework
::
GradVarName
(
kX
));
PADDLE_ENFORCE_EQ
(
fwd_input
.
size
(),
in_grads
.
size
(),
"Backward input gradient number does not match forward input number."
);
std
::
unordered_set
<
std
::
string
>
backward_skip_vars
;
for
(
size_t
i
=
0
;
i
<
in_grads
.
size
();
++
i
)
{
if
(
in_grads
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
backward_skip_vars
.
insert
(
in_grads
[
i
]);
backward_skip_vars
.
insert
(
framework
::
GradVarName
(
fwd_input
[
i
]));
}
SetSkipVars
(
bwd_op
,
std
::
vector
<
std
::
string
>
(
backward_skip_vars
.
begin
(),
backward_skip_vars
.
end
()));
}
// Find all while_ops and while_grad_ops in the graph or program
// The while_grad_op and while_op may located in different blocks
// So we should traverse all blocks in the program and find them out.
static
void
FindAllWhileAndWhileGradOp
(
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
());
if
(
while_ops
->
empty
())
return
;
const
auto
*
program
=
while_ops
->
front
().
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
)
->
Program
();
for
(
size_t
i
=
1
;
i
<
program
->
Size
();
++
i
)
{
auto
&
block
=
program
->
Block
(
i
);
for
(
size_t
j
=
0
;
j
<
block
.
OpSize
();
++
j
)
{
auto
*
op
=
block
.
Op
(
j
);
if
(
op
->
Type
()
==
"while"
)
{
while_ops
->
emplace_back
(
op
);
}
else
if
(
op
->
Type
()
==
"while_grad"
)
{
while_grad_ops
->
emplace_back
(
op
);
}
}
}
PADDLE_ENFORCE_GE
(
while_ops
->
size
(),
while_grad_ops
->
size
(),
"There are extra while_grad ops in the graph or program"
);
}
static
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
std
::
vector
<
OpVariant
>
*
while_ops
,
std
::
vector
<
OpVariant
>
*
while_grad_ops
)
{
FindAllWhileAndWhileGradOp
(
while_ops
,
while_grad_ops
);
VLOG
(
2
)
<<
"Found while op num: "
<<
while_ops
->
size
()
<<
", while grad op num: "
<<
while_grad_ops
->
size
();
if
(
while_grad_ops
->
empty
())
{
return
;
}
std
::
unordered_set
<
OpVariant
,
OpVariant
::
Hasher
>
while_op_set
(
while_ops
->
begin
(),
while_ops
->
end
());
for
(
auto
&
bwd_op
:
*
while_grad_ops
)
{
const
OpVariant
*
matched_fwd_op
=
nullptr
;
for
(
auto
&
fwd_op
:
while_op_set
)
{
if
(
IsMatchedWhileOpAndWhileGradOp
(
fwd_op
,
bwd_op
))
{
PADDLE_ENFORCE
(
matched_fwd_op
==
nullptr
,
"Found multiple matched while ops"
);
matched_fwd_op
=
&
fwd_op
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
matched_fwd_op
,
"Cannot find matched forward while op."
);
ModifyWhileOpAndWhileGradOpAttr
(
*
matched_fwd_op
,
bwd_op
);
while_op_set
.
erase
(
*
matched_fwd_op
);
}
}
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
)
{
// If block_id is not 0, returns
// This is because all while_ops and while_grad_ops in the whole program
// would be processed when block_id is 0 (i.e. when Executor::Run() or
// ParallelExecutor constructs).
// What's more, all while_ops and while_grad_ops must be processed when
// block_id is zero. If not, while_op may run first and erase variables
// used in while_grad_op, and in this moment, while_grad_ops may be not
// constructed yet.
if
(
block_id
!=
0
)
return
;
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
for
(
auto
&
op
:
all_ops
)
{
if
(
op
->
Type
()
==
"while"
)
{
fwd_ops
.
emplace_back
(
op
.
get
());
}
else
if
(
op
->
Type
()
==
"while_grad"
)
{
bwd_ops
.
emplace_back
(
op
.
get
());
}
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
&
fwd_ops
,
&
bwd_ops
);
}
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
)
{
std
::
vector
<
OpVariant
>
fwd_ops
,
bwd_ops
;
fwd_ops
.
reserve
(
while_ops
.
size
());
for
(
auto
*
op
:
while_ops
)
{
fwd_ops
.
emplace_back
(
op
);
}
bwd_ops
.
reserve
(
while_grad_ops
.
size
());
for
(
auto
*
op
:
while_grad_ops
)
{
bwd_ops
.
emplace_back
(
op
);
}
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl
(
&
fwd_ops
,
&
bwd_ops
);
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/
framework/details/eager_deletion_pass
.h
→
paddle/fluid/
operators/controlflow/while_op_helper
.h
浏览文件 @
472f16b5
// Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 201
9
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -14,19 +14,30 @@
...
@@ -14,19 +14,30 @@
#pragma once
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/variant.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
operators
{
namespace
details
{
class
EagerDeletionPass
:
public
ir
::
Pass
{
static
constexpr
char
kStepBlock
[]
=
"sub_block"
;
protected:
static
constexpr
char
kCondition
[]
=
"Condition"
;
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
static
constexpr
char
kStepScopes
[]
=
"StepScopes"
;
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
static
constexpr
char
kX
[]
=
"X"
;
};
static
constexpr
char
kXGRAD
[]
=
"X@GRAD"
;
static
constexpr
char
kOutputs
[]
=
"Out"
;
static
constexpr
char
kSkipEagerDeletionVars
[]
=
"skip_eager_deletion_vars"
;
}
// namespace details
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
}
// namespace framework
int
block_id
,
const
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
&
all_ops
);
void
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp
(
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_ops
,
const
std
::
vector
<
framework
::
OperatorBase
*>
&
while_grad_ops
);
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/recurrent_op.cc
浏览文件 @
472f16b5
...
@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase {
...
@@ -282,7 +282,9 @@ class RecurrentOp : public RecurrentBase {
// Every inputs are linked now, execute!
// Every inputs are linked now, execute!
executor
.
Run
(
*
program
,
&
cur_scope
,
block
->
ID
(),
executor
.
Run
(
*
program
,
&
cur_scope
,
block
->
ID
(),
false
/*create_local_scope*/
);
false
/*create_local_scope*/
,
true
/*create_vars*/
,
std
::
vector
<
std
::
string
>
()
/*skip_ref_cnt_vars*/
,
true
/*force_disable_gc*/
);
// get device context from pool
// get device context from pool
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
&
pool
=
...
@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase {
...
@@ -398,7 +400,9 @@ class RecurrentGradOp : public RecurrentBase {
VLOG
(
5
)
<<
"Recurrent memory linking finished "
;
VLOG
(
5
)
<<
"Recurrent memory linking finished "
;
// Run step block with cur_scope
// Run step block with cur_scope
executor
.
Run
(
*
program
,
&
cur_scope
,
block
->
ID
(),
executor
.
Run
(
*
program
,
&
cur_scope
,
block
->
ID
(),
false
/*create_local_scope*/
);
false
/*create_local_scope*/
,
true
/*create_vars*/
,
std
::
vector
<
std
::
string
>
()
/*skip_ref_cnt_vars*/
,
true
/*force_disable_gc*/
);
VLOG
(
5
)
<<
"executor.Run finished "
;
VLOG
(
5
)
<<
"executor.Run finished "
;
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
472f16b5
...
@@ -876,9 +876,11 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -876,9 +876,11 @@ All parameter, weight, gradient are variables in Paddle.
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
.
def
(
"close"
,
&
Executor
::
Close
)
.
def
(
"close"
,
&
Executor
::
Close
)
.
def
(
"run"
,
[](
Executor
&
self
,
const
ProgramDesc
&
prog
,
Scope
*
scope
,
.
def
(
"run"
,
[](
Executor
&
self
,
const
ProgramDesc
&
prog
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
int
block_id
,
bool
create_local_scope
,
bool
create_vars
,
const
std
::
vector
<
std
::
string
>
&
fetch_vars
)
{
pybind11
::
gil_scoped_release
release
;
pybind11
::
gil_scoped_release
release
;
self
.
Run
(
prog
,
scope
,
block_id
,
create_local_scope
,
create_vars
);
self
.
Run
(
prog
,
scope
,
block_id
,
create_local_scope
,
create_vars
,
fetch_vars
);
});
});
m
.
def
(
"init_gflags"
,
framework
::
InitGflags
);
m
.
def
(
"init_gflags"
,
framework
::
InitGflags
);
...
...
python/paddle/fluid/__init__.py
浏览文件 @
472f16b5
...
@@ -128,11 +128,11 @@ def __bootstrap__():
...
@@ -128,11 +128,11 @@ def __bootstrap__():
'check_nan_inf'
,
'benchmark'
,
'eager_delete_scope'
,
'use_ngraph'
,
'check_nan_inf'
,
'benchmark'
,
'eager_delete_scope'
,
'use_ngraph'
,
'initial_cpu_memory_in_mb'
,
'init_allocated_mem'
,
'free_idle_memory'
,
'initial_cpu_memory_in_mb'
,
'init_allocated_mem'
,
'free_idle_memory'
,
'paddle_num_threads'
,
"dist_threadpool_size"
,
'eager_delete_tensor_gb'
,
'paddle_num_threads'
,
"dist_threadpool_size"
,
'eager_delete_tensor_gb'
,
'fast_eager_deletion_mode'
,
'
allocator_strategy
'
,
'fast_eager_deletion_mode'
,
'
memory_fraction_of_eager_deletion
'
,
'
reader_queue_speed_test_mode'
,
'print_sub_graph_dir
'
,
'
allocator_strategy'
,
'reader_queue_speed_test_mode
'
,
'p
e_profile_fname'
,
'warpctc_dir'
,
'inner_op_parallelism
'
,
'p
rint_sub_graph_dir'
,
'pe_profile_fname'
,
'warpctc_dir
'
,
'
enable_parallel_graph'
,
'multiple_of_cupti_buffer_size
'
,
'
inner_op_parallelism'
,
'enable_parallel_graph
'
,
'enable_subgraph_optimize'
'
multiple_of_cupti_buffer_size'
,
'
enable_subgraph_optimize'
]
]
if
'Darwin'
not
in
sysstr
:
if
'Darwin'
not
in
sysstr
:
read_env_flags
.
append
(
'use_pinned_memory'
)
read_env_flags
.
append
(
'use_pinned_memory'
)
...
...
python/paddle/fluid/executor.py
浏览文件 @
472f16b5
...
@@ -590,7 +590,7 @@ class Executor(object):
...
@@ -590,7 +590,7 @@ class Executor(object):
fetch_var_name
=
fetch_var_name
)
fetch_var_name
=
fetch_var_name
)
self
.
_feed_data
(
program
,
feed
,
feed_var_name
,
scope
)
self
.
_feed_data
(
program
,
feed
,
feed_var_name
,
scope
)
exe
.
run
(
program
.
desc
,
scope
,
0
,
True
,
True
)
exe
.
run
(
program
.
desc
,
scope
,
0
,
True
,
True
,
fetch_var_name
)
outs
=
self
.
_fetch_data
(
fetch_list
,
fetch_var_name
,
scope
)
outs
=
self
.
_fetch_data
(
fetch_list
,
fetch_var_name
,
scope
)
if
return_numpy
:
if
return_numpy
:
outs
=
as_numpy
(
outs
)
outs
=
as_numpy
(
outs
)
...
...
python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py
浏览文件 @
472f16b5
...
@@ -16,8 +16,7 @@ import os
...
@@ -16,8 +16,7 @@ import os
import
unittest
import
unittest
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
"0.0"
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
"0.0"
os
.
environ
[
os
.
environ
[
'RECORDIO_FILENAME'
]
=
'./eager_deletion_transformer.wmt16.recordio'
'RECORDIO_FILENAME'
]
=
'/tmp/eager_deletion_transformer.wmt16.recordio'
from
test_parallel_executor_transformer
import
TestTransformer
from
test_parallel_executor_transformer
import
TestTransformer
...
...
python/paddle/fluid/tests/unittests/test_eager_deletion_while_op.py
0 → 100644
浏览文件 @
472f16b5
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
os
.
environ
[
'CPU_NUM'
]
=
'2'
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
'0.0'
os
.
environ
[
'FLAGS_fast_eager_deletion_mode'
]
=
'1'
import
unittest
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
from
paddle.fluid.executor
import
Executor
import
paddle.fluid.core
as
core
from
paddle.fluid.backward
import
append_backward
import
paddle.fluid.compiler
as
compiler
import
numpy
import
multiprocessing
class
TestEagerDeletionWhileOpBase
(
unittest
.
TestCase
):
def
test_main
(
self
):
places
=
[
core
.
CPUPlace
(),
]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
p
in
places
:
for
with_data_parallel
in
[
False
,
True
]:
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
with
fluid
.
scope_guard
(
fluid
.
Scope
()):
self
.
run_main
(
p
,
with_data_parallel
)
def
run_main
(
self
,
place
,
with_data_parallel
):
self
.
place
=
place
self
.
with_data_parallel
=
with_data_parallel
if
not
core
.
is_compiled_with_cuda
()
and
isinstance
(
self
.
place
,
core
.
CUDAPlace
):
return
if
isinstance
(
self
.
place
,
core
.
CUDAPlace
):
device_cnt
=
core
.
get_cuda_device_count
(
)
if
self
.
with_data_parallel
else
1
else
:
device_cnt
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
(
)))
if
self
.
with_data_parallel
else
1
d0
=
layers
.
data
(
"d0"
,
shape
=
[
10
],
append_batch_size
=
False
,
dtype
=
'float32'
)
d1
=
layers
.
data
(
"d1"
,
shape
=
[
10
],
append_batch_size
=
False
,
dtype
=
'float32'
)
d2
=
layers
.
data
(
"d2"
,
shape
=
[
10
],
append_batch_size
=
False
,
dtype
=
'float32'
)
i
=
layers
.
zeros
(
shape
=
[
1
],
dtype
=
'int64'
)
i
.
stop_gradient
=
True
init
=
layers
.
zeros
(
shape
=
[
10
],
dtype
=
'float32'
)
mem_array
=
layers
.
array_write
(
x
=
init
,
i
=
i
)
data_array
=
layers
.
array_write
(
x
=
d0
,
i
=
i
)
i
=
layers
.
increment
(
i
)
layers
.
array_write
(
d1
,
i
,
array
=
data_array
)
i
=
layers
.
increment
(
i
)
layers
.
array_write
(
d2
,
i
,
array
=
data_array
)
i
=
layers
.
zeros
(
shape
=
[
1
],
dtype
=
'int64'
)
i
.
stop_gradient
=
True
array_len
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
1
)
array_len
.
stop_gradient
=
True
cond
=
layers
.
less_than
(
x
=
i
,
y
=
array_len
)
j
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
1
)
j
.
stop_gradient
=
True
array_len2
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
3
)
array_len2
.
stop_gradient
=
True
cond2
=
layers
.
less_than
(
x
=
j
,
y
=
array_len2
)
while_op
=
layers
.
While
(
cond
=
cond
)
while_op2
=
layers
.
While
(
cond
=
cond2
)
with
while_op
.
block
():
d
=
layers
.
array_read
(
array
=
data_array
,
i
=
i
)
prev
=
layers
.
array_read
(
array
=
mem_array
,
i
=
i
)
d
=
layers
.
reshape
(
d
,
shape
=
[
10
])
prev
=
layers
.
reshape
(
prev
,
shape
=
[
10
])
result
=
layers
.
sums
(
input
=
[
d
,
prev
])
i
=
layers
.
increment
(
x
=
i
,
in_place
=
True
)
layers
.
array_write
(
result
,
i
=
i
,
array
=
mem_array
)
layers
.
less_than
(
x
=
i
,
y
=
array_len
,
cond
=
cond
)
with
while_op2
.
block
():
d2
=
layers
.
array_read
(
array
=
data_array
,
i
=
j
)
prev2
=
layers
.
array_read
(
array
=
mem_array
,
i
=
j
)
d2
=
layers
.
reshape
(
d2
,
shape
=
[
10
])
prev2
=
layers
.
reshape
(
prev2
,
shape
=
[
10
])
result2
=
layers
.
sums
(
input
=
[
d2
,
prev2
])
j
=
layers
.
increment
(
x
=
j
,
in_place
=
True
)
layers
.
array_write
(
result2
,
i
=
j
,
array
=
mem_array
)
layers
.
less_than
(
x
=
j
,
y
=
array_len2
,
cond
=
cond2
)
sum_result
=
layers
.
array_read
(
array
=
mem_array
,
i
=
j
)
sum_result
.
persistable
=
True
tmp
=
layers
.
unsqueeze
(
sum_result
,
axes
=
[
0
])
tmp
=
layers
.
expand
(
tmp
,
expand_times
=
[
10
,
1
])
fc
=
layers
.
fc
(
tmp
,
size
=
256
)
loss
=
layers
.
mean
(
sum_result
)
optim
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
1e-3
)
optim
.
minimize
(
loss
)
exe
=
Executor
(
self
.
place
)
exe
.
run
(
fluid
.
default_startup_program
())
prog
=
compiler
.
CompiledProgram
(
fluid
.
default_main_program
())
if
self
.
with_data_parallel
:
prog
=
prog
.
with_data_parallel
()
for
_
in
range
(
5
):
d
=
[]
for
i
in
range
(
3
):
tmp
=
numpy
.
random
.
random
(
size
=
[
10
]).
astype
(
'float32'
)
if
not
self
.
with_data_parallel
:
d
.
append
(
tmp
)
else
:
d
.
append
(
numpy
.
array
([
tmp
]
*
device_cnt
))
outs
=
exe
.
run
(
program
=
prog
,
feed
=
{
'd0'
:
d
[
0
],
'd1'
:
d
[
1
],
'd2'
:
d
[
2
]},
fetch_list
=
[
sum_result
])
self
.
assertAlmostEqual
(
numpy
.
sum
(
d
),
numpy
.
sum
(
outs
[
0
]),
delta
=
0.01
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_partial_eager_deletion_transformer.py
0 → 100644
浏览文件 @
472f16b5
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
unittest
os
.
environ
[
'FLAGS_eager_delete_tensor_gb'
]
=
"0.0"
os
.
environ
[
'FLAGS_memory_fraction_of_eager_deletion'
]
=
"0.55"
os
.
environ
[
'RECORDIO_FILENAME'
]
=
'./p_gc_transformer.wmt16.recordio'
from
test_parallel_executor_transformer
import
TestTransformer
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录