Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
420fdbb2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
420fdbb2
编写于
1月 19, 2021
作者:
L
liuyuhui
提交者:
GitHub
1月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Kunlun]PR3: add xpu executor, multi xpu card train function optimization (#30317) (#30535)
上级
7a4ccf59
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
445 addition
and
26 deletion
+445
-26
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-0
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc
...uid/framework/details/bind_threaded_ssa_graph_executor.cc
+316
-0
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h
...luid/framework/details/bind_threaded_ssa_graph_executor.h
+107
-0
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+0
-20
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+18
-4
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+1
-1
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
420fdbb2
...
...
@@ -267,7 +267,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy collective_helper
graph build_strategy
bind_threaded_ssa_graph_executor
collective_helper
fast_threaded_ssa_graph_executor variable_helper
)
cc_test
(
dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
420fdbb2
...
...
@@ -101,6 +101,8 @@ cc_library(scope_buffered_monitor SRCS scope_buffered_monitor.cc DEPS scope prof
cc_library
(
scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor scope_buffered_monitor
)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
cc_library
(
bind_threaded_ssa_graph_executor SRCS bind_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle gflags ssa_graph_executor scope simple_threadpool device_context
)
cc_library
(
fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context
)
cc_test
(
fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle
)
...
...
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc
0 → 100644
浏览文件 @
420fdbb2
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include <deque>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
#if defined(PADDLE_WITH_XPU)
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
std
::
atomic
<
unsigned
int
>
exec_op_count_
;
static
std
::
atomic
<
int
>
error_state
;
BindThreadedSSAGraphExecutor
::
BindThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
)
:
strategy_
(
strategy
),
local_scopes_
(
local_scopes
),
local_exec_scopes_
(
local_exec_scopes
),
places_
(
places
),
graph_
(
graph
),
prepare_pool_
(
1
),
multi_device_op_pool_
(
1
)
{
for
(
uint32_t
i
=
0
;
i
<
places
.
size
();
i
++
)
{
pool_
.
emplace_back
(
std
::
unique_ptr
<::
ThreadPool
>
(
new
::
ThreadPool
(
1
)));
}
int
index
=
0
;
for
(
uint32_t
i
=
0
;
i
<
places
.
size
();
i
++
)
{
int
id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
places_
[
i
]).
device
;
if
(
place_to_index_
.
find
(
id
)
==
place_to_index_
.
end
())
{
place_to_index_
[
id
]
=
index
;
index
++
;
}
}
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
op_deps_
.
emplace
(
op
,
dep
);
if
(
dep
==
0
)
{
bootstrap_ops_
.
emplace_back
(
op
);
}
}
PADDLE_ENFORCE_GT
(
op_deps_
.
size
(),
0
,
platform
::
errors
::
PreconditionNotMet
(
"The graph doesn't have operators."
));
PrepareAtomicOpDeps
();
}
static
std
::
vector
<
OpHandleBase
*>
get_children
(
OpHandleBase
*
op
)
{
auto
&
outputs
=
op
->
Outputs
();
std
::
vector
<
OpHandleBase
*>
ret
;
for
(
auto
&
output
:
outputs
)
{
ret
.
insert
(
ret
.
end
(),
output
->
PendingOps
().
begin
(),
output
->
PendingOps
().
end
());
}
return
ret
;
}
static
std
::
vector
<
OpHandleBase
*>
get_parents
(
OpHandleBase
*
op
)
{
auto
&
inputs
=
op
->
Inputs
();
std
::
vector
<
OpHandleBase
*>
ret
;
for
(
auto
&
input
:
inputs
)
{
if
(
input
->
GeneratedOp
()
!=
nullptr
)
{
ret
.
push_back
(
input
->
GeneratedOp
());
}
}
return
ret
;
}
FetchResultType
BindThreadedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
{
VLOG
(
3
)
<<
"enter BindThreadedSSAGraphExecutor Run"
;
return
RunMainStream
(
fetch_tensors
,
return_merged
);
}
// use 2 streams to run op. The first stream is main stream and will run
// most op exclude op depending on multi device(e.g., all_reduce, fetch op)
FetchResultType
BindThreadedSSAGraphExecutor
::
RunMainStream
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
{
VLOG
(
3
)
<<
"enter MainStream Run"
;
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>>
op_deps
=
atomic_op_deps_
.
get
();
PrepareAtomicOpDeps
();
error_state
=
0
;
paddle
::
framework
::
FetchResultType
fetches
;
if
(
return_merged
)
{
fetches
=
FetchList
(
fetch_tensors
.
size
());
}
else
{
fetches
=
FetchUnmergedList
(
fetch_tensors
.
size
());
}
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
vector
<
OpHandleBase
*>
fetch_ops
;
std
::
vector
<
OpHandleBase
*>
ready_fetch_ops
;
auto
ready_ops
=
std
::
make_shared
<
BlockingQueue
<
OpHandleBase
*>>
();
exception_
.
Clear
();
InsertFetchOps
(
fetch_tensors
,
&
fetches
,
&
fetched_vars
,
op_deps
.
get
(),
&
fetch_ops
,
&
ready_fetch_ops
,
return_merged
);
for
(
auto
cur_op
:
bootstrap_ops_
)
{
ready_ops
->
Push
(
cur_op
);
}
for
(
auto
cur_op
:
ready_fetch_ops
)
{
ready_ops
->
Push
(
cur_op
);
}
exec_op_count_
=
0
;
platform
::
XPUPlace
cur_place
;
std
::
size_t
cur_count
=
0
;
while
(
cur_count
<
op_deps_
.
size
())
{
cur_count
++
;
auto
cur_op
=
ready_ops
->
Pop
();
if
(
cur_op
==
nullptr
)
{
// sleep a while to make sure worker thread quit
sleep
(
10
);
exec_op_count_
=
op_deps_
.
size
();
break
;
}
auto
dev_ctxes_
=
cur_op
->
DeviceContext
();
if
(
cur_op
->
IsMultiDeviceTransfer
())
{
RunMultiDeviceOpAsync
(
cur_op
,
op_deps
.
get
(),
ready_ops
);
continue
;
}
else
{
cur_place
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
dev_ctxes_
.
begin
()
->
first
);
int
cur_index
=
place_to_index_
[
cur_place
.
device
];
RunOpAsyncMainStream
(
cur_op
,
op_deps
.
get
(),
ready_ops
,
cur_index
);
}
}
while
(
exec_op_count_
<
op_deps_
.
size
())
{
}
// Wait FetchOps.
ClearFetchOp
(
graph_
,
&
fetch_ops
);
if
(
exception_
.
IsCaught
())
{
ExecutionFinal
(
&
fetch_ops
);
}
return
fetches
;
}
void
BindThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
FetchResultType
*
fetches
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
*
fetched_vars
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBase
*>
*
ready_fetch_ops
,
bool
return_merged
)
{
std
::
unordered_set
<
std
::
string
>
fetch_tensor_set
(
fetch_tensors
.
begin
(),
fetch_tensors
.
end
());
for
(
auto
&
fetch_var_name
:
fetch_tensor_set
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
GraphVars
>
(
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
(
*
fetched_vars
)[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
}
}
}
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
.
at
(
i
);
auto
fetched_var_it
=
fetched_vars
->
find
(
var_name
);
PADDLE_ENFORCE_NE
(
fetched_var_it
,
fetched_vars
->
end
(),
platform
::
errors
::
PreconditionNotMet
(
"Cannot find fetched variable(%s) in current computation graph. "
"Possible reasons are:
\n
"
" 1. The variable to be fetched is not defined in main program.
\n
"
" 2. The variable to be fetched is not an input or output of any "
"operator.
\n
"
" 3. Confirm that you have used the fetch `Variable` format "
"instead of the string literal('%s') in `fetch_list` parameter "
"when using `executor.run` method. In other words, the format of "
"`executor.run(fetch_list=[fetch_var])`(fetch_var is a Variable) "
"is recommended."
,
var_name
,
var_name
));
auto
&
vars
=
fetched_var_it
->
second
;
ir
::
Node
*
fetch_node
=
graph_
->
CreateEmptyNode
(
"fetch"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
op
=
new
FetchOpHandle
(
fetch_node
,
fetches
,
i
,
&
local_scopes_
,
&
local_exec_scopes_
,
return_merged
);
fetch_ops
->
emplace_back
(
op
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
for
(
auto
&
p
:
places_
)
{
op
->
SetDeviceContext
(
p
,
pool
.
Get
(
p
));
}
for
(
auto
*
var
:
vars
)
{
op
->
AddInput
(
var
);
}
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
(
*
op_deps
)[
op
].
dep_num
=
dep
;
(
*
op_deps
)[
op
].
op
=
op
;
if
(
dep
==
0
)
{
ready_fetch_ops
->
emplace_back
(
op
);
}
}
}
void
BindThreadedSSAGraphExecutor
::
RunMultiDeviceOpAsync
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
)
{
multi_device_op_pool_
.
enqueue
([
=
]
{
try
{
if
(
error_state
==
0
&&
LIKELY
(
!
strategy_
.
dry_run_
))
{
auto
dev_ctxes
=
op
->
DeviceContext
();
auto
&
inputs
=
op
->
Inputs
();
for
(
auto
&
input
:
inputs
)
{
auto
dev_ctxes
=
input
->
GeneratedOp
()
->
DeviceContext
();
for
(
auto
&
item
:
dev_ctxes
)
{
((
platform
::
XPUDeviceContext
*
)(
item
.
second
))
->
Wait
();
}
}
op
->
Run
(
strategy_
.
use_device_
);
auto
&
outputs
=
op
->
Outputs
();
for
(
auto
&
output
:
outputs
)
{
for
(
auto
&
pending_op
:
output
->
PendingOps
())
{
std
::
atomic
<
int
>
&
deps
=
op_deps
->
at
(
pending_op
).
dep_num
;
if
(
deps
.
fetch_sub
(
1
)
==
1
)
{
ready_ops
->
Push
(
pending_op
);
}
}
}
}
else
if
(
error_state
)
{
ready_ops
->
Push
(
nullptr
);
}
}
catch
(...)
{
error_state
=
1
;
ready_ops
->
Push
(
nullptr
);
exception_
.
Catch
(
std
::
current_exception
());
}
exec_op_count_
++
;
});
}
void
BindThreadedSSAGraphExecutor
::
RunOpAsyncMainStream
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
,
int
index
)
{
pool_
[
index
]
->
enqueue
([
=
]
{
try
{
if
(
error_state
==
0
&&
LIKELY
(
!
strategy_
.
dry_run_
))
{
op
->
Run
(
strategy_
.
use_device_
);
auto
&
outputs
=
op
->
Outputs
();
for
(
auto
&
output
:
outputs
)
{
for
(
auto
&
pending_op
:
output
->
PendingOps
())
{
std
::
atomic
<
int
>
&
deps
=
op_deps
->
at
(
pending_op
).
dep_num
;
if
(
deps
.
fetch_sub
(
1
)
==
1
)
{
ready_ops
->
Push
(
pending_op
);
}
}
}
}
else
if
(
error_state
)
{
ready_ops
->
Push
(
nullptr
);
}
}
catch
(...)
{
error_state
=
1
;
ready_ops
->
Push
(
nullptr
);
exception_
.
Catch
(
std
::
current_exception
());
}
exec_op_count_
++
;
});
}
void
BindThreadedSSAGraphExecutor
::
PrepareAtomicOpDeps
()
{
atomic_op_deps_
=
prepare_pool_
.
enqueue
([
&
]
{
auto
*
op_deps
=
new
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
;
for
(
auto
&
pair
:
op_deps_
)
{
(
*
op_deps
)[
pair
.
first
].
dep_num
=
pair
.
second
;
(
*
op_deps
)[
pair
.
first
].
op
=
pair
.
first
;
}
return
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>>
(
op_deps
);
});
}
const
ir
::
Graph
&
BindThreadedSSAGraphExecutor
::
Graph
()
const
{
return
*
graph_
;
}
void
BindThreadedSSAGraphExecutor
::
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
VLOG
(
3
)
<<
"caught exception "
<<
exception_
.
Type
()
<<
", rethrow it"
;
ClearFetchOp
(
graph_
,
fetch_ops
);
exception_
.
ReThrow
();
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h
0 → 100644
浏览文件 @
420fdbb2
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#if defined(PADDLE_WITH_XPU)
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
struct
RunningItem
{
std
::
atomic
<
int
>
dep_num
;
OpHandleBase
*
op
;
};
class
OpHandleBase
;
class
BindThreadedSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
BindThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
);
// FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FetchResultType
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
override
;
const
ir
::
Graph
&
Graph
()
const
override
;
private:
FetchResultType
RunMainStream
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
);
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy
strategy_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_exec_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
ir
::
Graph
*
graph_
;
std
::
unordered_map
<
OpHandleBase
*
,
int
>
op_deps_
;
std
::
unordered_map
<
int
,
int
>
place_to_index_
;
std
::
vector
<
OpHandleBase
*>
bootstrap_ops_
;
std
::
unique_ptr
<
int
[]
>
stream_op_count_
;
std
::
future
<
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>>>
atomic_op_deps_
;
ExceptionHolder
exception_
;
std
::
vector
<
std
::
unique_ptr
<::
ThreadPool
>>
pool_
;
::
ThreadPool
prepare_pool_
;
::
ThreadPool
multi_device_op_pool_
;
void
RunOpAsyncMainStream
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
,
int
index
);
void
RunMultiDeviceOpAsync
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
);
void
PrepareAtomicOpDeps
();
int
get_pool_thread_index
(
int
device_id
);
inline
void
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
);
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
FetchResultType
*
fetches
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
*
fetched_vars
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBase
*>
*
ready_fetch_ops
,
bool
return_merged
);
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
420fdbb2
...
...
@@ -215,13 +215,6 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with CUDA."
));
#endif
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU
dev_ctxes_
.
at
(
place
)
->
Wait
();
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with XPU."
));
#endif
}
// There are nothing to do when the place is CPUPlace.
...
...
@@ -271,19 +264,6 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with CUDA."
));
#endif
}
else
if
(
platform
::
is_xpu_place
(
in_var_handle
->
place
()))
{
#ifdef PADDLE_WITH_XPU
PADDLE_ENFORCE_EQ
(
platform
::
is_same_place
(
place
,
in_var_handle
->
place
()),
true
,
platform
::
errors
::
InvalidArgument
(
"The place of output(%s) is not consistent with the "
"place of current op(%s)."
,
in_var_handle
->
Name
(),
Name
()));
dev_ctxes_
.
at
(
place
)
->
Wait
();
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with XPU."
));
#endif
}
// There are nothing to do when the place is CPUPlace.
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
420fdbb2
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
...
...
@@ -933,10 +934,23 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
}
else
{
VLOG
(
3
)
<<
"use FastThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
if
(
member_
->
use_device_
==
p
::
kXPU
)
{
#if defined(PADDLE_WITH_XPU)
VLOG
(
3
)
<<
"use BindThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
BindThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't use XPU device since it's not compiled with XPU,"
"Please recompile or reinstall Paddle with XPU support."
));
#endif
}
else
{
VLOG
(
3
)
<<
"use FastThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
}
}
final_graphs
.
emplace_back
(
graph
);
}
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
420fdbb2
...
...
@@ -210,7 +210,7 @@ void XPUDeviceContext::Wait() const {
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
xpu_wait
();
xpu_wait
(
context_
->
xpu_stream
);
}
Place
XPUDeviceContext
::
GetPlace
()
const
{
return
place_
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录