Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f855c05f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f855c05f
编写于
9月 21, 2018
作者:
Z
Zeng Jinle
提交者:
GitHub
9月 21, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13520 from sneaxiy/enhance_eager_delete
Enhance eager delete and sparse Adam
上级
3043f51b
192c49cb
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
123 addition
and
37 deletion
+123
-37
paddle/fluid/framework/details/reference_count_op_handle.h
paddle/fluid/framework/details/reference_count_op_handle.h
+28
-13
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+64
-11
paddle/fluid/operators/adam_op.h
paddle/fluid/operators/adam_op.h
+31
-13
未找到文件。
paddle/fluid/framework/details/reference_count_op_handle.h
浏览文件 @
f855c05f
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
...
...
@@ -46,17 +47,15 @@ class ReferenceCountOpHandle : public OpHandleBase {
const
std
::
vector
<
std
::
string
>
&
var_names
,
GarbageCollector
<
Tensor
>
*
gc
,
AtomicReferenceCountMap
*
ref_cnts
)
:
OpHandleBase
(
node
),
scope_
(
scope
),
var_names_
(
var_names
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
:
OpHandleBase
(
node
),
scope_
(
scope
),
gc_
(
gc
),
ref_cnts_
(
ref_cnts
)
{
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
IsStreamGarabageCollector
())
{
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
}
for
(
auto
&
name
:
var_names
)
AddVar
(
name
);
}
~
ReferenceCountOpHandle
()
{
...
...
@@ -69,19 +68,35 @@ class ReferenceCountOpHandle : public OpHandleBase {
std
::
string
Name
()
const
override
{
return
"reference_count"
;
}
void
AddVar
(
const
std
::
string
&
name
)
{
auto
it
=
var_names_
.
find
(
name
);
if
(
it
!=
var_names_
.
end
())
++
(
it
->
second
);
else
var_names_
[
name
]
=
1
;
}
protected:
void
RunImpl
()
override
{
auto
*
exec_scope
=
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
std
::
vector
<
LoDTensor
*>
tensors
;
for
(
auto
&
name
:
var_names_
)
{
std
::
vector
<
Tensor
*>
tensors
;
for
(
auto
&
pair
:
var_names_
)
{
auto
&
name
=
pair
.
first
;
auto
it
=
ref_cnts_
->
find
(
name
);
if
(
it
==
ref_cnts_
->
end
())
continue
;
auto
*
var
=
exec_scope
->
FindVar
(
name
);
if
(
var
==
nullptr
||
!
var
->
IsType
<
LoDTensor
>
())
continue
;
if
(
it
->
second
.
fetch_sub
(
1
)
<=
1
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
if
(
var
==
nullptr
)
continue
;
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
it
->
second
.
fetch_sub
(
pair
.
second
)
<=
pair
.
second
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
LoDTensor
>
());
}
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
if
(
it
->
second
.
fetch_sub
(
pair
.
second
)
<=
pair
.
second
)
{
tensors
.
emplace_back
(
var
->
GetMutable
<
SelectedRows
>
()
->
mutable_value
());
}
}
}
...
...
@@ -91,7 +106,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
}
private:
void
ClearTensors
(
const
std
::
vector
<
LoD
Tensor
*>
&
tensors
)
{
void
ClearTensors
(
const
std
::
vector
<
Tensor
*>
&
tensors
)
{
auto
*
gc
=
dynamic_cast
<
StreamGarbageCollector
<
Tensor
>
*>
(
gc_
);
if
(
gc
!=
nullptr
)
{
auto
compute_stream
=
dev_ctx_
->
stream
();
...
...
@@ -112,7 +127,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
const
Scope
*
scope_
;
platform
::
CUDADeviceContext
*
dev_ctx_
;
std
::
vector
<
std
::
string
>
var_names_
;
std
::
unordered_map
<
std
::
string
,
int
>
var_names_
;
GarbageCollector
<
Tensor
>
*
gc_
;
// not own
AtomicReferenceCountMap
*
ref_cnts_
;
// not own
cudaEvent_t
event_
;
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
f855c05f
...
...
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <queue>
#include <string>
#include <vector>
...
...
@@ -23,6 +24,25 @@ namespace paddle {
namespace
framework
{
namespace
details
{
static
ComputationOpHandle
*
FindNextComputationOpHandle
(
VarHandle
*
var_in
)
{
std
::
queue
<
VarHandleBase
*>
queue
;
queue
.
push
(
var_in
);
do
{
auto
*
var
=
queue
.
front
();
queue
.
pop
();
for
(
auto
*
op
:
var
->
PendingOps
())
{
auto
*
compute_op
=
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
if
(
compute_op
!=
nullptr
&&
compute_op
->
GetPlace
()
==
var_in
->
place_
)
{
return
compute_op
;
}
for
(
auto
*
out_var
:
op
->
Outputs
())
{
queue
.
push
(
out_var
);
}
}
}
while
(
!
queue
.
empty
());
return
nullptr
;
}
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
ref_cnts
=
Get
<
DeviceReferenceCountMap
>
(
kGlobalReferenceCount
);
...
...
@@ -34,6 +54,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
// Step 2: Find all variables in non-computation ops which refers to variables
// in computation ops
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unique_ptr
<
ReferenceCountOpHandle
>>
compute_ref_cnt_map
;
auto
get_ref_cnts_from_compute_op
=
[
&
](
const
std
::
unique_ptr
<
OpHandleBase
>
&
op
,
const
std
::
vector
<
VarHandleBase
*>
&
vars
)
{
...
...
@@ -54,15 +77,18 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
VarDesc
*
var_desc
=
var_handle
->
Node
()
->
Var
();
auto
var_name
=
var_handle
->
Node
()
->
Name
();
// This is w
ie
rd but there is really some variables without var_desc
// This is w
ei
rd but there is really some variables without var_desc
// in computation_op
if
(
var_desc
==
nullptr
)
{
if
(
compute_op
->
Node
()
->
Op
()
->
Block
()
->
FindVar
(
var_name
)
==
nullptr
)
continue
;
}
else
{
if
(
var_desc
->
Persistable
()
||
var_desc
->
Proto
()
->
type
().
type
()
!=
proto
::
VarType
::
LOD_TENSOR
)
if
(
var_desc
->
Persistable
())
continue
;
auto
var_type
=
var_desc
->
Proto
()
->
type
().
type
();
if
(
var_type
!=
proto
::
VarType
::
LOD_TENSOR
&&
var_type
!=
proto
::
VarType
::
SELECTED_ROWS
)
{
continue
;
}
}
// compute op only runs in one device
...
...
@@ -93,12 +119,33 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
if
(
ref_cnts
.
count
(
place
.
device
)
&&
ref_cnts
[
place
.
device
]
->
count
(
var_name
))
{
++
(
*
ref_cnts
[
place
.
device
])[
var_name
];
auto
*
next_compute_op
=
FindNextComputationOpHandle
(
var_handle
);
if
(
next_compute_op
!=
nullptr
)
{
if
(
compute_ref_cnt_map
.
count
(
next_compute_op
))
{
compute_ref_cnt_map
[
next_compute_op
]
->
AddVar
(
var_name
);
VLOG
(
5
)
<<
"Add reference count of "
<<
var_name
<<
" to Operator "
<<
next_compute_op
->
Name
();
}
else
{
// Create new reference_count_op_handle
ir
::
Node
*
ref_cnt_node
=
graph
->
CreateEmptyNode
(
"reference_count"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
if
(
next_compute_op
->
Outputs
().
empty
())
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
next_compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
ref_cnt_handle
->
AddInput
(
next_compute_op
->
Outputs
().
front
());
compute_ref_cnt_map
[
next_compute_op
].
reset
(
ref_cnt_handle
);
}
}
}
}
};
std
::
unordered_map
<
OpHandleBase
*
,
ReferenceCountOpHandle
*>
compute_ref_cnt_map
;
auto
&
all_ops
=
graph
->
Get
<
GraphOps
>
(
kGraphOps
);
for
(
auto
&
op
:
all_ops
)
{
auto
in_var_names
=
get_ref_cnts_from_compute_op
(
op
,
op
->
Inputs
());
...
...
@@ -113,11 +160,13 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
compute_op
->
AddOutput
(
dep_var
);
ref_cnt_handle
->
AddInput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
compute_ref_cnt_map
[
compute_op
]
=
ref_cnt_handle
;
if
(
compute_op
->
Outputs
().
empty
())
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
ref_cnt_handle
->
AddInput
(
compute_op
->
Outputs
().
front
());
compute_ref_cnt_map
[
compute_op
].
reset
(
ref_cnt_handle
);
}
for
(
auto
&
op
:
all_ops
)
{
...
...
@@ -131,7 +180,11 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
new_all_ops
.
emplace_back
(
std
::
move
(
op
));
auto
it
=
compute_ref_cnt_map
.
find
(
new_all_ops
.
back
().
get
());
if
(
it
!=
compute_ref_cnt_map
.
end
())
{
new_all_ops
.
emplace_back
(
it
->
second
);
// Add LeafNode to ReferenceCountOpHandle
auto
*
dummy_leaf
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dummy_leaf
);
it
->
second
->
AddOutput
(
dummy_leaf
);
new_all_ops
.
emplace_back
(
std
::
move
(
it
->
second
));
}
}
...
...
paddle/fluid/operators/adam_op.h
浏览文件 @
f855c05f
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <math.h> // for sqrt in CPU and CUDA
#include <Eigen/Dense>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
...
...
@@ -306,26 +307,43 @@ class AdamOpKernel : public framework::OpKernel<T> {
VLOG
(
3
)
<<
"grad row size is 0!!"
;
return
;
}
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
auto
&
grad_merge
=
*
(
ctx
.
scope
()
.
NewScope
()
.
Var
(
"sparse_adam_grad_merge"
)
->
GetMutable
<
framework
::
SelectedRows
>
());
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
&
grad_merge
);
std
::
vector
<
int64_t
>
cpu_rows
(
grad
.
rows
().
begin
(),
grad
.
rows
().
end
());
bool
is_strict_sorted
=
true
;
for
(
size_t
i
=
1
;
i
<
cpu_rows
.
size
();
++
i
)
{
if
(
cpu_rows
[
i
-
1
]
>=
cpu_rows
[
i
])
{
is_strict_sorted
=
false
;
break
;
}
}
const
framework
::
SelectedRows
*
grad_merge_ptr
;
if
(
is_strict_sorted
)
{
grad_merge_ptr
=
&
grad
;
}
else
{
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter
::
MergeAdd
<
DeviceContext
,
T
>
merge_func
;
auto
*
grad_merge_var
=
const_cast
<
framework
::
Scope
&>
(
ctx
.
scope
())
.
Var
()
->
GetMutable
<
framework
::
SelectedRows
>
();
merge_func
(
ctx
.
template
device_context
<
DeviceContext
>(),
grad
,
grad_merge_var
);
grad_merge_ptr
=
grad_merge_var
;
}
auto
&
grad_merge
=
*
grad_merge_ptr
;
auto
&
grad_tensor
=
grad_merge
.
value
();
const
T
*
grad_data
=
grad_tensor
.
template
data
<
T
>();
int64_t
*
rows
=
nullptr
;
// When compiled without CUDA, the CUDA
Mutable
Data() interface should not be
const
int64_t
*
rows
=
nullptr
;
// When compiled without CUDA, the CUDAData() interface should not be
// provided.
#if defined(PADDLE_WITH_CUDA)
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
rows
=
grad_merge
.
mutable_rows
()
->
CUDAMutable
Data
(
ctx
.
GetPlace
());
rows
=
grad_merge
.
rows
().
CUDA
Data
(
ctx
.
GetPlace
());
}
else
{
#endif
rows
=
grad_merge
.
mutable_rows
()
->
data
();
rows
=
grad_merge
.
rows
().
data
();
#if defined(PADDLE_WITH_CUDA)
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录