未验证 提交 5988553f 编写于 作者: 王明冬 提交者: GitHub

[NPU] add npu support for new executor. test=develop (#43403)

上级 0a04b8a9
...@@ -137,6 +137,13 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -137,6 +137,13 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)}; new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)};
new_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); new_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
new_op_func_node.kernel_func_(exec_ctx); new_op_func_node.kernel_func_(exec_ctx);
// NOTE(winter-wang): in npu device, D2H kernel is asynchronous. need to
// explicit synchronization.
#ifdef PADDLE_WITH_ASCEND_CL
if (op_type == kMemcpyD2H) {
dev_ctx->Wait();
}
#endif
// NOTE(Aurelius84): data_transform_op is expensive operation, so we tag them // NOTE(Aurelius84): data_transform_op is expensive operation, so we tag them
// as kQueueSync and execute them in thread pool. // as kQueueSync and execute them in thread pool.
new_op_func_node.type_ = OpFuncType::kQueueSync; new_op_func_node.type_ = OpFuncType::kQueueSync;
......
...@@ -90,6 +90,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -90,6 +90,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
auto local_scope = &var_scope_.GetMutableScope()->NewScope(); auto local_scope = &var_scope_.GetMutableScope()->NewScope();
local_scope_ = local_scope; local_scope_ = local_scope;
} }
var_scope_.SetLocalScope(local_scope_);
// prune // prune
...@@ -115,7 +116,6 @@ InterpreterCore::~InterpreterCore() { ...@@ -115,7 +116,6 @@ InterpreterCore::~InterpreterCore() {
interpreter::CostInfo InterpreterCore::DryRun( interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
var_scope_.SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true); Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info; interpreter::CostInfo cost_info;
{ {
...@@ -144,7 +144,6 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -144,7 +144,6 @@ paddle::framework::FetchList InterpreterCore::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif #endif
bool is_build = is_build_; bool is_build = is_build_;
var_scope_.SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, is_build); Prepare(feed_names, feed_tensors, is_build);
if (is_build) { if (is_build) {
...@@ -153,8 +152,10 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -153,8 +152,10 @@ paddle::framework::FetchList InterpreterCore::Run(
// until the second step run. // until the second step run.
async_work_queue_ = GetWorkQueue(); async_work_queue_ = GetWorkQueue();
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait();
#endif
} }
if (create_local_scope_) { if (create_local_scope_) {
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
...@@ -174,7 +175,6 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -174,7 +175,6 @@ paddle::framework::FetchList InterpreterCore::Run(
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif #endif
if (!is_build_) { if (!is_build_) {
var_scope_.SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(block_, &var_scope_); paddle::framework::interpreter::build_variable_scope(block_, &var_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
...@@ -196,12 +196,14 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -196,12 +196,14 @@ paddle::framework::FetchList InterpreterCore::Run(
async_work_queue_ = GetWorkQueue(); async_work_queue_ = GetWorkQueue();
ExecuteInstructionList(vec_instruction_); ExecuteInstructionList(vec_instruction_);
#ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool::Instance().Get(place_)->Wait();
#endif
} }
if (create_local_scope_) { if (create_local_scope_) {
ClearLoDTensorArrayInLocalScope(); ClearLoDTensorArrayInLocalScope();
} }
// return Fetch Tensors // return Fetch Tensors
auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName); auto* fetch_var = local_scope_->FindVar(interpreter::kFetchVarName);
if (fetch_var) { if (fetch_var) {
...@@ -528,6 +530,17 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -528,6 +530,17 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope() Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope(); : var_scope_.GetMutableScope();
#ifdef PADDLE_WITH_ASCEND_CL
// NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable
// values, but only through special `float_status` to checks whether
// the operation is overflow. More about `float_status`, see:
// https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue
if (FLAGS_check_nan_inf) {
framework::details::NPUAllocAndClearFloatStatus(*op, *local_scope, place);
}
#endif
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{ {
// If it is OperatorBase, InferShape do nothing. // If it is OperatorBase, InferShape do nothing.
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/data_transfer.h" #include "paddle/fluid/framework/new_executor/data_transfer.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h" #include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
...@@ -43,6 +44,7 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -43,6 +44,7 @@ PADDLE_DEFINE_EXPORTED_bool(
"Enable serial execution for standalone executor, used for debug."); "Enable serial execution for standalone executor, used for debug.");
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
DECLARE_bool(check_nan_inf);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -446,11 +448,19 @@ void build_op_func_list(const platform::Place& place, ...@@ -446,11 +448,19 @@ void build_op_func_list(const platform::Place& place,
op_func_node.output_index = outs_name2id; op_func_node.output_index = outs_name2id;
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
#ifdef PADDLE_WITH_ASCEND_CL
// NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable
// values, but only through special `float_status` to checks whether
// the operation is overflow. More about `float_status`, see:
// https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue
if (FLAGS_check_nan_inf) {
framework::details::NPUAllocAndClearFloatStatus(*op, *local_scope, place);
}
#endif
if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) { if (dynamic_cast<framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run() // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope); deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
VLOG(4) << "End run " << place << " "
<< op_func_node.operator_base_->DebugStringEx(local_scope);
} else { } else {
auto op_with_kernel = const_cast<framework::OperatorWithKernel*>( auto op_with_kernel = const_cast<framework::OperatorWithKernel*>(
static_cast<const framework::OperatorWithKernel*>(op)); static_cast<const framework::OperatorWithKernel*>(op));
...@@ -593,6 +603,12 @@ void build_op_func_list(const platform::Place& place, ...@@ -593,6 +603,12 @@ void build_op_func_list(const platform::Place& place,
<< var_scope->GetNameById(p.second); << var_scope->GetNameById(p.second);
} }
} }
// for debug nan/inf
if (FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(*op, *runtime_scope, place);
}
} }
VLOG(4) << "End run " << place << " " VLOG(4) << "End run " << place << " "
...@@ -768,12 +784,7 @@ void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map, ...@@ -768,12 +784,7 @@ void ShrinkDownstreamMap(std::map<int, std::list<int>>* downstream_map,
// b: c // b: c
// happens_before[i][j] means i should be executed before j // happens_before[i][j] means i should be executed before j
op_happens_before->resize(op_num); op_happens_before->assign(op_num, std::vector<bool>(op_num, false));
for (size_t i = 0; i < op_num; ++i) {
(*op_happens_before)[i].resize(op_num);
std::fill(
(*op_happens_before)[i].begin(), (*op_happens_before)[i].end(), false);
}
// bfs to get all next ops // bfs to get all next ops
auto bfs = [&](size_t op_idx) { auto bfs = [&](size_t op_idx) {
...@@ -883,6 +894,18 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -883,6 +894,18 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
} }
} }
// the original output of inplace op is also change.
if (!vec_instruction[op_idx].InplaceBackMap().empty()) {
auto& m = vec_instruction[op_idx].InplaceBackMap();
for (auto& p : m) {
auto& var = p.second;
if (var2min_rw_op.count(var)) {
for (auto dep_op : var2min_rw_op[var]) {
op2dependences[op_idx].insert(dep_op);
}
}
}
}
// step2: update 2 var2xxxx data structure // step2: update 2 var2xxxx data structure
for (auto& item : for (auto& item :
...@@ -894,16 +917,6 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -894,16 +917,6 @@ std::map<int, std::list<int>> build_op_downstream_map(
} }
} }
for (auto& item :
vec_instruction[op_idx].Inputs()) { // for all inputs(read only)
for (auto var : item.second) {
if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
}
}
}
// NOTE(zhiqiu): The inplace op with `transfer` also changes // NOTE(zhiqiu): The inplace op with `transfer` also changes
// original output after that so add original output as well // original output after that so add original output as well
// original: a->op->a // original: a->op->a
...@@ -914,8 +927,16 @@ std::map<int, std::list<int>> build_op_downstream_map( ...@@ -914,8 +927,16 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (auto& p : m) { for (auto& p : m) {
auto var = p.second; auto var = p.second;
var2recent_write_op[var] = op_idx; var2recent_write_op[var] = op_idx;
// var in input list and in output list, so remove it. var2min_rw_op[var] = {static_cast<int>(op_idx)};
if (remove_duplicate.count(var) == 0) { remove_duplicate.insert(var);
}
}
for (auto& item :
vec_instruction[op_idx].Inputs()) { // for all inputs(read only)
for (auto var : item.second) {
if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
} }
} }
......
...@@ -389,7 +389,8 @@ static bool IsCpuOp(const Instruction& instr) { ...@@ -389,7 +389,8 @@ static bool IsCpuOp(const Instruction& instr) {
// is supported heterogeneous place // is supported heterogeneous place
static bool IsSupportedHetePlace(const phi::Place& place) { static bool IsSupportedHetePlace(const phi::Place& place) {
return platform::is_gpu_place(place) || platform::is_xpu_place(place); return platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_xpu_place(place);
} }
} // namespace interpreter } // namespace interpreter
......
...@@ -21,23 +21,37 @@ ...@@ -21,23 +21,37 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace {
std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>*
d2h_ctxs = nullptr;
std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>*
h2d_ctxs = nullptr;
std::mutex ctx_mtx;
} // namespace
StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) { StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) {
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place) || platform::is_npu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::lock_guard<std::mutex> lk(ctx_mtx);
platform::EmplaceDeviceContexts( if (d2h_ctxs == nullptr) {
&d2h_ctxs_, d2h_ctxs = new std::map<
{place}, Place,
/*disable_setting_default_stream_for_allocator=*/true); std::shared_future<std::unique_ptr<platform::DeviceContext>>>();
platform::EmplaceDeviceContexts( h2d_ctxs = new std::map<
&h2d_ctxs_, Place,
{place}, std::shared_future<std::unique_ptr<platform::DeviceContext>>>();
/*disable_setting_default_stream_for_allocator=*/true); }
#else if (d2h_ctxs->find(place) == d2h_ctxs->end()) {
PADDLE_THROW( platform::EmplaceDeviceContexts(
platform::errors::Unimplemented("CUDAPlace is not supported. Please " d2h_ctxs,
"re-compile with WITH_GPU option.")); {place},
#endif /*disable_setting_default_stream_for_allocator=*/true);
platform::EmplaceDeviceContexts(
h2d_ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true);
}
d2h_ctx_ = (*d2h_ctxs)[place];
h2d_ctx_ = (*h2d_ctxs)[place];
} }
} }
...@@ -162,15 +176,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -162,15 +176,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node) { const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type(); auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_; auto* dev_ctx = op_func_node.dev_ctx_;
// only gpu need update. xpu not need, because xpu memcpy op kernel is // only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
// synchronous. // synchronous.
if (platform::is_gpu_place(place_)) { if (platform::is_gpu_place(place_) || platform::is_npu_place(place_)) {
if (op_type == interpreter::kMemcpyD2H) { if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_"; VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctxs_[place_].get().get(); dev_ctx = d2h_ctx_.get().get();
} else if (op_type == interpreter::kMemcpyH2D) { } else if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_"; VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctxs_[place_].get().get(); dev_ctx = h2d_ctx_.get().get();
} }
} }
return dev_ctx; return dev_ctx;
...@@ -188,11 +202,20 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -188,11 +202,20 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
*/ */
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) { const Instruction& next_instr) {
return platform::is_xpu_place(place_) || if (&cur_instr.DeviceContext() == &next_instr.DeviceContext()) return true;
(&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpreter::IsCpuOp(cur_instr) || // xpu memcpy kerenl is synchronous.
interpreter::IsMemcpyD2H(cur_instr) || if (platform::is_xpu_place(place_)) return true;
interpreter::IsMemcpyH2D(next_instr));
// npu d2h kernel is asynchronous.
if (platform::is_npu_place(place_)) {
return interpreter::IsCpuOp(cur_instr) ||
interpreter::IsMemcpyH2D(next_instr);
}
// gpu or cpu
return interpreter::IsCpuOp(cur_instr) ||
interpreter::IsMemcpyD2H(cur_instr) ||
interpreter::IsMemcpyH2D(next_instr);
} }
platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
...@@ -201,6 +224,8 @@ platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { ...@@ -201,6 +224,8 @@ platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) {
} else { } else {
if (platform::is_xpu_place(place_)) { if (platform::is_xpu_place(place_)) {
return platform::kXPU; return platform::kXPU;
} else if (platform::is_npu_place(place_)) {
return platform::kNPU;
} }
return platform::kCUDA; return platform::kCUDA;
} }
......
...@@ -53,9 +53,9 @@ class StreamAnalyzer { ...@@ -53,9 +53,9 @@ class StreamAnalyzer {
platform::DeviceType GetWaiterType(const Instruction& instr); platform::DeviceType GetWaiterType(const Instruction& instr);
Place place_; const Place place_;
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>> d2h_ctxs_; std::shared_future<std::unique_ptr<platform::DeviceContext>> d2h_ctx_;
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>> h2d_ctxs_; std::shared_future<std::unique_ptr<platform::DeviceContext>> h2d_ctx_;
std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_; std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_;
}; };
......
...@@ -1080,11 +1080,11 @@ AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, ...@@ -1080,11 +1080,11 @@ AllocationPtr AllocatorFacade::Alloc(const platform::Place& place,
} else { } else {
return m->GetAllocator(p, size)->Allocate(size); return m->GetAllocator(p, size)->Allocate(size);
} }
#elif defined PADDLE_WITH_XPU #elif defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_ASCEND_CL)
return GetAllocator(place)->Allocate(size); return GetAllocator(place)->Allocate(size);
#else #else
PADDLE_THROW( PADDLE_THROW(platform::errors::PreconditionNotMet(
platform::errors::PreconditionNotMet("Not compiled with GPU or XPU.")); "Not compiled with GPU or XPU or NPU."));
#endif #endif
} }
......
...@@ -70,8 +70,12 @@ class CropNPUKernel : public framework::OpKernel<T> { ...@@ -70,8 +70,12 @@ class CropNPUKernel : public framework::OpKernel<T> {
shape->dims().size(), shape->dims().size(),
x->dims().size())); x->dims().size()));
// shape memory maybe have gc.
Tensor tmp_shape(*shape);
tmp_shape.mutable_data<T>(ctx.GetPlace());
const auto& runner = const auto& runner =
NpuOpRunner("Crop", {*x, *shape}, {*out}, attr_input); NpuOpRunner("Crop", {*x, tmp_shape}, {*out}, attr_input);
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
......
...@@ -94,14 +94,13 @@ class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -94,14 +94,13 @@ class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", AddOutput("Out",
"(LoDTensor) The type of output " "(LoDTensor) The type of output "
"is the same as input X."); "is the same as input X.");
AddAttr<int>( AddAttr<int>("dst_place_type",
"dst_place_type", "Determine the dst place of tensor copy. "
"Determine the dst place of tensor copy. " "By Now it support:"
"By Now it ONLY support CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace " "0. CUDAPinnedPlace/CPU <->CUDAPlace"
"Other place type is Unimplemented and will cause ERROR." "1. NPUPinnedPlace/CPU <-> NPUPlace"
"0: dst is on CUDAPlace. " "2. CPU <->XPUPlace"
"1: dst is on NPUPlace. " "Other place type is Unimplemented and will cause ERROR.");
"2: dst is on XPUPlace. ");
AddComment(R"DOC( AddComment(R"DOC(
MemcpyD2H Operator. MemcpyD2H Operator.
By now, it ONLY supports the memcopy between CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace. By now, it ONLY supports the memcopy between CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace.
......
...@@ -280,6 +280,16 @@ if(WITH_XPU) ...@@ -280,6 +280,16 @@ if(WITH_XPU)
CACHE INTERNAL "device event libs") CACHE INTERNAL "device event libs")
endif() endif()
if(WITH_ASCEND_CL)
cc_library(
device_event_npu
SRCS device_event_npu.cc
DEPS device_event_base npu_resource_pool)
set(DEVICE_EVENT_LIBS
device_event_npu
CACHE INTERNAL "device event libs")
endif()
if(WITH_GPU) if(WITH_GPU)
nv_library( nv_library(
device_event_gpu device_event_gpu
......
...@@ -285,6 +285,10 @@ void NPUEventQuery(aclrtEvent event, aclrtEventStatus *status) { ...@@ -285,6 +285,10 @@ void NPUEventQuery(aclrtEvent event, aclrtEventStatus *status) {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtQueryEvent(event, status)); PADDLE_ENFORCE_NPU_SUCCESS(aclrtQueryEvent(event, status));
} }
void NPUEventSynchronize(aclrtEvent event) {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeEvent(event));
}
void NPUStreamWaitEvent(aclrtStream stream, aclrtEvent event) { void NPUStreamWaitEvent(aclrtStream stream, aclrtEvent event) {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtStreamWaitEvent(stream, event)); PADDLE_ENFORCE_NPU_SUCCESS(aclrtStreamWaitEvent(stream, event));
} }
......
...@@ -138,6 +138,9 @@ void NPUEventQuery(aclrtEvent event, aclrtEventStatus *status); ...@@ -138,6 +138,9 @@ void NPUEventQuery(aclrtEvent event, aclrtEventStatus *status);
//! Record NPU event in the stream. //! Record NPU event in the stream.
void NPUEventRecord(aclrtEvent event, aclrtStream stream); void NPUEventRecord(aclrtEvent event, aclrtStream stream);
//! Synchronize NPU event.
void NPUEventSynchronize(aclrtEvent event);
//! Makes a stream wait on an event. //! Makes a stream wait on an event.
void NPUStreamWaitEvent(aclrtStream stream, aclrtEvent event); void NPUStreamWaitEvent(aclrtStream stream, aclrtEvent event);
......
...@@ -125,6 +125,8 @@ DeviceType Place2DeviceType(const platform::Place& place) { ...@@ -125,6 +125,8 @@ DeviceType Place2DeviceType(const platform::Place& place) {
return platform::DeviceType::XPU; return platform::DeviceType::XPU;
} else if (platform::is_ipu_place(place)) { } else if (platform::is_ipu_place(place)) {
return platform::DeviceType::IPU; return platform::DeviceType::IPU;
} else if (platform::is_npu_place(place)) {
return platform::DeviceType::NPU;
} else if (platform::is_mlu_place(place)) { } else if (platform::is_mlu_place(place)) {
return platform::DeviceType::MLU; return platform::DeviceType::MLU;
} else { } else {
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
using ::paddle::platform::kCPU; using ::paddle::platform::kCPU;
using ::paddle::platform::kCUDA; using ::paddle::platform::kCUDA;
using ::paddle::platform::kNPU;
using ::paddle::platform::kXPU; using ::paddle::platform::kXPU;
USE_EVENT(kCPU) USE_EVENT(kCPU)
...@@ -41,3 +42,9 @@ USE_EVENT(kXPU); ...@@ -41,3 +42,9 @@ USE_EVENT(kXPU);
USE_EVENT_WAIT(kXPU, kXPU) USE_EVENT_WAIT(kXPU, kXPU)
USE_EVENT_WAIT(kCPU, kXPU) USE_EVENT_WAIT(kCPU, kXPU)
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL
USE_EVENT(kNPU);
USE_EVENT_WAIT(kNPU, kNPU)
USE_EVENT_WAIT(kCPU, kNPU)
#endif
...@@ -66,7 +66,7 @@ class DeviceEvent { ...@@ -66,7 +66,7 @@ class DeviceEvent {
type_id_)); type_id_));
// TODO(Aurelius84): only support CPU/CUDA, need consider XPU/NPU later // TODO(Aurelius84): only support CPU/CUDA, need consider XPU/NPU later
PADDLE_ENFORCE_LT(type_id_, PADDLE_ENFORCE_LT(type_id_,
3, 4,
platform::errors::Unavailable( platform::errors::Unavailable(
"Currently DeviceEvent do not support %s", place)); "Currently DeviceEvent do not support %s", place));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_resource_pool.h"
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h"
namespace paddle {
namespace platform {
struct NPUDeviceEventWrapper {
explicit NPUDeviceEventWrapper(const platform::Place& place) {
PADDLE_ENFORCE_EQ(
platform::is_npu_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be NPUPlace, but received %d. ", place));
device_id_ = place.device;
PADDLE_ENFORCE_GT(
device_id_,
-1,
platform::errors::PreconditionNotMet(
"Required DeviceOption.device_id > -1, but received %d. ",
device_id_));
inner_event_ = NpuEventResourcePool::Instance().New(device_id_);
}
std::shared_ptr<NpuEventObject> inner_event_;
int device_id_;
};
void DeviceEventCreateNPU(DeviceEvent* event,
const platform::Place& place,
unsigned int) {
event->InitEvent(std::make_shared<NPUDeviceEventWrapper>(place));
}
void DeviceEventRecordNPU(DeviceEvent* event, const DeviceContext* context) {
auto* wrapper = static_cast<NPUDeviceEventWrapper*>(event->GetEvent().get());
auto* npu_dev_ctx = dynamic_cast<const platform::NPUDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
npu_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into NPUDeviceContext."));
NPUEventRecord(wrapper->inner_event_.get(), npu_dev_ctx->stream());
}
bool DeviceEventQueryNPU(const DeviceEvent* event) {
auto* wrapper = static_cast<NPUDeviceEventWrapper*>(event->GetEvent().get());
PADDLE_ENFORCE_NOT_NULL(
wrapper,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast event into NPUDeviceEventWrapper."));
aclrtEventStatus status = ACL_EVENT_STATUS_COMPLETE;
platform::NPUEventQuery(wrapper->inner_event_.get(), &status);
return ACL_EVENT_STATUS_COMPLETE == status;
}
void DeviceEventFinishNPU(const DeviceEvent* event) {
auto* wrapper = static_cast<NPUDeviceEventWrapper*>(event->GetEvent().get());
NPUEventSynchronize(wrapper->inner_event_.get());
}
void DeviceEventNPUWaitNPU(const DeviceEvent* event,
const DeviceContext* context) {
auto* wrapper = static_cast<NPUDeviceEventWrapper*>(event->GetEvent().get());
auto* npu_dev_ctx = dynamic_cast<const platform::NPUDeviceContext*>(context);
PADDLE_ENFORCE_NOT_NULL(
npu_dev_ctx,
platform::errors::PreconditionNotMet(
"Failed to dynamic_cast context into NPUDeviceContext."));
NPUStreamWaitEvent(npu_dev_ctx->stream(), wrapper->inner_event_.get());
}
void DeviceEventCPUWaitNPU(const DeviceEvent* event,
const DeviceContext* context) {
DeviceEventFinishNPU(event);
}
void DeviceEventSetFinishedNPU(const DeviceEvent* event) {
// do nothing
}
void EventResetNPU(const DeviceEvent* event) {
// do nothing
}
} // namespace platform
} // namespace paddle
using ::paddle::platform::kCPU;
using ::paddle::platform::kNPU;
REGISTER_EVENT_CREATE_FUNCTION(kNPU, paddle::platform::DeviceEventCreateNPU)
REGISTER_EVENT_RECORD_FUNCTION(kNPU, paddle::platform::DeviceEventRecordNPU)
REGISTER_EVENT_QUERY_FUNCTION(kNPU, paddle::platform::DeviceEventQueryNPU)
REGISTER_EVENT_FINISH_FUNCTION(kNPU, paddle::platform::DeviceEventFinishNPU)
REGISTER_EVENT_SET_FINISHED_FUNCTION(
kNPU, paddle::platform::DeviceEventSetFinishedNPU)
REGISTER_EVENT_WAIT_FUNCTION(kNPU,
kNPU,
paddle::platform::DeviceEventNPUWaitNPU)
REGISTER_EVENT_WAIT_FUNCTION(kCPU,
kNPU,
paddle::platform::DeviceEventCPUWaitNPU)
REGISTER_EVENT_RESET_FUNCTION(kNPU, paddle::platform::EventResetNPU)
#endif
...@@ -1400,9 +1400,8 @@ class Executor(object): ...@@ -1400,9 +1400,8 @@ class Executor(object):
program = pruned_program program = pruned_program
def _can_use_interpreter_core(program, place): def _can_use_interpreter_core(program, place):
if core.is_compiled_with_npu() or core.is_compiled_with_mlu( if core.is_compiled_with_mlu() or core.is_compiled_with_ipu(
) or core.is_compiled_with_ipu() or isinstance( ) or isinstance(place, core.CustomPlace):
place, core.CustomPlace):
return False return False
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册