Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5be6f762
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5be6f762
编写于
10月 25, 2018
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove_lock_in_some_ops
test=develop
上级
88376697
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
615 addition
and
34 deletion
+615
-34
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+5
-2
paddle/fluid/framework/details/computation_op_handle.cc
paddle/fluid/framework/details/computation_op_handle.cc
+12
-4
paddle/fluid/framework/details/computation_op_handle.h
paddle/fluid/framework/details/computation_op_handle.h
+12
-1
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc
...framework/details/modify_op_lock_and_record_event_pass.cc
+62
-0
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h
.../framework/details/modify_op_lock_and_record_event_pass.h
+32
-0
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+3
-3
paddle/fluid/framework/details/op_handle_graph.cc
paddle/fluid/framework/details/op_handle_graph.cc
+294
-0
paddle/fluid/framework/details/op_handle_graph.h
paddle/fluid/framework/details/op_handle_graph.h
+87
-0
paddle/fluid/framework/details/reference_count_op_handle.h
paddle/fluid/framework/details/reference_count_op_handle.h
+2
-2
paddle/fluid/framework/details/reference_count_pass.cc
paddle/fluid/framework/details/reference_count_pass.cc
+19
-12
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+6
-0
paddle/fluid/operators/conv_cudnn_op.cu.cc
paddle/fluid/operators/conv_cudnn_op.cu.cc
+5
-3
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
+5
-3
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+35
-4
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+36
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
5be6f762
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto node
)
cc_library
(
var_handle SRCS var_handle.cc DEPS place framework_proto node
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor
)
cc_library
(
op_handle_graph SRCS op_handle_graph.cc DEPS op_handle_base
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
cc_library
(
computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry
)
...
@@ -28,6 +29,8 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
...
@@ -28,6 +29,8 @@ cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_handle_graph multi_devices_helper
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
cc_library
(
reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass
)
...
@@ -37,9 +40,9 @@ cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_
...
@@ -37,9 +40,9 @@ cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass
modify_op_lock_and_record_event_pass
)
else
()
else
()
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
)
cc_library
(
ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto
modify_op_lock_and_wait_pass
)
endif
()
endif
()
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
cc_library
(
threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
...
...
paddle/fluid/framework/details/computation_op_handle.cc
浏览文件 @
5be6f762
...
@@ -20,18 +20,26 @@ namespace paddle {
...
@@ -20,18 +20,26 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
namespace
details
{
namespace
details
{
ComputationOpHandle
::
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
ComputationOpHandle
::
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
)
platform
::
Place
place
,
size_t
scope_idx
)
:
OpHandleBase
(
node
),
:
OpHandleBase
(
node
),
op_
(
framework
::
OpRegistry
::
CreateOp
(
*
node
->
Op
())),
op_
(
framework
::
OpRegistry
::
CreateOp
(
*
node
->
Op
())),
scope_
(
scope
),
scope_
(
scope
),
place_
(
place
)
{}
place_
(
place
),
scope_idx_
(
scope_idx
)
{}
void
ComputationOpHandle
::
RunImpl
()
{
void
ComputationOpHandle
::
RunImpl
()
{
WaitInputVarGenerated
(
place_
);
WaitInputVarGenerated
(
place_
);
this
->
RunAndRecordEvent
([
this
]
{
auto
run_func
=
[
this
]()
{
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
op_
->
Run
(
*
scope_
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
(),
place_
);
});
};
if
(
is_lock_and_record_event_free_
)
{
run_func
();
}
else
{
this
->
RunAndRecordEvent
(
run_func
);
}
}
}
bool
ComputationOpHandle
::
NeedWait
(
VarHandleBase
*
in_var
)
{
bool
ComputationOpHandle
::
NeedWait
(
VarHandleBase
*
in_var
)
{
...
...
paddle/fluid/framework/details/computation_op_handle.h
浏览文件 @
5be6f762
...
@@ -28,7 +28,8 @@ namespace framework {
...
@@ -28,7 +28,8 @@ namespace framework {
namespace
details
{
namespace
details
{
struct
ComputationOpHandle
:
public
OpHandleBase
{
struct
ComputationOpHandle
:
public
OpHandleBase
{
public:
public:
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
);
ComputationOpHandle
(
ir
::
Node
*
node
,
Scope
*
scope
,
platform
::
Place
place
,
size_t
scope_idx
);
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
...
@@ -36,6 +37,14 @@ struct ComputationOpHandle : public OpHandleBase {
...
@@ -36,6 +37,14 @@ struct ComputationOpHandle : public OpHandleBase {
const
platform
::
Place
&
GetPlace
()
const
{
return
place_
;
}
const
platform
::
Place
&
GetPlace
()
const
{
return
place_
;
}
size_t
GetScopeIdx
()
const
{
return
scope_idx_
;
}
OperatorBase
&
GetOp
()
{
return
*
op_
;
}
const
OperatorBase
&
GetOp
()
const
{
return
*
op_
;
}
void
SetLockAndRecordEventFree
(
bool
b
)
{
is_lock_and_record_event_free_
=
b
;
}
protected:
protected:
void
RunImpl
()
override
;
void
RunImpl
()
override
;
...
@@ -45,6 +54,8 @@ struct ComputationOpHandle : public OpHandleBase {
...
@@ -45,6 +54,8 @@ struct ComputationOpHandle : public OpHandleBase {
std
::
unique_ptr
<
OperatorBase
>
op_
;
std
::
unique_ptr
<
OperatorBase
>
op_
;
Scope
*
scope_
;
Scope
*
scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
size_t
scope_idx_
{
0
};
bool
is_lock_and_record_event_free_
{
false
};
};
};
}
// namespace details
}
// namespace details
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.cc
0 → 100644
浏览文件 @
5be6f762
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_graph.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
ComputationOpHandle
*
ConvertToComputationOpHandle
(
OpHandleBase
*
op
)
{
return
dynamic_cast
<
ComputationOpHandle
*>
(
op
);
}
static
bool
IsLockAndRecordEventFreeComputationOpHandle
(
ComputationOpHandle
*
op
,
const
OpHandleGraph
&
graph
)
{
for
(
auto
&
pending_op
:
graph
.
PendingOps
(
op
))
{
auto
*
tmp
=
ConvertToComputationOpHandle
(
pending_op
);
if
(
tmp
==
nullptr
||
!
(
tmp
->
GetPlace
()
==
op
->
GetPlace
()))
{
return
false
;
}
}
return
true
;
}
std
::
unique_ptr
<
ir
::
Graph
>
ModifyOpLockAndRecordEventPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ir_graph
)
const
{
auto
&
all_ops
=
ir_graph
->
Get
<
GraphOps
>
(
kGraphOps
);
OpHandleGraph
graph
(
all_ops
);
for
(
auto
&
op
:
all_ops
)
{
auto
*
compute_op
=
ConvertToComputationOpHandle
(
op
.
get
());
if
(
compute_op
==
nullptr
)
continue
;
bool
is_lock_and_record_event_free
=
IsLockAndRecordEventFreeComputationOpHandle
(
compute_op
,
graph
);
compute_op
->
SetLockAndRecordEventFree
(
is_lock_and_record_event_free
);
if
(
is_lock_and_record_event_free
)
{
VLOG
(
10
)
<<
"Set is_lock_and_record_event_free be true in op "
<<
compute_op
->
DebugString
();
}
}
return
ir_graph
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
modify_op_lock_and_record_event_pass
,
paddle
::
framework
::
details
::
ModifyOpLockAndRecordEventPass
);
paddle/fluid/framework/details/modify_op_lock_and_record_event_pass.h
0 → 100644
浏览文件 @
5be6f762
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
ModifyOpLockAndRecordEventPass
:
public
ir
::
Pass
{
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
override
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
5be6f762
...
@@ -513,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
...
@@ -513,7 +513,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
int
dev_id
)
const
{
int
dev_id
)
const
{
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
local_scopes_
[
dev_id
],
places_
[
dev_id
]));
local_scopes_
[
dev_id
],
places_
[
dev_id
]
,
dev_id
));
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
CreateOpHandleIOs
(
result
,
node
,
dev_id
);
}
}
...
@@ -630,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
...
@@ -630,8 +630,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result,
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
for
(
size_t
scope_idx
=
0
;
scope_idx
<
num_places
;
++
scope_idx
)
{
auto
p
=
places_
[
scope_idx
];
auto
p
=
places_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
auto
s
=
local_scopes_
[
scope_idx
];
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
ComputationOpHandle
(
new
ComputationOpHandle
(
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
));
result
->
CreateOpNode
(
node
->
Op
()),
s
,
p
,
scope_idx
));
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
CreateOpHandleIOs
(
result
,
node
,
scope_idx
);
}
}
}
}
...
...
paddle/fluid/framework/details/op_handle_graph.cc
0 → 100644
浏览文件 @
5be6f762
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/op_handle_graph.h"
#include <queue>
#include <utility>
namespace
paddle
{
namespace
framework
{
namespace
details
{
OpHandleGraph
::
OpHandleGraph
(
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
&
ops
)
{
BuildGraph
(
ops
);
}
void
OpHandleGraph
::
BuildGraph
(
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
&
ops
)
{
for
(
auto
&
op
:
ops
)
{
preceding_ops_
[
op
.
get
()];
pending_ops_
[
op
.
get
()];
for
(
auto
&
var
:
op
->
Outputs
())
{
for
(
auto
&
pending_op
:
var
->
PendingOps
())
{
preceding_ops_
[
pending_op
].
insert
(
op
.
get
());
pending_ops_
[
op
.
get
()].
insert
(
pending_op
);
}
}
}
PADDLE_ENFORCE
(
preceding_ops_
.
size
()
==
ops
.
size
()
&&
pending_ops_
.
size
()
==
ops
.
size
(),
"There are duplicate ops in graph."
);
}
size_t
OpHandleGraph
::
OpNumber
()
const
{
return
preceding_ops_
.
size
();
}
std
::
unordered_set
<
OpHandleBase
*>
OpHandleGraph
::
AllOps
()
const
{
std
::
unordered_set
<
OpHandleBase
*>
ret
;
for
(
auto
&
pair
:
preceding_ops_
)
{
ret
.
insert
(
pair
.
first
);
}
return
ret
;
}
bool
OpHandleGraph
::
HasOp
(
OpHandleBase
*
op
)
const
{
return
preceding_ops_
.
count
(
op
)
!=
0
;
}
void
OpHandleGraph
::
EnforceHasOp
(
OpHandleBase
*
op
)
const
{
PADDLE_ENFORCE
(
HasOp
(
op
),
"Cannot found op %s in OpHandleGraph"
,
op
==
nullptr
?
"nullptr"
:
op
->
DebugString
());
}
const
std
::
unordered_set
<
OpHandleBase
*>
&
OpHandleGraph
::
PrecedingOps
(
OpHandleBase
*
op
)
const
{
EnforceHasOp
(
op
);
return
preceding_ops_
.
at
(
op
);
}
const
std
::
unordered_set
<
OpHandleBase
*>
&
OpHandleGraph
::
PendingOps
(
OpHandleBase
*
op
)
const
{
EnforceHasOp
(
op
);
return
pending_ops_
.
at
(
op
);
}
std
::
vector
<
std
::
unordered_set
<
OpHandleBase
*>>
OpHandleGraph
::
AllPrecedingOps
(
OpHandleBase
*
op
)
const
{
EnforceHasOp
(
op
);
std
::
queue
<
OpHandleBase
*>
queue
[
2
];
int
cur
=
0
;
std
::
unordered_set
<
OpHandleBase
*>
visited_ops
;
std
::
vector
<
std
::
unordered_set
<
OpHandleBase
*>>
ret
;
for
(
auto
&
tmp
:
preceding_ops_
.
at
(
op
))
{
queue
[
cur
].
push
(
tmp
);
visited_ops
.
insert
(
tmp
);
}
while
(
!
queue
[
cur
].
empty
())
{
std
::
unordered_set
<
OpHandleBase
*>
cur_level_ops
;
auto
*
tmp
=
queue
[
cur
].
front
();
queue
[
cur
].
pop
();
for
(
auto
&
preceding_op
:
preceding_ops_
.
at
(
tmp
))
{
if
(
visited_ops
.
count
(
preceding_op
))
{
continue
;
}
else
{
queue
[
1
-
cur
].
push
(
preceding_op
);
cur_level_ops
.
insert
(
preceding_op
);
visited_ops
.
insert
(
preceding_op
);
}
}
if
(
!
cur_level_ops
.
empty
())
{
ret
.
emplace_back
(
std
::
move
(
cur_level_ops
));
}
cur
=
1
-
cur
;
}
return
ret
;
}
std
::
vector
<
std
::
unordered_set
<
OpHandleBase
*>>
OpHandleGraph
::
AllPendingOps
(
OpHandleBase
*
op
)
const
{
EnforceHasOp
(
op
);
std
::
queue
<
OpHandleBase
*>
queue
[
2
];
int
cur
=
0
;
std
::
unordered_set
<
OpHandleBase
*>
visited_ops
;
std
::
vector
<
std
::
unordered_set
<
OpHandleBase
*>>
ret
;
for
(
auto
&
tmp
:
preceding_ops_
.
at
(
op
))
{
queue
[
cur
].
push
(
tmp
);
visited_ops
.
insert
(
tmp
);
}
while
(
!
queue
[
cur
].
empty
())
{
std
::
unordered_set
<
OpHandleBase
*>
cur_level_ops
;
auto
*
tmp
=
queue
[
cur
].
front
();
queue
[
cur
].
pop
();
for
(
auto
&
next_op
:
pending_ops_
.
at
(
tmp
))
{
if
(
visited_ops
.
count
(
next_op
))
{
continue
;
}
else
{
queue
[
1
-
cur
].
push
(
next_op
);
cur_level_ops
.
insert
(
next_op
);
visited_ops
.
insert
(
next_op
);
}
}
if
(
!
cur_level_ops
.
empty
())
{
ret
.
emplace_back
(
std
::
move
(
cur_level_ops
));
}
cur
=
1
-
cur
;
}
return
ret
;
}
OpHandleGraph
::
Relation
OpHandleGraph
::
RelationBetween
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
EnforceHasOp
(
op1
);
EnforceHasOp
(
op2
);
if
(
op1
==
op2
)
{
return
kSame
;
}
else
if
(
IsBeforeOrSameImpl
(
op1
,
op2
))
{
return
kBefore
;
}
else
if
(
IsBeforeOrSameImpl
(
op2
,
op1
))
{
return
kAfter
;
}
else
{
return
kNoDeps
;
}
}
bool
OpHandleGraph
::
IsSame
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
EnforceHasOp
(
op1
);
EnforceHasOp
(
op2
);
return
op1
==
op2
;
}
bool
OpHandleGraph
::
IsBeforeOrSame
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
EnforceHasOp
(
op1
);
EnforceHasOp
(
op2
);
return
IsBeforeOrSameImpl
(
op1
,
op2
);
}
bool
OpHandleGraph
::
IsBefore
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
EnforceHasOp
(
op1
);
EnforceHasOp
(
op2
);
return
op1
!=
op2
&&
IsBeforeOrSameImpl
(
op1
,
op2
);
}
bool
OpHandleGraph
::
IsBeforeOrSameImpl
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
std
::
queue
<
OpHandleBase
*>
queue
;
// BFS
queue
.
push
(
op1
);
do
{
auto
*
op
=
queue
.
front
();
queue
.
pop
();
if
(
op
==
op2
)
return
true
;
for
(
auto
&
pending_op
:
pending_ops_
.
at
(
op
))
{
queue
.
push
(
pending_op
);
}
}
while
(
!
queue
.
empty
());
return
false
;
}
bool
OpHandleGraph
::
IsAfterOrSame
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
EnforceHasOp
(
op1
);
EnforceHasOp
(
op2
);
return
IsBeforeOrSameImpl
(
op2
,
op1
);
}
bool
OpHandleGraph
::
IsAfter
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
return
IsBefore
(
op2
,
op1
);
}
bool
OpHandleGraph
::
IsNoDeps
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
return
RelationBetween
(
op1
,
op2
)
==
kNoDeps
;
}
std
::
unordered_set
<
OpHandleBase
*>
OpHandleGraph
::
NoPendingOpSet
()
const
{
std
::
unordered_set
<
OpHandleBase
*>
ret
;
for
(
auto
&
pair
:
pending_ops_
)
{
if
(
pair
.
second
.
empty
())
ret
.
insert
(
pair
.
first
);
}
return
ret
;
}
std
::
unordered_set
<
OpHandleBase
*>
OpHandleGraph
::
NoPrecedingOpSet
()
const
{
std
::
unordered_set
<
OpHandleBase
*>
ret
;
for
(
auto
&
pair
:
preceding_ops_
)
{
if
(
pair
.
second
.
empty
())
ret
.
insert
(
pair
.
first
);
}
return
ret
;
}
OpHandleBase
*
OpHandleGraph
::
NearestCommonParent
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
EnforceHasOp
(
op1
);
EnforceHasOp
(
op2
);
// FIXME(zjl): A brute-force O(2*n) algorithm here
// First, BFS all preceding_ops of op1 and record them in set S
// Second, BFS all preceding_ops of op2 and found whether it is in set S
std
::
unordered_set
<
OpHandleBase
*>
all_preceding_ops
;
std
::
queue
<
OpHandleBase
*>
queue
;
queue
.
push
(
op1
);
do
{
auto
*
op
=
queue
.
front
();
queue
.
pop
();
all_preceding_ops
.
insert
(
op
);
for
(
auto
&
preceding_op
:
preceding_ops_
.
at
(
op
))
{
queue
.
push
(
preceding_op
);
}
}
while
(
!
queue
.
empty
());
queue
.
push
(
op2
);
do
{
auto
*
op
=
queue
.
front
();
queue
.
pop
();
if
(
all_preceding_ops
.
count
(
op
))
return
op
;
for
(
auto
&
preceding_op
:
preceding_ops_
.
at
(
op
))
{
queue
.
push
(
preceding_op
);
}
}
while
(
!
queue
.
empty
());
return
nullptr
;
}
OpHandleBase
*
OpHandleGraph
::
NearestCommonParentAfter
(
OpHandleBase
*
op
,
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
{
EnforceHasOp
(
op
);
EnforceHasOp
(
op1
);
EnforceHasOp
(
op2
);
std
::
unordered_map
<
OpHandleBase
*
,
int
>
all_preceding_ops
;
int
max_depth
=
-
1
;
std
::
queue
<
std
::
pair
<
OpHandleBase
*
,
int
>>
queue
;
queue
.
push
(
std
::
make_pair
(
op1
,
0
));
do
{
auto
tmp
=
queue
.
front
();
queue
.
pop
();
all_preceding_ops
.
insert
(
tmp
);
if
(
tmp
.
first
==
op1
)
{
max_depth
=
tmp
.
second
;
break
;
}
for
(
auto
&
preceding_op
:
preceding_ops_
.
at
(
tmp
.
first
))
{
queue
.
push
(
std
::
make_pair
(
preceding_op
,
tmp
.
second
+
1
));
}
}
while
(
!
queue
.
empty
());
if
(
max_depth
==
-
1
)
{
return
nullptr
;
}
std
::
queue
<
OpHandleBase
*>
queue2
;
queue2
.
push
(
op2
);
do
{
auto
*
tmp
=
queue2
.
front
();
queue2
.
pop
();
if
(
all_preceding_ops
.
count
(
tmp
)
&&
(
tmp
==
op
||
all_preceding_ops
[
tmp
]
<
max_depth
))
{
return
tmp
;
}
}
while
(
!
queue2
.
empty
());
return
nullptr
;
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/op_handle_graph.h
0 → 100644
浏览文件 @
5be6f762
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
class
OpHandleGraph
{
public:
enum
Relation
{
kSame
=
0
,
kBefore
=
1
,
kAfter
=
2
,
kNoDeps
=
3
};
explicit
OpHandleGraph
(
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
&
ops
);
size_t
OpNumber
()
const
;
std
::
unordered_set
<
OpHandleBase
*>
AllOps
()
const
;
const
std
::
unordered_set
<
OpHandleBase
*>
&
PrecedingOps
(
OpHandleBase
*
op
)
const
;
const
std
::
unordered_set
<
OpHandleBase
*>
&
PendingOps
(
OpHandleBase
*
op
)
const
;
std
::
vector
<
std
::
unordered_set
<
OpHandleBase
*>>
AllPrecedingOps
(
OpHandleBase
*
op
)
const
;
std
::
vector
<
std
::
unordered_set
<
OpHandleBase
*>>
AllPendingOps
(
OpHandleBase
*
op
)
const
;
bool
HasOp
(
OpHandleBase
*
op
)
const
;
Relation
RelationBetween
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
bool
IsSame
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
bool
IsBeforeOrSame
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
bool
IsBefore
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
bool
IsAfterOrSame
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
bool
IsAfter
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
bool
IsNoDeps
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
OpHandleBase
*
NearestCommonParent
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
// Find an operator that is after op and before op1, op2
OpHandleBase
*
NearestCommonParentAfter
(
OpHandleBase
*
op
,
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
std
::
unordered_set
<
OpHandleBase
*>
NoPendingOpSet
()
const
;
std
::
unordered_set
<
OpHandleBase
*>
NoPrecedingOpSet
()
const
;
private:
void
BuildGraph
(
const
std
::
vector
<
std
::
unique_ptr
<
OpHandleBase
>>
&
ops
);
void
EnforceHasOp
(
OpHandleBase
*
op
)
const
;
bool
IsBeforeOrSameImpl
(
OpHandleBase
*
op1
,
OpHandleBase
*
op2
)
const
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unordered_set
<
OpHandleBase
*>>
preceding_ops_
;
std
::
unordered_map
<
OpHandleBase
*
,
std
::
unordered_set
<
OpHandleBase
*>>
pending_ops_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/reference_count_op_handle.h
浏览文件 @
5be6f762
...
@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
...
@@ -51,7 +51,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
dev_ctx_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
if
(
IsStreamGarabageCollector
())
{
if
(
IsStreamGarabageCollector
())
{
PADDLE_ENFORCE
(
cudaSetDevice
(
place
.
device
)
);
platform
::
SetDeviceId
(
place
.
device
);
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
event_
,
cudaEventDisableTiming
));
}
}
...
@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
...
@@ -61,7 +61,7 @@ class ReferenceCountOpHandle : public OpHandleBase {
~
ReferenceCountOpHandle
()
{
~
ReferenceCountOpHandle
()
{
if
(
IsStreamGarabageCollector
())
{
if
(
IsStreamGarabageCollector
())
{
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx_
->
GetPlace
());
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dev_ctx_
->
GetPlace
());
PADDLE_ENFORCE
(
cudaSetDevice
(
gpu_place
.
device
)
);
platform
::
SetDeviceId
(
gpu_place
.
device
);
PADDLE_ENFORCE
(
cudaEventDestroy
(
event_
));
PADDLE_ENFORCE
(
cudaEventDestroy
(
event_
));
}
}
}
}
...
...
paddle/fluid/framework/details/reference_count_pass.cc
浏览文件 @
5be6f762
...
@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
...
@@ -43,6 +43,23 @@ static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) {
return
nullptr
;
return
nullptr
;
}
}
static
void
AddDependencyBetween
(
OpHandleBase
*
in
,
OpHandleBase
*
out
,
ir
::
Graph
*
graph
)
{
auto
it
=
std
::
find_if
(
in
->
Outputs
().
begin
(),
in
->
Outputs
().
end
(),
[](
VarHandleBase
*
var
)
{
return
dynamic_cast
<
DummyVarHandle
*>
(
var
)
!=
nullptr
;
});
if
(
it
!=
in
->
Outputs
().
end
())
{
out
->
AddInput
(
*
it
);
}
else
{
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
in
->
AddOutput
(
dep_var
);
out
->
AddInput
(
dep_var
);
}
}
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
ReferenceCountPass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
auto
&
ref_cnts
=
Get
<
DeviceReferenceCountMap
>
(
kGlobalReferenceCount
);
auto
&
ref_cnts
=
Get
<
DeviceReferenceCountMap
>
(
kGlobalReferenceCount
);
...
@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -133,12 +150,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
ref_cnt_node
,
next_compute_op
->
GetScope
(),
place
,
{
var_name
},
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
if
(
next_compute_op
->
Outputs
().
empty
())
{
AddDependencyBetween
(
next_compute_op
,
ref_cnt_handle
,
graph
.
get
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
next_compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
ref_cnt_handle
->
AddInput
(
next_compute_op
->
Outputs
().
front
());
compute_ref_cnt_map
[
next_compute_op
].
reset
(
ref_cnt_handle
);
compute_ref_cnt_map
[
next_compute_op
].
reset
(
ref_cnt_handle
);
}
}
}
}
...
@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
...
@@ -160,12 +172,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
auto
*
ref_cnt_handle
=
new
ReferenceCountOpHandle
(
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
ref_cnt_node
,
compute_op
->
GetScope
(),
place
,
in_var_names
,
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
gcs
[
place
.
device
].
get
(),
cur_ref_cnts
[
place
.
device
].
get
());
if
(
compute_op
->
Outputs
().
empty
())
{
AddDependencyBetween
(
compute_op
,
ref_cnt_handle
,
graph
.
get
());
auto
*
dep_var
=
new
DummyVarHandle
(
graph
->
CreateControlDepVar
());
compute_op
->
AddOutput
(
dep_var
);
graph
->
Get
<
GraphDepVars
>
(
kGraphDepVars
).
emplace
(
dep_var
);
}
ref_cnt_handle
->
AddInput
(
compute_op
->
Outputs
().
front
());
compute_ref_cnt_map
[
compute_op
].
reset
(
ref_cnt_handle
);
compute_ref_cnt_map
[
compute_op
].
reset
(
ref_cnt_handle
);
}
}
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
5be6f762
...
@@ -156,6 +156,10 @@ ParallelExecutor::ParallelExecutor(
...
@@ -156,6 +156,10 @@ ParallelExecutor::ParallelExecutor(
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
params
,
member_
->
local_scopes_
,
member_
->
use_cuda_
);
#endif
#endif
graph
=
ir
::
PassRegistry
::
Instance
()
.
Get
(
"modify_op_lock_and_record_event_pass"
)
->
Apply
(
std
::
move
(
graph
));
// If the loss_var_name is given, the number of graph should be only one.
// If the loss_var_name is given, the number of graph should be only one.
if
(
loss_var_name
.
size
())
{
if
(
loss_var_name
.
size
())
{
PADDLE_ENFORCE_EQ
(
ir
::
GraphNum
(
*
graph
),
1
,
PADDLE_ENFORCE_EQ
(
ir
::
GraphNum
(
*
graph
),
1
,
...
@@ -319,6 +323,8 @@ ParallelExecutor::~ParallelExecutor() {
...
@@ -319,6 +323,8 @@ ParallelExecutor::~ParallelExecutor() {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
USE_PASS
(
modify_op_lock_and_record_event_pass
);
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
USE_PASS
(
reference_count_pass
);
USE_PASS
(
reference_count_pass
);
#endif
#endif
paddle/fluid/operators/conv_cudnn_op.cu.cc
浏览文件 @
5be6f762
...
@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -160,6 +160,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward ---------------------
// ------------------- cudnn conv forward ---------------------
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
for
(
int
i
=
0
;
i
<
groups
;
i
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionForward
(
...
@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
...
@@ -168,7 +169,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
cudnn_conv_desc
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_output_desc
,
output_data
+
i
*
group_offset_out
));
&
beta
,
cudnn_output_desc
,
output_data
+
i
*
group_offset_out
));
};
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
};
};
...
@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -314,6 +315,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data ---------------------
// ------------------- cudnn conv backward data ---------------------
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
ScalingParamType
<
T
>
alpha
=
1.0
f
,
beta
=
0.0
f
;
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
if
(
input_grad
)
{
if
(
input_grad
)
{
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Because beta is zero, it is unnecessary to reset input_grad.
// Because beta is zero, it is unnecessary to reset input_grad.
...
@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -327,7 +329,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
data_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
data_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_input_desc
,
input_grad_data
+
i
*
group_offset_in
));
cudnn_input_desc
,
input_grad_data
+
i
*
group_offset_in
));
};
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
// ------------------- cudnn conv backward filter ---------------------
// ------------------- cudnn conv backward filter ---------------------
...
@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
...
@@ -343,7 +345,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
filter_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
filter_algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_filter_desc
,
filter_grad_data
+
i
*
group_offset_filter
));
cudnn_filter_desc
,
filter_grad_data
+
i
*
group_offset_filter
));
};
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
}
}
...
...
paddle/fluid/operators/conv_transpose_cudnn_op.cu.cc
浏览文件 @
5be6f762
...
@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
...
@@ -104,6 +104,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int
output_offset
=
output
->
numel
()
/
output
->
dims
()[
0
]
/
groups
;
int
output_offset
=
output
->
numel
()
/
output
->
dims
()[
0
]
/
groups
;
int
filter_offset
=
filter
->
numel
()
/
groups
;
int
filter_offset
=
filter
->
numel
()
/
groups
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
for
(
int
g
=
0
;
g
<
groups
;
g
++
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
auto
cudnn_func
=
[
&
](
void
*
cudnn_workspace
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnConvolutionBackwardData
(
...
@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
...
@@ -112,7 +113,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
algo
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_output_desc
,
output_data
+
output_offset
*
g
));
cudnn_output_desc
,
output_data
+
output_offset
*
g
));
};
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
};
};
...
@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -208,6 +209,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
output_grad
->
numel
()
/
output_grad
->
dims
()[
0
]
/
groups
;
output_grad
->
numel
()
/
output_grad
->
dims
()[
0
]
/
groups
;
int
filter_offset
=
filter
->
numel
()
/
groups
;
int
filter_offset
=
filter
->
numel
()
/
groups
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
T
alpha
=
1.0
f
,
beta
=
0.0
f
;
auto
workspace_handle
=
dev_ctx
.
cudnn_workspace_handle
();
if
(
input_grad
)
{
if
(
input_grad
)
{
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// Because beta is zero, it is unnecessary to reset input_grad.
// Because beta is zero, it is unnecessary to reset input_grad.
...
@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -220,7 +222,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_input_desc
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_input_desc
,
input_grad_data
+
input_offset
*
g
));
input_grad_data
+
input_offset
*
g
));
};
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
...
@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
...
@@ -238,7 +240,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_workspace
,
workspace_size_in_bytes
,
&
beta
,
cudnn_filter_desc
,
filter_grad_data
+
filter_offset
*
g
));
cudnn_filter_desc
,
filter_grad_data
+
filter_offset
*
g
));
};
};
dev_ctx
.
RunCudnnFuncWithWorkspace
(
cudnn_func
,
workspace_size_in_bytes
);
workspace_handle
.
RunFunc
(
cudnn_func
,
workspace_size_in_bytes
);
}
}
}
}
}
}
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
5be6f762
...
@@ -168,10 +168,7 @@ class CudnnHolder {
...
@@ -168,10 +168,7 @@ class CudnnHolder {
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_len
)
{
size_t
required_workspace_len
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mtx_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mtx_
);
if
(
required_workspace_len
>
workspace_len_
)
{
RunFuncImpl
(
cudnn_func
,
required_workspace_len
);
ReallocateWorkspace
(
required_workspace_len
);
}
cudnn_func
(
workspace_
);
}
}
~
CudnnHolder
()
{
~
CudnnHolder
()
{
...
@@ -182,6 +179,16 @@ class CudnnHolder {
...
@@ -182,6 +179,16 @@ class CudnnHolder {
}
}
private:
private:
std
::
mutex
&
Mutex
()
{
return
mtx_
;
}
void
RunFuncImpl
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_len
)
{
if
(
required_workspace_len
>
workspace_len_
)
{
ReallocateWorkspace
(
required_workspace_len
);
}
cudnn_func
(
workspace_
);
}
void
ReallocateWorkspace
(
size_t
required_workspace_len
)
{
void
ReallocateWorkspace
(
size_t
required_workspace_len
)
{
if
(
required_workspace_len
<=
workspace_len_
)
{
if
(
required_workspace_len
<=
workspace_len_
)
{
return
;
return
;
...
@@ -195,6 +202,8 @@ class CudnnHolder {
...
@@ -195,6 +202,8 @@ class CudnnHolder {
workspace_len_
=
required_workspace_len
;
workspace_len_
=
required_workspace_len
;
}
}
friend
class
CudnnWorkspaceHandle
;
cudnnHandle_t
cudnn_handle_
;
cudnnHandle_t
cudnn_handle_
;
void
*
workspace_
;
void
*
workspace_
;
size_t
workspace_len_
;
size_t
workspace_len_
;
...
@@ -205,6 +214,24 @@ class CudnnHolder {
...
@@ -205,6 +214,24 @@ class CudnnHolder {
std
::
mutex
mtx_
;
std
::
mutex
mtx_
;
};
};
CudnnWorkspaceHandle
::
CudnnWorkspaceHandle
(
CudnnHolder
*
holder
)
:
holder_
(
holder
)
{}
void
CudnnWorkspaceHandle
::
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_len
)
{
// defer lock when the function is invoked first time
BeginCallGuard
();
holder_
->
RunFuncImpl
(
cudnn_func
,
required_workspace_len
);
}
void
CudnnWorkspaceHandle
::
BeginCallGuard
()
{
if
(
!
guard_
)
{
guard_
.
reset
(
new
std
::
lock_guard
<
std
::
mutex
>
(
holder_
->
Mutex
()));
}
}
void
CudnnWorkspaceHandle
::
EndCallGuard
()
{
guard_
.
reset
();
}
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
),
cudnn_holder_
(
nullptr
)
{
:
place_
(
place
),
cudnn_holder_
(
nullptr
)
{
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
...
@@ -271,6 +298,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
...
@@ -271,6 +298,10 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return
cudnn_holder_
->
cudnn_handle
();
return
cudnn_holder_
->
cudnn_handle
();
}
}
CudnnWorkspaceHandle
CUDADeviceContext
::
cudnn_workspace_handle
()
const
{
return
CudnnWorkspaceHandle
(
cudnn_holder_
.
get
());
}
void
CUDADeviceContext
::
RunCudnnFuncWithWorkspace
(
void
CUDADeviceContext
::
RunCudnnFuncWithWorkspace
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
workspace_len
)
const
{
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
workspace_len
)
const
{
cudnn_holder_
->
RunFunc
(
cudnn_func
,
workspace_len
);
cudnn_holder_
->
RunFunc
(
cudnn_func
,
workspace_len
);
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
5be6f762
...
@@ -74,6 +74,33 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
...
@@ -74,6 +74,33 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
class
EigenCudaStreamDevice
;
class
EigenCudaStreamDevice
;
class
CudnnHolder
;
class
CudnnHolder
;
class
CudnnWorkspaceHandle
{
public:
/*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */
explicit
CudnnWorkspaceHandle
(
CudnnHolder
*
holder
);
/*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */
void
RunFunc
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
size_t
required_workspace_len
);
/*! \brief User can call this method to acquire the lock manually,
* But it is usually unnecessary, because RunFunc() would
* acquire the lock first before invoking cudnn functions. */
void
BeginCallGuard
();
/*! \brief User can call this method to release the lock manually,
* But it is usually unnecssary, because the lock would be
* release once the handle is destructed. But it can be used
* to manually release the lock as soon as possible. */
void
EndCallGuard
();
private:
CudnnHolder
*
holder_
;
// not own
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
};
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
@@ -100,6 +127,15 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -100,6 +127,15 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
cudnnHandle_t
cudnn_handle
()
const
;
/*! \brief Return a cudnn workspace handle to call multiple cudnn
* functions without interrupting by other threads.
* Once the first cudnn function is called by the handle, a lock
* would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
CudnnWorkspaceHandle
cudnn_workspace_handle
()
const
;
/*! \brief Run a cudnn function with the workspace provided by
/*! \brief Run a cudnn function with the workspace provided by
* CUDADeviceContext */
* CUDADeviceContext */
void
RunCudnnFuncWithWorkspace
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
void
RunCudnnFuncWithWorkspace
(
const
std
::
function
<
void
(
void
*
)
>&
cudnn_func
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录