未验证 提交 420fdbb2 编写于 作者: L liuyuhui 提交者: GitHub

[Kunlun]PR3: add xpu executor, multi xpu card train function optimization (#30317) (#30535)

上级 7a4ccf59
......@@ -267,7 +267,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy collective_helper
graph build_strategy bind_threaded_ssa_graph_executor collective_helper
fast_threaded_ssa_graph_executor variable_helper)
cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS
......
......@@ -101,6 +101,8 @@ cc_library(scope_buffered_monitor SRCS scope_buffered_monitor.cc DEPS scope prof
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor scope_buffered_monitor)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
cc_library(bind_threaded_ssa_graph_executor SRCS bind_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle gflags ssa_graph_executor scope simple_threadpool device_context)
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include <deque>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
#if defined(PADDLE_WITH_XPU)
namespace paddle {
namespace framework {
namespace details {
static std::atomic<unsigned int> exec_op_count_;
static std::atomic<int> error_state;
BindThreadedSSAGraphExecutor::BindThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: strategy_(strategy),
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
places_(places),
graph_(graph),
prepare_pool_(1),
multi_device_op_pool_(1) {
for (uint32_t i = 0; i < places.size(); i++) {
pool_.emplace_back(std::unique_ptr<::ThreadPool>(new ::ThreadPool(1)));
}
int index = 0;
for (uint32_t i = 0; i < places.size(); i++) {
int id = BOOST_GET_CONST(platform::XPUPlace, places_[i]).device;
if (place_to_index_.find(id) == place_to_index_.end()) {
place_to_index_[id] = index;
index++;
}
}
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep);
if (dep == 0) {
bootstrap_ops_.emplace_back(op);
}
}
PADDLE_ENFORCE_GT(op_deps_.size(), 0,
platform::errors::PreconditionNotMet(
"The graph doesn't have operators."));
PrepareAtomicOpDeps();
}
static std::vector<OpHandleBase *> get_children(OpHandleBase *op) {
auto &outputs = op->Outputs();
std::vector<OpHandleBase *> ret;
for (auto &output : outputs) {
ret.insert(ret.end(), output->PendingOps().begin(),
output->PendingOps().end());
}
return ret;
}
static std::vector<OpHandleBase *> get_parents(OpHandleBase *op) {
auto &inputs = op->Inputs();
std::vector<OpHandleBase *> ret;
for (auto &input : inputs) {
if (input->GeneratedOp() != nullptr) {
ret.push_back(input->GeneratedOp());
}
}
return ret;
}
FetchResultType BindThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter BindThreadedSSAGraphExecutor Run";
return RunMainStream(fetch_tensors, return_merged);
}
// use 2 streams to run op. The first stream is main stream and will run
// most op exclude op depending on multi device(e.g., all_reduce, fetch op)
FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter MainStream Run";
std::unique_ptr<std::unordered_map<OpHandleBase *, struct RunningItem>>
op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps();
error_state = 0;
paddle::framework::FetchResultType fetches;
if (return_merged) {
fetches = FetchList(fetch_tensors.size());
} else {
fetches = FetchUnmergedList(fetch_tensors.size());
}
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
auto ready_ops = std::make_shared<BlockingQueue<OpHandleBase *>>();
exception_.Clear();
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops, return_merged);
for (auto cur_op : bootstrap_ops_) {
ready_ops->Push(cur_op);
}
for (auto cur_op : ready_fetch_ops) {
ready_ops->Push(cur_op);
}
exec_op_count_ = 0;
platform::XPUPlace cur_place;
std::size_t cur_count = 0;
while (cur_count < op_deps_.size()) {
cur_count++;
auto cur_op = ready_ops->Pop();
if (cur_op == nullptr) {
// sleep a while to make sure worker thread quit
sleep(10);
exec_op_count_ = op_deps_.size();
break;
}
auto dev_ctxes_ = cur_op->DeviceContext();
if (cur_op->IsMultiDeviceTransfer()) {
RunMultiDeviceOpAsync(cur_op, op_deps.get(), ready_ops);
continue;
} else {
cur_place =
BOOST_GET_CONST(platform::XPUPlace, dev_ctxes_.begin()->first);
int cur_index = place_to_index_[cur_place.device];
RunOpAsyncMainStream(cur_op, op_deps.get(), ready_ops, cur_index);
}
}
while (exec_op_count_ < op_deps_.size()) {
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
return fetches;
}
void BindThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged) {
std::unordered_set<std::string> fetch_tensor_set(fetch_tensors.begin(),
fetch_tensors.end());
for (auto &fetch_var_name : fetch_tensor_set) {
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors.at(i);
auto fetched_var_it = fetched_vars->find(var_name);
PADDLE_ENFORCE_NE(
fetched_var_it, fetched_vars->end(),
platform::errors::PreconditionNotMet(
"Cannot find fetched variable(%s) in current computation graph. "
"Possible reasons are:\n"
" 1. The variable to be fetched is not defined in main program.\n"
" 2. The variable to be fetched is not an input or output of any "
"operator.\n"
" 3. Confirm that you have used the fetch `Variable` format "
"instead of the string literal('%s') in `fetch_list` parameter "
"when using `executor.run` method. In other words, the format of "
"`executor.run(fetch_list=[fetch_var])`(fetch_var is a Variable) "
"is recommended.",
var_name, var_name));
auto &vars = fetched_var_it->second;
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_, return_merged);
fetch_ops->emplace_back(op);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
for (auto &p : places_) {
op->SetDeviceContext(p, pool.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op].dep_num = dep;
(*op_deps)[op].op = op;
if (dep == 0) {
ready_fetch_ops->emplace_back(op);
}
}
}
void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops) {
multi_device_op_pool_.enqueue([=] {
try {
if (error_state == 0 && LIKELY(!strategy_.dry_run_)) {
auto dev_ctxes = op->DeviceContext();
auto &inputs = op->Inputs();
for (auto &input : inputs) {
auto dev_ctxes = input->GeneratedOp()->DeviceContext();
for (auto &item : dev_ctxes) {
((platform::XPUDeviceContext *)(item.second))->Wait();
}
}
op->Run(strategy_.use_device_);
auto &outputs = op->Outputs();
for (auto &output : outputs) {
for (auto &pending_op : output->PendingOps()) {
std::atomic<int> &deps = op_deps->at(pending_op).dep_num;
if (deps.fetch_sub(1) == 1) {
ready_ops->Push(pending_op);
}
}
}
} else if (error_state) {
ready_ops->Push(nullptr);
}
} catch (...) {
error_state = 1;
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
exec_op_count_++;
});
}
void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops, int index) {
pool_[index]->enqueue([=] {
try {
if (error_state == 0 && LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_device_);
auto &outputs = op->Outputs();
for (auto &output : outputs) {
for (auto &pending_op : output->PendingOps()) {
std::atomic<int> &deps = op_deps->at(pending_op).dep_num;
if (deps.fetch_sub(1) == 1) {
ready_ops->Push(pending_op);
}
}
}
} else if (error_state) {
ready_ops->Push(nullptr);
}
} catch (...) {
error_state = 1;
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
exec_op_count_++;
});
}
void BindThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, struct RunningItem>;
for (auto &pair : op_deps_) {
(*op_deps)[pair.first].dep_num = pair.second;
(*op_deps)[pair.first].op = pair.first;
}
return std::unique_ptr<
std::unordered_map<OpHandleBase *, struct RunningItem>>(op_deps);
});
}
const ir::Graph &BindThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void BindThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_.ReThrow();
}
} // namespace details
} // namespace framework
} // namespace paddle
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#if defined(PADDLE_WITH_XPU)
namespace paddle {
namespace framework {
class Scope;
namespace details {
struct RunningItem {
std::atomic<int> dep_num;
OpHandleBase *op;
};
class OpHandleBase;
class BindThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
BindThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
// FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override;
const ir::Graph &Graph() const override;
private:
FetchResultType RunMainStream(const std::vector<std::string> &fetch_tensors,
bool return_merged);
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_;
ir::Graph *graph_;
std::unordered_map<OpHandleBase *, int> op_deps_;
std::unordered_map<int, int> place_to_index_;
std::vector<OpHandleBase *> bootstrap_ops_;
std::unique_ptr<int[]> stream_op_count_;
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, struct RunningItem>>>
atomic_op_deps_;
ExceptionHolder exception_;
std::vector<std::unique_ptr<::ThreadPool>> pool_;
::ThreadPool prepare_pool_;
::ThreadPool multi_device_op_pool_;
void RunOpAsyncMainStream(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops, int index);
void RunMultiDeviceOpAsync(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops);
void PrepareAtomicOpDeps();
int get_pool_thread_index(int device_id);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged);
};
} // namespace details
} // namespace framework
} // namespace paddle
#endif
......@@ -215,13 +215,6 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif
}
// There are nothing to do when the place is CPUPlace.
......@@ -271,19 +264,6 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(in_var_handle->place())) {
#ifdef PADDLE_WITH_XPU
PADDLE_ENFORCE_EQ(
platform::is_same_place(place, in_var_handle->place()), true,
platform::errors::InvalidArgument(
"The place of output(%s) is not consistent with the "
"place of current op(%s).",
in_var_handle->Name(), Name()));
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif
}
// There are nothing to do when the place is CPUPlace.
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
......@@ -933,10 +934,23 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
if (member_->use_device_ == p::kXPU) {
#if defined(PADDLE_WITH_XPU)
VLOG(3) << "use BindThreadedSSAGraphExecutor";
member_->executor_.reset(new details::BindThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_,
member_->local_exec_scopes_, member_->places_, graph));
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use XPU device since it's not compiled with XPU,"
"Please recompile or reinstall Paddle with XPU support."));
#endif
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_,
member_->local_exec_scopes_, member_->places_, graph));
}
}
final_graphs.emplace_back(graph);
}
......
......@@ -210,7 +210,7 @@ void XPUDeviceContext::Wait() const {
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
xpu_wait();
xpu_wait(context_->xpu_stream);
}
Place XPUDeviceContext::GetPlace() const { return place_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册