Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
0a36ef3c
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a36ef3c
编写于
9月 20, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance eager deletion
上级
d775087d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
125 addition
and
37 deletion
+125
-37
paddle/fluid/framework/details/reference_count_op_handle.h
paddle/fluid/framework/details/reference_count_op_handle.h
+28
-13
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+64
-11
paddle/fluid/operators/adam_op.h
paddle/fluid/operators/adam_op.h
+33
-13
未找到文件。
paddle/fluid/framework/details/reference_count_op_handle.h
浏览文件 @
0a36ef3c
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
...
@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
std
::
vector
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
:
OpHandleBase
(
node
),
scope_
(
scope
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
scope_
(
scope
),
var_names_
(
var_names
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
IsStreamGarabageCollector
())
{
if
(
IsStreamGarabageCollector
())
{
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
));
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
}
}
for
(
auto
&
name
:
var_names
)
AddVar
(
name
);
}
}
~
ReferenceCountOpHandle
()
{
~
ReferenceCountOpHandle
()
{
...
@@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase {
...
@@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase {
std
::
string
Name
()
const
override
{
return
"reference_count"
;
}
std
::
string
Name
()
const
override
{
return
"reference_count"
;
}
void
AddVar
(
const
std
::
string
&
name
)
{
auto
it
=
var_names_
.
find
(
name
);
if
(
it
!=
var_names_
.
end
())
++
(
it
->
second
);
else
var_names_
[
name
]
=
1
;
}
protected:
protected:
void
RunImpl
()
override
{
void
RunImpl
()
override
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
std
::
vector
<
LoDTensor
*>
tensors
;
std
::
vector
<
Tensor
*>
tensors
;
for
(
auto
&
name
:
var_names_
)
{
for
(
auto
&
pair
:
var_names_
)
{
auto
&
name
=
pair
.
first
;
auto
it
=
ref_cnts_
->
find
(
name
);
auto
it
=
ref_cnts_
->
find
(
name
);
if
(
it
==
ref_cnts_
->
end
())
continue
;
if
(
it
==
ref_cnts_
->
end
())
continue
;
auto
*
var
=
exec_scope
->
FindVar
(
name
);
auto
*
var
=
exec_scope
->
FindVar
(
name
);
if
(
var
==
nullptr
||
!
var
->
IsType
<
LoDTensor
>
())
continue
;
if
(
var
==
nullptr
)
continue
;
if
(
it
->
second
.
fetch_sub
(
1
)
<=
1
)
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
if
(
it
->
second
.
fetch_sub
(
pair
.
second
)
<=
pair
.
second
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
if
(
it
->
second
.
fetch_sub
(
pair
.
second
)
<=
pair
.
second
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
}
}
}
}
}
...
@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
...
@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
}
}
private:
private:
void
ClearTensors
(
const
std
::
vector
<
LoD
Tensor
*>
&
tensors
)
{
void
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
)
{
auto
*
gc
=
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
);
auto
*
gc
=
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
);
if
(
gc
!=
nullptr
)
{
if
(
gc
!=
nullptr
)
{
auto
compute_stream
=
dev_ctx_
->
stream
();
auto
compute_stream
=
dev_ctx_
->
stream
();
...
@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
...
@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
const
Scope
*
scope_
;
const
Scope
*
scope_
;
platform
::
CUDADeviceContext
*
dev_ctx_
;
platform
::
CUDADeviceContext
*
dev_ctx_
;
std
::
vector
<
std
::
string
>
var_names_
;
std
::
unordered_map
<
std
::
string
,
int
>
var_names_
;
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
cudaEvent_t
event_
;
cudaEvent_t
event_
;
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
0a36ef3c
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <queue>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -23,6 +24,25 @@ namespace paddle {
...
@@ -23,6 +24,25 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
static
ComputationOpHandle
*
FindNextComputationOpHandle
(
VarHandle
*
var_in
)
{
std
::
queue
<
VarHandleBase
*>
queue
;
queue
.
push
(
var_in
);
do
{
auto
*
var
=
queue
.
front
();
queue
.
pop
();
for
(
auto
*
op
:
var
->
PendingOps
())
{
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetPlace
()
==
var_in
->
place_
)
{
return
compute_op
;
}
for
(
auto
*
out_var
:
op
->
Outputs
())
{
queue
.
push
(
out_var
);
}
}
}
while
(
!
queue
.
empty
());
return
nullptr
;
}
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
ref_cnts
=
Get
<
DeviceReferenceCountMap
>
(
kGlobalReferenceCount
);
auto
&
ref_cnts
=
Get
<
DeviceReferenceCountMap
>
(
kGlobalReferenceCount
);
...
@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
// in computation ops
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unique_ptr
<
ReferenceCountOpHandle
>>
compute_ref_cnt_map
;
auto
get_ref_cnts_from_compute_op
=
[
&
](
auto
get_ref_cnts_from_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
...
@@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
VarDesc
*
var_desc
=
var_handle
->
Node
()
->
Var
();
VarDesc
*
var_desc
=
var_handle
->
Node
()
->
Var
();
auto
var_name
=
var_handle
->
Node
()
->
Name
();
auto
var_name
=
var_handle
->
Node
()
->
Name
();
// This is w
ie
rd but there is really some variables without var_desc
// This is w
ei
rd but there is really some variables without var_desc
// in computation_op
// in computation_op
if
(
var_desc
==
nullptr
)
{
if
(
var_desc
==
nullptr
)
{
if
(
compute_op
->
Node
()
->
Op
()
->
Block
()
->
FindVar
(
var_name
)
==
nullptr
)
if
(
compute_op
->
Node
()
->
Op
()
->
Block
()
->
FindVar
(
var_name
)
==
nullptr
)
continue
;
continue
;
}
else
{
}
else
{
if
(
var_desc
->
Persistable
()
||
if
(
var_desc
->
Persistable
())
continue
;
var_desc
->
Proto
()
->
type
().
type
()
!=
proto
::
VarType
::
LOD_TENSOR
)
auto
var_type
=
var_desc
->
Proto
()
->
type
().
type
();
if
(
var_type
!=
proto
::
VarType
::
LOD_TENSOR
&&
var_type
!=
proto
::
VarType
::
SELECTED_ROWS
)
{
continue
;
continue
;
}
}
}
// compute op only runs in one device
// compute op only runs in one device
...
@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if
(
ref_cnts
.
count
(
place
.
device
)
&&
if
(
ref_cnts
.
count
(
place
.
device
)
&&
ref_cnts
[
place
.
device
]
->
count
(
var_name
))
{
ref_cnts
[
place
.
device
]
->
count
(
var_name
))
{
++
(
*
ref_cnts
[
place
.
device
])[
var_name
];
++
(
*
ref_cnts
[
place
.
device
])[
var_name
];
auto
*
next_compute_op
=
FindNextComputationOpHandle
(
var_handle
);
if
(
next_compute_op
!=
nullptr
)
{
if
(
compute_ref_cnt_map
.
count
(
next_compute_op
))
{
compute_ref_cnt_map
[
next_compute_op
]
->
AddVar
(
var_name
);
VLOG
(
5
)
<<
"Add reference count of "
<<
var_name
<<
" to Operator "
<<
next_compute_op
->
Name
();
}
else
{
// Create new reference_count_op_handle
ir
::
Node
*
ref_cnt_node
=
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
if
(
next_compute_op
->
Outputs
().
empty
())
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
next_compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
ref_cnt_handle
->
AddInput
(
next_compute_op
->
Outputs
().
front
());
compute_ref_cnt_map
[
next_compute_op
].
reset
(
ref_cnt_handle
);
}
}
}
}
}
}
};
};
std
::
unordered_map
<
OpHandleBase
*
,
ReferenceCountOpHandle
*>
compute_ref_cnt_map
;
auto
&
all_ops
=
graph
->
Get
<
GraphOps
>
(
kGraphOps
);
auto
&
all_ops
=
graph
->
Get
<
GraphOps
>
(
kGraphOps
);
for
(
auto
&
op
:
all_ops
)
{
for
(
auto
&
op
:
all_ops
)
{
auto
in_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Inputs
());
auto
in_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Inputs
());
...
@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
if
(
compute_op
->
Outputs
().
empty
())
{
compute_op
->
AddOutput
(
dep_var
);
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
ref_cnt_handle
->
AddInput
(
dep_var
);
compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
compute_ref_cnt_map
[
compute_op
]
=
ref_cnt_handle
;
}
ref_cnt_handle
->
AddInput
(
compute_op
->
Outputs
().
front
());
compute_ref_cnt_map
[
compute_op
].
reset
(
ref_cnt_handle
);
}
}
for
(
auto
&
op
:
all_ops
)
{
for
(
auto
&
op
:
all_ops
)
{
...
@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
().
get
());
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
().
get
());
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
new_all_ops
.
emplace_back
(
it
->
second
);
// Add LeafNode to ReferenceCountOpHandle
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
it
->
second
->
AddOutput
(
dummy_leaf
);
new_all_ops
.
emplace_back
(
std
::
move
(
it
->
second
));
}
}
}
}
...
...
paddle/fluid/operators/adam_op.h
浏览文件 @
0a36ef3c
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <math.h> // for sqrt in CPU and CUDA
#include <math.h> // for sqrt in CPU and CUDA
#include <Eigen/Dense>
#include <Eigen/Dense>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
...
@@ -306,26 +307,45 @@ class AdamOpKernel : public framework::OpKernel<T> {
...
@@ -306,26 +307,45 @@ class AdamOpKernel : public framework::OpKernel<T> {
VLOG
(
3
)
<<
"grad row size is 0!!"
;
VLOG
(
3
)
<<
"grad row size is 0!!"
;
return
;
return
;
}
}
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
std
::
vector
<
int64_t
>
cpu_rows
(
grad
.
rows
().
begin
(),
grad
.
rows
().
end
());
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
bool
is_strict_sorted
=
true
;
auto
&
grad_merge
=
*
(
ctx
.
scope
()
for
(
size_t
i
=
1
;
i
<
cpu_rows
.
size
();
++
i
)
{
.
NewScope
()
if
(
cpu_rows
[
i
-
1
]
>=
cpu_rows
[
i
])
{
.
Var
(
"sparse_adam_grad_merge"
)
is_strict_sorted
=
false
;
->
GetMutable
<
framework
::
SelectedRows
>
());
break
;
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
}
&
grad_merge
);
}
const
framework
::
SelectedRows
*
grad_merge_ptr
;
if
(
is_strict_sorted
)
{
grad_merge_ptr
=
&
grad
;
}
else
{
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
auto
*
grad_merge_var
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
grad_merge_var
);
grad_merge_ptr
=
grad_merge_var
;
std
::
cerr
<<
"Create new variables in adam_op"
<<
std
::
endl
;
}
auto
&
grad_merge
=
*
grad_merge_ptr
;
auto
&
grad_tensor
=
grad_merge
.
value
();
auto
&
grad_tensor
=
grad_merge
.
value
();
const
T
*
grad_data
=
grad_tensor
.
template
data
<
T
>();
const
T
*
grad_data
=
grad_tensor
.
template
data
<
T
>();
int64_t
*
rows
=
nullptr
;
const
int64_t
*
rows
=
nullptr
;
// When compiled without CUDA, the CUDA
Mutable
Data() interface should not be
// When compiled without CUDA, the CUDAData() interface should not be
// provided.
// provided.
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
rows
=
grad_merge
.
mutable_rows
()
->
CUDAMutable
Data
(
ctx
.
GetPlace
());
rows
=
grad_merge
.
rows
().
CUDA
Data
(
ctx
.
GetPlace
());
}
else
{
}
else
{
#endif
#endif
rows
=
grad_merge
.
mutable_rows
()
->
data
();
rows
=
grad_merge
.
rows
().
data
();
#if defined(PADDLE_WITH_CUDA)
#if defined(PADDLE_WITH_CUDA)
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录