Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
387bac46
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
387bac46
编写于
12月 07, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code
test=develop
上级
d0c8b9b9
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
122 addition
and
107 deletion
+122
-107
paddle/fluid/framework/details/eager_deletion_pass.cc
paddle/fluid/framework/details/eager_deletion_pass.cc
+6
-4
paddle/fluid/framework/details/op_graph_view.cc
paddle/fluid/framework/details/op_graph_view.cc
+2
-0
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+7
-7
paddle/fluid/framework/details/reference_count_pass_helper.h
paddle/fluid/framework/details/reference_count_pass_helper.h
+6
-4
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+4
-4
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+1
-1
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-10
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+3
-3
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+86
-67
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+3
-7
未找到文件。
paddle/fluid/framework/details/eager_deletion_pass.cc
浏览文件 @
387bac46
...
...
@@ -31,10 +31,11 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
auto
&
ref_cnts
=
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
k
Cur
ReferenceCount
);
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
k
Runtime
ReferenceCount
);
const
auto
&
last_live_ops
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
auto
&
gcs
=
Get
<
GarbageCollectorList
>
(
kGarbageCollector
);
auto
&
gcs
=
Get
<
GarbageCollectorMap
>
(
kGarbageCollector
);
const
auto
&
places
=
Get
<
std
::
vector
<
platform
::
Place
>>
(
kAllPlaces
);
ref_cnts
=
std
::
vector
<
AtomicReferenceCountMap
>
(
vars
.
size
());
...
...
@@ -58,7 +59,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
eager_deletion_op
=
new
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
std
::
move
(
var_names
),
gcs
[
op
->
GetScopeIdx
()]
.
get
(),
std
::
move
(
var_names
),
gcs
.
at
(
places
[
op
->
GetScopeIdx
()])
.
get
(),
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
auto
it
=
std
::
find_if
(
...
...
@@ -90,6 +91,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
REGISTER_PASS
(
eager_deletion_pass
,
paddle
::
framework
::
details
::
EagerDeletionPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
k
Cur
ReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
k
Runtime
ReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kAllPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
paddle/fluid/framework/details/op_graph_view.cc
浏览文件 @
387bac46
...
...
@@ -23,6 +23,8 @@ namespace details {
OpGraphView
::
OpGraphView
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
{
Build
(
ops
);
}
void
OpGraphView
::
Build
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
{
preceding_ops_
.
clear
();
pending_ops_
.
clear
();
for
(
auto
&
op
:
ops
)
{
preceding_ops_
[
op
];
pending_ops_
[
op
];
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
387bac46
...
...
@@ -29,22 +29,22 @@ namespace paddle {
namespace
framework
{
namespace
details
{
class
Op
Connec
tionDetector
{
class
Op
Rela
tionDetector
{
public:
enum
RelationShip
{
kSame
=
0
,
kNoDeps
=
1
,
kBefore
=
2
,
kAfter
=
3
};
explicit
Op
Connec
tionDetector
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
explicit
Op
Rela
tionDetector
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
:
graph_
(
all_ops
)
{}
template
<
typename
OpSet
>
OpSet
MaxNoDepOps
(
const
OpSet
&
op_set
)
{
if
(
op_set
.
size
()
<=
1
)
return
op_set
;
OpSet
MaxNoDepOps
(
const
OpSet
&
op_set
)
const
{
using
KeyType
=
typename
OpSet
::
key_type
;
static_assert
(
std
::
is_base_of
<
OpHandleBase
,
typename
std
::
remove_pointer
<
KeyType
>::
type
>::
value
,
"Key type of OpSet must be or derived of OpHandleBase"
);
"Key type of OpSet must be
OpHandleBase,
or derived of OpHandleBase"
);
if
(
op_set
.
size
()
<=
1
)
return
op_set
;
std
::
vector
<
OpHandleBase
*>
ops
(
op_set
.
begin
(),
op_set
.
end
());
OpSet
ret
;
auto
rels
=
GetRelations
(
ops
);
...
...
@@ -59,7 +59,7 @@ class OpConnectionDetector {
private:
std
::
vector
<
std
::
vector
<
RelationShip
>>
GetRelations
(
const
std
::
vector
<
OpHandleBase
*>
ops
)
{
const
std
::
vector
<
OpHandleBase
*>
ops
)
const
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
op_to_idx
;
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
graph_
.
HasOp
(
ops
[
i
]),
"Op does not exist in graph"
);
...
...
@@ -144,7 +144,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
last_live_ops_of_vars
=
std
::
vector
<
LastLiveOpsOfVars
>
(
vars
.
size
());
ref_cnts
=
std
::
vector
<
ReferenceCountMap
>
(
vars
.
size
());
Op
Connec
tionDetector
detector
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
));
Op
Rela
tionDetector
detector
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
));
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
auto
&
name_var_pair
:
vars
[
i
])
{
...
...
paddle/fluid/framework/details/reference_count_pass_helper.h
浏览文件 @
387bac46
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <atomic>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -33,12 +34,13 @@ using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using
AtomicReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
std
::
atomic
<
size_t
>>
;
using
GarbageCollector
List
=
std
::
vector
<
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>>
;
using
GarbageCollector
Map
=
std
::
map
<
platform
::
Place
,
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>>
;
const
char
kGlobalReferenceCount
[]
=
"reference_count"
;
const
char
k
CurReferenceCount
[]
=
"current
_reference_count"
;
const
char
kGlobalReferenceCount
[]
=
"
global_
reference_count"
;
const
char
k
RuntimeReferenceCount
[]
=
"runtime
_reference_count"
;
const
char
kGarbageCollector
[]
=
"garbage_collector"
;
const
char
kAllPlaces
[]
=
"all_places"
;
using
LastLiveOpsOfVars
=
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ComputationOpHandle
*>>
;
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
387bac46
...
...
@@ -32,15 +32,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
var_infos_
(
std
::
move
(
var_infos
)),
places_
(
std
::
move
(
places
))
{
if
(
Graph
().
Has
(
details
::
kGarbageCollector
))
{
gc_
=
&
(
Graph
().
Get
<
GarbageCollector
List
>
(
details
::
kGarbageCollector
));
gc_
=
&
(
Graph
().
Get
<
GarbageCollector
Map
>
(
details
::
kGarbageCollector
));
}
}
void
ScopeBufferedSSAGraphExecutor
::
WaitAllGarbageCollectors
()
{
if
(
gc_
)
{
for
(
auto
&
gc
:
*
gc_
)
{
gc
->
Wait
();
gc
->
Reset
();
for
(
auto
&
gc
_pair
:
*
gc_
)
{
gc
_pair
.
second
->
Wait
();
gc
_pair
.
second
->
Reset
();
}
}
}
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
387bac46
...
...
@@ -60,7 +60,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std
::
vector
<
VariableInfo
>
var_infos_
;
std
::
vector
<
platform
::
Place
>
places_
;
GarbageCollector
List
*
gc_
{
nullptr
};
GarbageCollector
Map
*
gc_
{
nullptr
};
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
387bac46
...
...
@@ -56,13 +56,7 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
type
!=
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
continue
;
}
auto
it
=
ref_cnts
.
find
(
name
);
if
(
it
!=
ref_cnts
.
end
())
{
++
it
->
second
;
}
else
{
ref_cnts
[
name
]
=
1
;
}
++
ref_cnts
[
name
];
}
}
};
...
...
@@ -79,8 +73,8 @@ ExecutorPrepareContext::ExecutorPrepareContext(
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
)
:
prog_
(
prog
),
block_id_
(
block_id
)
{
if
(
GetEagerDeletionThreshold
()
>=
0
)
{
ref_cnts_
=
GetNonPersistableReferenceCounts
(
prog
.
Block
(
block_id
),
skip_ref_cnt_vars
);
global_
ref_cnts_
=
GetNonPersistableReferenceCounts
(
prog
.
Block
(
block_id
),
skip_ref_cnt_vars
);
}
}
...
...
@@ -443,7 +437,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
if
(
gc
)
{
DeleteUnusedTensors
(
*
local_scope
,
op
.
get
(),
gc
.
get
(),
&
(
ctx
->
cur
_ref_cnts_
));
&
(
ctx
->
runtime
_ref_cnts_
));
}
}
...
...
paddle/fluid/framework/executor.h
浏览文件 @
387bac46
...
...
@@ -34,14 +34,14 @@ struct ExecutorPrepareContext {
~
ExecutorPrepareContext
();
void
ResetReferenceCount
()
{
cur_ref_cnts_
=
ref_cnts_
;
}
void
ResetReferenceCount
()
{
runtime_ref_cnts_
=
global_
ref_cnts_
;
}
const
framework
::
ProgramDesc
&
prog_
;
size_t
block_id_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
cur
_ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
global_
ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
runtime
_ref_cnts_
;
};
class
Executor
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
387bac46
...
...
@@ -51,11 +51,22 @@ class ParallelExecutorPrivate {
}
}
void
ResetRuntimeReferenceCount
()
{
for
(
size_t
i
=
0
;
i
<
rt_ref_cnts_
.
size
();
++
i
)
{
for
(
auto
&
pair
:
rt_ref_cnts_
[
i
])
{
rt_cur_ref_cnts_
[
i
][
pair
.
first
]
=
pair
.
second
;
std
::
unique_ptr
<
ir
::
Graph
>
PrepareGCAndRefCnts
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
,
size_t
max_memory_size
);
inline
bool
HasGarbageCollectors
()
const
{
return
!
gcs_
.
empty
();
}
void
ResetRuntimeReferenceCount
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
for
(
size_t
i
=
0
;
i
<
runtime_ref_cnts_
.
size
();
++
i
)
{
for
(
auto
&
pair
:
global_ref_cnts_
[
i
])
{
runtime_ref_cnts_
[
i
][
pair
.
first
]
=
pair
.
second
;
}
for
(
auto
&
fetch_name
:
fetch_tensors
)
{
runtime_ref_cnts_
[
i
].
erase
(
fetch_name
);
}
runtime_ref_cnts_
[
i
].
erase
(
fetched_var_name
);
}
}
...
...
@@ -71,14 +82,75 @@ class ParallelExecutorPrivate {
bool
use_cuda_
;
bool
use_all_reduce_
;
//
rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, r
t_cur_ref_cnts_ is reset to
ref_cnts_
std
::
vector
<
details
::
ReferenceCountMap
>
rt
_ref_cnts_
;
std
::
vector
<
details
::
AtomicReferenceCountMap
>
r
t_cur
_ref_cnts_
;
details
::
GarbageCollector
List
gcs_
;
//
global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
//
then
keeps unchanged
// Before each iteration, r
untime_ref_cnts_ is reset to global_
ref_cnts_
std
::
vector
<
details
::
ReferenceCountMap
>
global
_ref_cnts_
;
std
::
vector
<
details
::
AtomicReferenceCountMap
>
r
untime
_ref_cnts_
;
details
::
GarbageCollector
Map
gcs_
;
};
std
::
unique_ptr
<
ir
::
Graph
>
ParallelExecutorPrivate
::
PrepareGCAndRefCnts
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
,
size_t
max_memory_size
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
place
=
places_
[
i
];
if
(
gcs_
.
count
(
place
)
>
0
)
{
continue
;
}
#ifdef PADDLE_WITH_CUDA
GarbageCollector
<
Tensor
>
*
gc
=
nullptr
;
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
gc
=
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
);
}
else
{
gc
=
new
StreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
);
}
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
#endif
gc
=
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
max_memory_size
);
VLOG
(
10
)
<<
"Created GarbageCollector at "
<<
place
;
#ifdef PADDLE_WITH_CUDA
}
#endif
if
(
gc
)
{
gcs_
[
place
]
=
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>
(
gc
);
}
}
if
(
gcs_
.
empty
())
{
std
::
vector
<
details
::
LastLiveOpsOfVars
>
last_live_ops_of_vars
;
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
&
global_ref_cnts_
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
auto
eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kRuntimeReferenceCount
,
&
runtime_ref_cnts_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
graph
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
}
return
graph
;
}
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
return
member_
->
local_scopes_
;
}
...
...
@@ -153,54 +225,8 @@ ParallelExecutor::ParallelExecutor(
auto
max_memory_size
=
GetEagerDeletionThreshold
();
if
(
max_memory_size
>=
0
)
{
size_t
place_num
=
member_
->
places_
.
size
();
for
(
size_t
i
=
0
;
i
<
place_num
;
++
i
)
{
auto
&
place
=
member_
->
places_
[
i
];
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
member_
->
gcs_
.
emplace_back
(
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
}
else
{
member_
->
gcs_
.
emplace_back
(
new
StreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
}
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
#endif
member_
->
gcs_
.
emplace_back
(
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
max_memory_size
));
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
#ifdef PADDLE_WITH_CUDA
}
#endif
}
}
if
(
!
member_
->
gcs_
.
empty
())
{
std
::
vector
<
details
::
LastLiveOpsOfVars
>
last_live_ops_of_vars
;
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
&
(
member_
->
rt_ref_cnts_
));
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
auto
eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kCurReferenceCount
,
&
(
member_
->
rt_cur_ref_cnts_
));
eager_deletion_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
(
member_
->
gcs_
));
eager_deletion_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
graph
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
(
member_
->
gcs_
));
graph
=
member_
->
PrepareGCAndRefCnts
(
std
::
move
(
graph
),
static_cast
<
size_t
>
(
max_memory_size
));
}
// Step 3. Create vars in each scope. Passes may also create new vars.
...
...
@@ -316,15 +342,8 @@ void ParallelExecutor::BCastParamsToDevices(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
if
(
!
member_
->
gcs_
.
empty
())
{
member_
->
ResetRuntimeReferenceCount
();
size_t
n
=
member_
->
rt_ref_cnts_
.
size
();
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
auto
&
fetch_name
:
fetch_tensors
)
{
member_
->
rt_cur_ref_cnts_
[
i
].
erase
(
fetch_name
);
}
member_
->
rt_cur_ref_cnts_
[
i
].
erase
(
fetched_var_name
);
}
if
(
member_
->
HasGarbageCollectors
())
{
member_
->
ResetRuntimeReferenceCount
(
fetch_tensors
,
fetched_var_name
);
}
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
387bac46
...
...
@@ -74,9 +74,7 @@ class WhileOp : public framework::OperatorBase {
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
auto
&
skip_vars
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSkipEagerDeletionVars
);
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
)
{
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
skip_vars
);
}
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
skip_vars
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
(),
skip_vars
);
while
(
cond
.
data
<
bool
>
()[
0
])
{
...
...
@@ -144,9 +142,7 @@ class WhileGradOp : public framework::OperatorBase {
auto
*
program
=
block
->
Program
();
auto
&
skip_vars
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSkipEagerDeletionVars
);
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
)
{
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
skip_vars
);
}
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
skip_vars
);
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
(),
skip_vars
);
auto
*
step_scopes
=
...
...
@@ -369,7 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed.
while_grad
->
SetAttr
(
"original_output_grad"
,
output_grads_list
);
/* The followi
_
ng codes are used in eager deletion mode */
/* The following codes are used in eager deletion mode */
std
::
unordered_set
<
std
::
string
>
bwd_skip_vars
;
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
)
{
std
::
unordered_set
<
std
::
string
>
fwd_skip_vars
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录