Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
420fdbb2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
420fdbb2
编写于
1月 19, 2021
作者:
L
liuyuhui
提交者:
GitHub
1月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Kunlun]PR3: add xpu executor, multi xpu card train function optimization (#30317) (#30535)
上级
7a4ccf59
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
445 addition
and
26 deletion
+445
-26
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-0
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc
...uid/framework/details/bind_threaded_ssa_graph_executor.cc
+316
-0
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h
...luid/framework/details/bind_threaded_ssa_graph_executor.h
+107
-0
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+0
-20
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+18
-4
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+1
-1
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
420fdbb2
...
...
@@ -267,7 +267,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h
cc_library
(
parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy collective_helper
graph build_strategy
bind_threaded_ssa_graph_executor
collective_helper
fast_threaded_ssa_graph_executor variable_helper
)
cc_test
(
dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
420fdbb2
...
...
@@ -101,6 +101,8 @@ cc_library(scope_buffered_monitor SRCS scope_buffered_monitor.cc DEPS scope prof
cc_library
(
scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor scope_buffered_monitor
)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
cc_library
(
bind_threaded_ssa_graph_executor SRCS bind_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle gflags ssa_graph_executor scope simple_threadpool device_context
)
cc_library
(
fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context
)
cc_test
(
fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle
)
...
...
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.cc
0 → 100644
浏览文件 @
420fdbb2
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include <deque>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
#if defined(PADDLE_WITH_XPU)
namespace
paddle
{
namespace
framework
{
namespace
details
{
static
std
::
atomic
<
unsigned
int
>
exec_op_count_
;
static
std
::
atomic
<
int
>
error_state
;
BindThreadedSSAGraphExecutor
::
BindThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
)
:
strategy_
(
strategy
),
local_scopes_
(
local_scopes
),
local_exec_scopes_
(
local_exec_scopes
),
places_
(
places
),
graph_
(
graph
),
prepare_pool_
(
1
),
multi_device_op_pool_
(
1
)
{
for
(
uint32_t
i
=
0
;
i
<
places
.
size
();
i
++
)
{
pool_
.
emplace_back
(
std
::
unique_ptr
<::
ThreadPool
>
(
new
::
ThreadPool
(
1
)));
}
int
index
=
0
;
for
(
uint32_t
i
=
0
;
i
<
places
.
size
();
i
++
)
{
int
id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
places_
[
i
]).
device
;
if
(
place_to_index_
.
find
(
id
)
==
place_to_index_
.
end
())
{
place_to_index_
[
id
]
=
index
;
index
++
;
}
}
for
(
auto
&
op
:
ir
::
FilterByNodeWrapper
<
OpHandleBase
>
(
*
graph_
))
{
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
op_deps_
.
emplace
(
op
,
dep
);
if
(
dep
==
0
)
{
bootstrap_ops_
.
emplace_back
(
op
);
}
}
PADDLE_ENFORCE_GT
(
op_deps_
.
size
(),
0
,
platform
::
errors
::
PreconditionNotMet
(
"The graph doesn't have operators."
));
PrepareAtomicOpDeps
();
}
static
std
::
vector
<
OpHandleBase
*>
get_children
(
OpHandleBase
*
op
)
{
auto
&
outputs
=
op
->
Outputs
();
std
::
vector
<
OpHandleBase
*>
ret
;
for
(
auto
&
output
:
outputs
)
{
ret
.
insert
(
ret
.
end
(),
output
->
PendingOps
().
begin
(),
output
->
PendingOps
().
end
());
}
return
ret
;
}
static
std
::
vector
<
OpHandleBase
*>
get_parents
(
OpHandleBase
*
op
)
{
auto
&
inputs
=
op
->
Inputs
();
std
::
vector
<
OpHandleBase
*>
ret
;
for
(
auto
&
input
:
inputs
)
{
if
(
input
->
GeneratedOp
()
!=
nullptr
)
{
ret
.
push_back
(
input
->
GeneratedOp
());
}
}
return
ret
;
}
FetchResultType
BindThreadedSSAGraphExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
{
VLOG
(
3
)
<<
"enter BindThreadedSSAGraphExecutor Run"
;
return
RunMainStream
(
fetch_tensors
,
return_merged
);
}
// use 2 streams to run op. The first stream is main stream and will run
// most op exclude op depending on multi device(e.g., all_reduce, fetch op)
FetchResultType
BindThreadedSSAGraphExecutor
::
RunMainStream
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
{
VLOG
(
3
)
<<
"enter MainStream Run"
;
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>>
op_deps
=
atomic_op_deps_
.
get
();
PrepareAtomicOpDeps
();
error_state
=
0
;
paddle
::
framework
::
FetchResultType
fetches
;
if
(
return_merged
)
{
fetches
=
FetchList
(
fetch_tensors
.
size
());
}
else
{
fetches
=
FetchUnmergedList
(
fetch_tensors
.
size
());
}
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
fetched_vars
;
std
::
vector
<
OpHandleBase
*>
fetch_ops
;
std
::
vector
<
OpHandleBase
*>
ready_fetch_ops
;
auto
ready_ops
=
std
::
make_shared
<
BlockingQueue
<
OpHandleBase
*>>
();
exception_
.
Clear
();
InsertFetchOps
(
fetch_tensors
,
&
fetches
,
&
fetched_vars
,
op_deps
.
get
(),
&
fetch_ops
,
&
ready_fetch_ops
,
return_merged
);
for
(
auto
cur_op
:
bootstrap_ops_
)
{
ready_ops
->
Push
(
cur_op
);
}
for
(
auto
cur_op
:
ready_fetch_ops
)
{
ready_ops
->
Push
(
cur_op
);
}
exec_op_count_
=
0
;
platform
::
XPUPlace
cur_place
;
std
::
size_t
cur_count
=
0
;
while
(
cur_count
<
op_deps_
.
size
())
{
cur_count
++
;
auto
cur_op
=
ready_ops
->
Pop
();
if
(
cur_op
==
nullptr
)
{
// sleep a while to make sure worker thread quit
sleep
(
10
);
exec_op_count_
=
op_deps_
.
size
();
break
;
}
auto
dev_ctxes_
=
cur_op
->
DeviceContext
();
if
(
cur_op
->
IsMultiDeviceTransfer
())
{
RunMultiDeviceOpAsync
(
cur_op
,
op_deps
.
get
(),
ready_ops
);
continue
;
}
else
{
cur_place
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
dev_ctxes_
.
begin
()
->
first
);
int
cur_index
=
place_to_index_
[
cur_place
.
device
];
RunOpAsyncMainStream
(
cur_op
,
op_deps
.
get
(),
ready_ops
,
cur_index
);
}
}
while
(
exec_op_count_
<
op_deps_
.
size
())
{
}
// Wait FetchOps.
ClearFetchOp
(
graph_
,
&
fetch_ops
);
if
(
exception_
.
IsCaught
())
{
ExecutionFinal
(
&
fetch_ops
);
}
return
fetches
;
}
void
BindThreadedSSAGraphExecutor
::
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
FetchResultType
*
fetches
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
*
fetched_vars
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBase
*>
*
ready_fetch_ops
,
bool
return_merged
)
{
std
::
unordered_set
<
std
::
string
>
fetch_tensor_set
(
fetch_tensors
.
begin
(),
fetch_tensors
.
end
());
for
(
auto
&
fetch_var_name
:
fetch_tensor_set
)
{
for
(
auto
&
var_map
:
graph_
->
Get
<
GraphVars
>
(
kGraphVars
))
{
auto
it
=
var_map
.
find
(
fetch_var_name
);
if
(
it
!=
var_map
.
end
())
{
(
*
fetched_vars
)[
fetch_var_name
].
push_back
(
*
it
->
second
.
rbegin
());
}
}
}
for
(
size_t
i
=
0
;
i
<
fetch_tensors
.
size
();
++
i
)
{
auto
&
var_name
=
fetch_tensors
.
at
(
i
);
auto
fetched_var_it
=
fetched_vars
->
find
(
var_name
);
PADDLE_ENFORCE_NE
(
fetched_var_it
,
fetched_vars
->
end
(),
platform
::
errors
::
PreconditionNotMet
(
"Cannot find fetched variable(%s) in current computation graph. "
"Possible reasons are:
\n
"
" 1. The variable to be fetched is not defined in main program.
\n
"
" 2. The variable to be fetched is not an input or output of any "
"operator.
\n
"
" 3. Confirm that you have used the fetch `Variable` format "
"instead of the string literal('%s') in `fetch_list` parameter "
"when using `executor.run` method. In other words, the format of "
"`executor.run(fetch_list=[fetch_var])`(fetch_var is a Variable) "
"is recommended."
,
var_name
,
var_name
));
auto
&
vars
=
fetched_var_it
->
second
;
ir
::
Node
*
fetch_node
=
graph_
->
CreateEmptyNode
(
"fetch"
,
ir
::
Node
::
Type
::
kOperation
);
auto
*
op
=
new
FetchOpHandle
(
fetch_node
,
fetches
,
i
,
&
local_scopes_
,
&
local_exec_scopes_
,
return_merged
);
fetch_ops
->
emplace_back
(
op
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
for
(
auto
&
p
:
places_
)
{
op
->
SetDeviceContext
(
p
,
pool
.
Get
(
p
));
}
for
(
auto
*
var
:
vars
)
{
op
->
AddInput
(
var
);
}
int
dep
=
static_cast
<
int
>
(
op
->
NotReadyInputSize
());
(
*
op_deps
)[
op
].
dep_num
=
dep
;
(
*
op_deps
)[
op
].
op
=
op
;
if
(
dep
==
0
)
{
ready_fetch_ops
->
emplace_back
(
op
);
}
}
}
void
BindThreadedSSAGraphExecutor
::
RunMultiDeviceOpAsync
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
)
{
multi_device_op_pool_
.
enqueue
([
=
]
{
try
{
if
(
error_state
==
0
&&
LIKELY
(
!
strategy_
.
dry_run_
))
{
auto
dev_ctxes
=
op
->
DeviceContext
();
auto
&
inputs
=
op
->
Inputs
();
for
(
auto
&
input
:
inputs
)
{
auto
dev_ctxes
=
input
->
GeneratedOp
()
->
DeviceContext
();
for
(
auto
&
item
:
dev_ctxes
)
{
((
platform
::
XPUDeviceContext
*
)(
item
.
second
))
->
Wait
();
}
}
op
->
Run
(
strategy_
.
use_device_
);
auto
&
outputs
=
op
->
Outputs
();
for
(
auto
&
output
:
outputs
)
{
for
(
auto
&
pending_op
:
output
->
PendingOps
())
{
std
::
atomic
<
int
>
&
deps
=
op_deps
->
at
(
pending_op
).
dep_num
;
if
(
deps
.
fetch_sub
(
1
)
==
1
)
{
ready_ops
->
Push
(
pending_op
);
}
}
}
}
else
if
(
error_state
)
{
ready_ops
->
Push
(
nullptr
);
}
}
catch
(...)
{
error_state
=
1
;
ready_ops
->
Push
(
nullptr
);
exception_
.
Catch
(
std
::
current_exception
());
}
exec_op_count_
++
;
});
}
void
BindThreadedSSAGraphExecutor
::
RunOpAsyncMainStream
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
,
int
index
)
{
pool_
[
index
]
->
enqueue
([
=
]
{
try
{
if
(
error_state
==
0
&&
LIKELY
(
!
strategy_
.
dry_run_
))
{
op
->
Run
(
strategy_
.
use_device_
);
auto
&
outputs
=
op
->
Outputs
();
for
(
auto
&
output
:
outputs
)
{
for
(
auto
&
pending_op
:
output
->
PendingOps
())
{
std
::
atomic
<
int
>
&
deps
=
op_deps
->
at
(
pending_op
).
dep_num
;
if
(
deps
.
fetch_sub
(
1
)
==
1
)
{
ready_ops
->
Push
(
pending_op
);
}
}
}
}
else
if
(
error_state
)
{
ready_ops
->
Push
(
nullptr
);
}
}
catch
(...)
{
error_state
=
1
;
ready_ops
->
Push
(
nullptr
);
exception_
.
Catch
(
std
::
current_exception
());
}
exec_op_count_
++
;
});
}
void
BindThreadedSSAGraphExecutor
::
PrepareAtomicOpDeps
()
{
atomic_op_deps_
=
prepare_pool_
.
enqueue
([
&
]
{
auto
*
op_deps
=
new
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
;
for
(
auto
&
pair
:
op_deps_
)
{
(
*
op_deps
)[
pair
.
first
].
dep_num
=
pair
.
second
;
(
*
op_deps
)[
pair
.
first
].
op
=
pair
.
first
;
}
return
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>>
(
op_deps
);
});
}
const
ir
::
Graph
&
BindThreadedSSAGraphExecutor
::
Graph
()
const
{
return
*
graph_
;
}
void
BindThreadedSSAGraphExecutor
::
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
VLOG
(
3
)
<<
"caught exception "
<<
exception_
.
Type
()
<<
", rethrow it"
;
ClearFetchOp
(
graph_
,
fetch_ops
);
exception_
.
ReThrow
();
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h
0 → 100644
浏览文件 @
420fdbb2
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#if defined(PADDLE_WITH_XPU)
namespace
paddle
{
namespace
framework
{
class
Scope
;
namespace
details
{
struct
RunningItem
{
std
::
atomic
<
int
>
dep_num
;
OpHandleBase
*
op
;
};
class
OpHandleBase
;
class
BindThreadedSSAGraphExecutor
:
public
SSAGraphExecutor
{
public:
BindThreadedSSAGraphExecutor
(
const
ExecutionStrategy
&
strategy
,
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
Scope
*>
&
local_exec_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
ir
::
Graph
*
graph
);
// FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FetchResultType
Run
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
)
override
;
const
ir
::
Graph
&
Graph
()
const
override
;
private:
FetchResultType
RunMainStream
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
bool
return_merged
);
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy
strategy_
;
std
::
vector
<
Scope
*>
local_scopes_
;
std
::
vector
<
Scope
*>
local_exec_scopes_
;
std
::
vector
<
platform
::
Place
>
places_
;
ir
::
Graph
*
graph_
;
std
::
unordered_map
<
OpHandleBase
*
,
int
>
op_deps_
;
std
::
unordered_map
<
int
,
int
>
place_to_index_
;
std
::
vector
<
OpHandleBase
*>
bootstrap_ops_
;
std
::
unique_ptr
<
int
[]
>
stream_op_count_
;
std
::
future
<
std
::
unique_ptr
<
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>>>
atomic_op_deps_
;
ExceptionHolder
exception_
;
std
::
vector
<
std
::
unique_ptr
<::
ThreadPool
>>
pool_
;
::
ThreadPool
prepare_pool_
;
::
ThreadPool
multi_device_op_pool_
;
void
RunOpAsyncMainStream
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
,
int
index
);
void
RunMultiDeviceOpAsync
(
OpHandleBase
*
op
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
shared_ptr
<
BlockingQueue
<
OpHandleBase
*>>
ready_ops
);
void
PrepareAtomicOpDeps
();
int
get_pool_thread_index
(
int
device_id
);
inline
void
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
);
void
InsertFetchOps
(
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
,
FetchResultType
*
fetches
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
VarHandleBase
*>>
*
fetched_vars
,
std
::
unordered_map
<
OpHandleBase
*
,
struct
RunningItem
>
*
op_deps
,
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
,
std
::
vector
<
OpHandleBase
*>
*
ready_fetch_ops
,
bool
return_merged
);
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
#endif
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
420fdbb2
...
...
@@ -215,13 +215,6 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with CUDA."
));
#endif
}
else
if
(
platform
::
is_xpu_place
(
place
))
{
#ifdef PADDLE_WITH_XPU
dev_ctxes_
.
at
(
place
)
->
Wait
();
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with XPU."
));
#endif
}
// There are nothing to do when the place is CPUPlace.
...
...
@@ -271,19 +264,6 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with CUDA."
));
#endif
}
else
if
(
platform
::
is_xpu_place
(
in_var_handle
->
place
()))
{
#ifdef PADDLE_WITH_XPU
PADDLE_ENFORCE_EQ
(
platform
::
is_same_place
(
place
,
in_var_handle
->
place
()),
true
,
platform
::
errors
::
InvalidArgument
(
"The place of output(%s) is not consistent with the "
"place of current op(%s)."
,
in_var_handle
->
Name
(),
Name
()));
dev_ctxes_
.
at
(
place
)
->
Wait
();
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"Not compiled with XPU."
));
#endif
}
// There are nothing to do when the place is CPUPlace.
...
...
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
420fdbb2
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
...
...
@@ -933,10 +934,23 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
}
else
{
VLOG
(
3
)
<<
"use FastThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
if
(
member_
->
use_device_
==
p
::
kXPU
)
{
#if defined(PADDLE_WITH_XPU)
VLOG
(
3
)
<<
"use BindThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
BindThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Paddle can't use XPU device since it's not compiled with XPU,"
"Please recompile or reinstall Paddle with XPU support."
));
#endif
}
else
{
VLOG
(
3
)
<<
"use FastThreadedSSAGraphExecutor"
;
member_
->
executor_
.
reset
(
new
details
::
FastThreadedSSAGraphExecutor
(
exec_strategy
,
member_
->
local_scopes_
,
member_
->
local_exec_scopes_
,
member_
->
places_
,
graph
));
}
}
final_graphs
.
emplace_back
(
graph
);
}
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
420fdbb2
...
...
@@ -210,7 +210,7 @@ void XPUDeviceContext::Wait() const {
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed."
,
ret
));
xpu_wait
();
xpu_wait
(
context_
->
xpu_stream
);
}
Place
XPUDeviceContext
::
GetPlace
()
const
{
return
place_
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录