Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
387bac46
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
387bac46
编写于
12月 07, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine code
test=develop
上级
d0c8b9b9
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
122 addition
and
107 deletion
+122
-107
paddle/fluid/framework/details/eager_deletion_pass.cc
paddle/fluid/framework/details/eager_deletion_pass.cc
+6
-4
paddle/fluid/framework/details/op_graph_view.cc
paddle/fluid/framework/details/op_graph_view.cc
+2
-0
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+7
-7
paddle/fluid/framework/details/reference_count_pass_helper.h
paddle/fluid/framework/details/reference_count_pass_helper.h
+6
-4
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
...id/framework/details/scope_buffered_ssa_graph_executor.cc
+4
-4
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
...uid/framework/details/scope_buffered_ssa_graph_executor.h
+1
-1
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-10
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+3
-3
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+86
-67
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+3
-7
未找到文件。
paddle/fluid/framework/details/eager_deletion_pass.cc
浏览文件 @
387bac46
...
...
@@ -31,10 +31,11 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
const
auto
&
vars
=
graph
->
Get
<
GraphVars
>
(
kGraphVars
);
auto
&
ref_cnts
=
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
k
Cur
ReferenceCount
);
Get
<
std
::
vector
<
AtomicReferenceCountMap
>>
(
k
Runtime
ReferenceCount
);
const
auto
&
last_live_ops
=
Get
<
std
::
vector
<
LastLiveOpsOfVars
>>
(
kLastLiveOpsOfVars
);
auto
&
gcs
=
Get
<
GarbageCollectorList
>
(
kGarbageCollector
);
auto
&
gcs
=
Get
<
GarbageCollectorMap
>
(
kGarbageCollector
);
const
auto
&
places
=
Get
<
std
::
vector
<
platform
::
Place
>>
(
kAllPlaces
);
ref_cnts
=
std
::
vector
<
AtomicReferenceCountMap
>
(
vars
.
size
());
...
...
@@ -58,7 +59,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
graph
->
CreateEmptyNode
(
"eager_deletion"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
eager_deletion_op
=
new
EagerDeletionOpHandle
(
eager_deletion_node
,
op
->
GetScope
(),
op
->
GetPlace
(),
std
::
move
(
var_names
),
gcs
[
op
->
GetScopeIdx
()]
.
get
(),
std
::
move
(
var_names
),
gcs
.
at
(
places
[
op
->
GetScopeIdx
()])
.
get
(),
&
(
ref_cnts
[
op
->
GetScopeIdx
()]));
auto
it
=
std
::
find_if
(
...
...
@@ -90,6 +91,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
REGISTER_PASS
(
eager_deletion_pass
,
paddle
::
framework
::
details
::
EagerDeletionPass
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
k
Cur
ReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
k
Runtime
ReferenceCount
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kLastLiveOpsOfVars
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kAllPlaces
)
.
RequirePassAttr
(
paddle
::
framework
::
details
::
kGarbageCollector
);
paddle/fluid/framework/details/op_graph_view.cc
浏览文件 @
387bac46
...
...
@@ -23,6 +23,8 @@ namespace details {
OpGraphView
::
OpGraphView
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
{
Build
(
ops
);
}
void
OpGraphView
::
Build
(
const
std
::
vector
<
OpHandleBase
*>
&
ops
)
{
preceding_ops_
.
clear
();
pending_ops_
.
clear
();
for
(
auto
&
op
:
ops
)
{
preceding_ops_
[
op
];
pending_ops_
[
op
];
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
387bac46
...
...
@@ -29,22 +29,22 @@ namespace paddle {
namespace
framework
{
namespace
details
{
class
Op
Connec
tionDetector
{
class
Op
Rela
tionDetector
{
public:
enum
RelationShip
{
kSame
=
0
,
kNoDeps
=
1
,
kBefore
=
2
,
kAfter
=
3
};
explicit
Op
Connec
tionDetector
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
explicit
Op
Rela
tionDetector
(
const
std
::
vector
<
OpHandleBase
*>
&
all_ops
)
:
graph_
(
all_ops
)
{}
template
<
typename
OpSet
>
OpSet
MaxNoDepOps
(
const
OpSet
&
op_set
)
{
if
(
op_set
.
size
()
<=
1
)
return
op_set
;
OpSet
MaxNoDepOps
(
const
OpSet
&
op_set
)
const
{
using
KeyType
=
typename
OpSet
::
key_type
;
static_assert
(
std
::
is_base_of
<
OpHandleBase
,
typename
std
::
remove_pointer
<
KeyType
>::
type
>::
value
,
"Key type of OpSet must be or derived of OpHandleBase"
);
"Key type of OpSet must be
OpHandleBase,
or derived of OpHandleBase"
);
if
(
op_set
.
size
()
<=
1
)
return
op_set
;
std
::
vector
<
OpHandleBase
*>
ops
(
op_set
.
begin
(),
op_set
.
end
());
OpSet
ret
;
auto
rels
=
GetRelations
(
ops
);
...
...
@@ -59,7 +59,7 @@ class OpConnectionDetector {
private:
std
::
vector
<
std
::
vector
<
RelationShip
>>
GetRelations
(
const
std
::
vector
<
OpHandleBase
*>
ops
)
{
const
std
::
vector
<
OpHandleBase
*>
ops
)
const
{
std
::
unordered_map
<
OpHandleBase
*
,
size_t
>
op_to_idx
;
for
(
size_t
i
=
0
;
i
<
ops
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
graph_
.
HasOp
(
ops
[
i
]),
"Op does not exist in graph"
);
...
...
@@ -144,7 +144,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
last_live_ops_of_vars
=
std
::
vector
<
LastLiveOpsOfVars
>
(
vars
.
size
());
ref_cnts
=
std
::
vector
<
ReferenceCountMap
>
(
vars
.
size
());
Op
Connec
tionDetector
detector
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
));
Op
Rela
tionDetector
detector
(
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph
));
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
auto
&
name_var_pair
:
vars
[
i
])
{
...
...
paddle/fluid/framework/details/reference_count_pass_helper.h
浏览文件 @
387bac46
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <atomic>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
...
...
@@ -33,12 +34,13 @@ using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using
AtomicReferenceCountMap
=
std
::
unordered_map
<
std
::
string
,
std
::
atomic
<
size_t
>>
;
using
GarbageCollector
List
=
std
::
vector
<
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>>
;
using
GarbageCollector
Map
=
std
::
map
<
platform
::
Place
,
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>>
;
const
char
kGlobalReferenceCount
[]
=
"reference_count"
;
const
char
k
CurReferenceCount
[]
=
"current
_reference_count"
;
const
char
kGlobalReferenceCount
[]
=
"
global_
reference_count"
;
const
char
k
RuntimeReferenceCount
[]
=
"runtime
_reference_count"
;
const
char
kGarbageCollector
[]
=
"garbage_collector"
;
const
char
kAllPlaces
[]
=
"all_places"
;
using
LastLiveOpsOfVars
=
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
ComputationOpHandle
*>>
;
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
浏览文件 @
387bac46
...
...
@@ -32,15 +32,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
var_infos_
(
std
::
move
(
var_infos
)),
places_
(
std
::
move
(
places
))
{
if
(
Graph
().
Has
(
details
::
kGarbageCollector
))
{
gc_
=
&
(
Graph
().
Get
<
GarbageCollector
List
>
(
details
::
kGarbageCollector
));
gc_
=
&
(
Graph
().
Get
<
GarbageCollector
Map
>
(
details
::
kGarbageCollector
));
}
}
void
ScopeBufferedSSAGraphExecutor
::
WaitAllGarbageCollectors
()
{
if
(
gc_
)
{
for
(
auto
&
gc
:
*
gc_
)
{
gc
->
Wait
();
gc
->
Reset
();
for
(
auto
&
gc
_pair
:
*
gc_
)
{
gc
_pair
.
second
->
Wait
();
gc
_pair
.
second
->
Reset
();
}
}
}
...
...
paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h
浏览文件 @
387bac46
...
...
@@ -60,7 +60,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
std
::
vector
<
VariableInfo
>
var_infos_
;
std
::
vector
<
platform
::
Place
>
places_
;
GarbageCollector
List
*
gc_
{
nullptr
};
GarbageCollector
Map
*
gc_
{
nullptr
};
};
}
// namespace details
}
// namespace framework
...
...
paddle/fluid/framework/executor.cc
浏览文件 @
387bac46
...
...
@@ -56,13 +56,7 @@ static std::unordered_map<std::string, size_t> GetNonPersistableReferenceCounts(
type
!=
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
continue
;
}
auto
it
=
ref_cnts
.
find
(
name
);
if
(
it
!=
ref_cnts
.
end
())
{
++
it
->
second
;
}
else
{
ref_cnts
[
name
]
=
1
;
}
++
ref_cnts
[
name
];
}
}
};
...
...
@@ -79,7 +73,7 @@ ExecutorPrepareContext::ExecutorPrepareContext(
const
std
::
vector
<
std
::
string
>&
skip_ref_cnt_vars
)
:
prog_
(
prog
),
block_id_
(
block_id
)
{
if
(
GetEagerDeletionThreshold
()
>=
0
)
{
ref_cnts_
=
GetNonPersistableReferenceCounts
(
prog
.
Block
(
block_id
),
global_
ref_cnts_
=
GetNonPersistableReferenceCounts
(
prog
.
Block
(
block_id
),
skip_ref_cnt_vars
);
}
}
...
...
@@ -443,7 +437,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
if
(
gc
)
{
DeleteUnusedTensors
(
*
local_scope
,
op
.
get
(),
gc
.
get
(),
&
(
ctx
->
cur
_ref_cnts_
));
&
(
ctx
->
runtime
_ref_cnts_
));
}
}
...
...
paddle/fluid/framework/executor.h
浏览文件 @
387bac46
...
...
@@ -34,14 +34,14 @@ struct ExecutorPrepareContext {
~
ExecutorPrepareContext
();
void
ResetReferenceCount
()
{
cur_ref_cnts_
=
ref_cnts_
;
}
void
ResetReferenceCount
()
{
runtime_ref_cnts_
=
global_
ref_cnts_
;
}
const
framework
::
ProgramDesc
&
prog_
;
size_t
block_id_
;
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
cur
_ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
global_
ref_cnts_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
runtime
_ref_cnts_
;
};
class
Executor
{
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
387bac46
...
...
@@ -51,11 +51,22 @@ class ParallelExecutorPrivate {
}
}
void
ResetRuntimeReferenceCount
()
{
for
(
size_t
i
=
0
;
i
<
rt_ref_cnts_
.
size
();
++
i
)
{
for
(
auto
&
pair
:
rt_ref_cnts_
[
i
])
{
rt_cur_ref_cnts_
[
i
][
pair
.
first
]
=
pair
.
second
;
std
::
unique_ptr
<
ir
::
Graph
>
PrepareGCAndRefCnts
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
,
size_t
max_memory_size
);
inline
bool
HasGarbageCollectors
()
const
{
return
!
gcs_
.
empty
();
}
void
ResetRuntimeReferenceCount
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
for
(
size_t
i
=
0
;
i
<
runtime_ref_cnts_
.
size
();
++
i
)
{
for
(
auto
&
pair
:
global_ref_cnts_
[
i
])
{
runtime_ref_cnts_
[
i
][
pair
.
first
]
=
pair
.
second
;
}
for
(
auto
&
fetch_name
:
fetch_tensors
)
{
runtime_ref_cnts_
[
i
].
erase
(
fetch_name
);
}
runtime_ref_cnts_
[
i
].
erase
(
fetched_var_name
);
}
}
...
...
@@ -71,14 +82,75 @@ class ParallelExecutorPrivate {
bool
use_cuda_
;
bool
use_all_reduce_
;
//
rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
// Before each iteration, r
t_cur_ref_cnts_ is reset to
ref_cnts_
std
::
vector
<
details
::
ReferenceCountMap
>
rt
_ref_cnts_
;
std
::
vector
<
details
::
AtomicReferenceCountMap
>
r
t_cur
_ref_cnts_
;
details
::
GarbageCollector
List
gcs_
;
//
global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
//
then
keeps unchanged
// Before each iteration, r
untime_ref_cnts_ is reset to global_
ref_cnts_
std
::
vector
<
details
::
ReferenceCountMap
>
global
_ref_cnts_
;
std
::
vector
<
details
::
AtomicReferenceCountMap
>
r
untime
_ref_cnts_
;
details
::
GarbageCollector
Map
gcs_
;
};
std
::
unique_ptr
<
ir
::
Graph
>
ParallelExecutorPrivate
::
PrepareGCAndRefCnts
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
,
size_t
max_memory_size
)
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
place
=
places_
[
i
];
if
(
gcs_
.
count
(
place
)
>
0
)
{
continue
;
}
#ifdef PADDLE_WITH_CUDA
GarbageCollector
<
Tensor
>
*
gc
=
nullptr
;
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
gc
=
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
);
}
else
{
gc
=
new
StreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
);
}
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
#endif
gc
=
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
max_memory_size
);
VLOG
(
10
)
<<
"Created GarbageCollector at "
<<
place
;
#ifdef PADDLE_WITH_CUDA
}
#endif
if
(
gc
)
{
gcs_
[
place
]
=
std
::
unique_ptr
<
GarbageCollector
<
Tensor
>>
(
gc
);
}
}
if
(
gcs_
.
empty
())
{
std
::
vector
<
details
::
LastLiveOpsOfVars
>
last_live_ops_of_vars
;
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
&
global_ref_cnts_
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
auto
eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kRuntimeReferenceCount
,
&
runtime_ref_cnts_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kAllPlaces
,
&
places_
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
graph
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
gcs_
);
}
return
graph
;
}
std
::
vector
<
Scope
*>
&
ParallelExecutor
::
GetLocalScopes
()
{
return
member_
->
local_scopes_
;
}
...
...
@@ -153,54 +225,8 @@ ParallelExecutor::ParallelExecutor(
auto
max_memory_size
=
GetEagerDeletionThreshold
();
if
(
max_memory_size
>=
0
)
{
size_t
place_num
=
member_
->
places_
.
size
();
for
(
size_t
i
=
0
;
i
<
place_num
;
++
i
)
{
auto
&
place
=
member_
->
places_
[
i
];
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
member_
->
gcs_
.
emplace_back
(
new
UnsafeFastGPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
}
else
{
member_
->
gcs_
.
emplace_back
(
new
StreamGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
max_memory_size
));
}
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
}
else
if
(
platform
::
is_cpu_place
(
place
))
{
#endif
member_
->
gcs_
.
emplace_back
(
new
CPUGarbageCollector
<
Tensor
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
max_memory_size
));
VLOG
(
10
)
<<
"Created "
<<
i
<<
"-th GarbageCollector at "
<<
place
;
#ifdef PADDLE_WITH_CUDA
}
#endif
}
}
if
(
!
member_
->
gcs_
.
empty
())
{
std
::
vector
<
details
::
LastLiveOpsOfVars
>
last_live_ops_of_vars
;
auto
ref_cnt_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"reference_count_pass"
);
ref_cnt_pass
->
SetNotOwned
(
details
::
kGlobalReferenceCount
,
&
(
member_
->
rt_ref_cnts_
));
ref_cnt_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
ref_cnt_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"ReferenceCountPass Applied"
;
auto
eager_deletion_pass
=
ir
::
PassRegistry
::
Instance
().
Get
(
"eager_deletion_pass"
);
eager_deletion_pass
->
SetNotOwned
(
details
::
kCurReferenceCount
,
&
(
member_
->
rt_cur_ref_cnts_
));
eager_deletion_pass
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
(
member_
->
gcs_
));
eager_deletion_pass
->
SetNotOwned
(
details
::
kLastLiveOpsOfVars
,
&
last_live_ops_of_vars
);
graph
=
eager_deletion_pass
->
Apply
(
std
::
move
(
graph
));
VLOG
(
10
)
<<
"EagerDeletionPass Applied"
;
graph
->
SetNotOwned
(
details
::
kGarbageCollector
,
&
(
member_
->
gcs_
));
graph
=
member_
->
PrepareGCAndRefCnts
(
std
::
move
(
graph
),
static_cast
<
size_t
>
(
max_memory_size
));
}
// Step 3. Create vars in each scope. Passes may also create new vars.
...
...
@@ -316,15 +342,8 @@ void ParallelExecutor::BCastParamsToDevices(
void
ParallelExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
const
std
::
string
&
fetched_var_name
)
{
platform
::
RecordBlock
b
(
0
);
if
(
!
member_
->
gcs_
.
empty
())
{
member_
->
ResetRuntimeReferenceCount
();
size_t
n
=
member_
->
rt_ref_cnts_
.
size
();
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
for
(
auto
&
fetch_name
:
fetch_tensors
)
{
member_
->
rt_cur_ref_cnts_
[
i
].
erase
(
fetch_name
);
}
member_
->
rt_cur_ref_cnts_
[
i
].
erase
(
fetched_var_name
);
}
if
(
member_
->
HasGarbageCollectors
())
{
member_
->
ResetRuntimeReferenceCount
(
fetch_tensors
,
fetched_var_name
);
}
auto
fetch_data
=
member_
->
executor_
->
Run
(
fetch_tensors
);
*
member_
->
global_scope_
->
Var
(
fetched_var_name
)
->
GetMutable
<
FeedFetchList
>
()
=
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
387bac46
...
...
@@ -74,9 +74,7 @@ class WhileOp : public framework::OperatorBase {
bool
is_test
=
Attr
<
bool
>
(
"is_test"
);
auto
&
skip_vars
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSkipEagerDeletionVars
);
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
)
{
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
skip_vars
);
}
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
(),
skip_vars
);
while
(
cond
.
data
<
bool
>
()[
0
])
{
...
...
@@ -144,9 +142,7 @@ class WhileGradOp : public framework::OperatorBase {
auto
*
program
=
block
->
Program
();
auto
&
skip_vars
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSkipEagerDeletionVars
);
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
)
{
VLOG
(
2
)
<<
GetSkipEagerDeletionVarsDebugString
(
skip_vars
);
}
auto
ctx
=
executor
.
Prepare
(
*
program
,
block
->
ID
(),
skip_vars
);
auto
*
step_scopes
=
...
...
@@ -369,7 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed.
while_grad
->
SetAttr
(
"original_output_grad"
,
output_grads_list
);
/* The followi
_
ng codes are used in eager deletion mode */
/* The following codes are used in eager deletion mode */
std
::
unordered_set
<
std
::
string
>
bwd_skip_vars
;
if
(
framework
::
GetEagerDeletionThreshold
()
>=
0
)
{
std
::
unordered_set
<
std
::
string
>
fwd_skip_vars
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录