From b2704837d2f4dd094abd1f97167862e09f0fc1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Fri, 24 Jun 2022 14:04:24 +0800 Subject: [PATCH] add xpu support for new static alone executor. test=develop (#43076) --- paddle/fluid/framework/CMakeLists.txt | 9 + .../framework/new_executor/data_transfer.cc | 54 +++--- .../framework/new_executor/data_transfer.h | 3 - .../new_executor/interpretercore_util.cc | 173 +++++++----------- .../new_executor/new_executor_defs.h | 7 +- .../framework/new_executor/stream_analyzer.cc | 24 ++- paddle/fluid/framework/operator.cc | 17 ++ paddle/fluid/framework/operator.h | 8 +- .../memory/allocation/allocator_facade.cc | 8 +- paddle/fluid/operators/CMakeLists.txt | 3 + paddle/fluid/operators/memcpy_d2h_op.cc | 13 +- paddle/fluid/operators/memcpy_h2d_op.cc | 14 +- paddle/fluid/operators/memcpy_h2d_op.h | 2 +- paddle/fluid/platform/CMakeLists.txt | 9 + paddle/fluid/platform/device_event.h | 7 + paddle/fluid/platform/device_event_xpu.cc | 118 ++++++++++++ python/paddle/fluid/executor.py | 6 +- 17 files changed, 324 insertions(+), 151 deletions(-) create mode 100644 paddle/fluid/platform/device_event_xpu.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 9fef2394c06..2aaa0c96e0a 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -298,6 +298,15 @@ elseif(WITH_ROCM) data_type_transform_test SRCS data_type_transform_test.cc data_type_transform_test.cu DEPS data_type_transform) +elseif(WITH_XPU) + cc_library( + data_type_transform + SRCS data_type_transform.cc + DEPS tensor xpulib) + cc_test( + data_type_transform_test + SRCS data_type_transform_test.cc + DEPS data_type_transform) else() cc_library( data_type_transform diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 171e15162fb..525c3bdbe74 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -277,13 +277,33 @@ std::shared_ptr TransferDevice(const std::string& var_name, // 2. Construct VariableNameMap VariableNameMap in_name_map = {{"X", {var_name}}}; VariableNameMap out_name_map = {{"Out", {*new_var_name}}}; - int dst_place_type = platform::is_cpu_place(dst_place) ? 0 - : platform::is_gpu_place(dst_place) ? 1 - : -1; - AttributeMap attr_map = {{"dst_place_type", dst_place_type}}; // 3. Create memcpy_d2h_op or memcpy_h2d_op - std::string op_type = get_memcpy_type(src_place, dst_place); + std::string op_type; + AttributeMap attr_map; + PADDLE_ENFORCE_EQ(platform::is_same_place(src_place, dst_place), false, + platform::errors::PreconditionNotMet( + "Required src_place shall be different with dst_place, " + "but received same place: %s", + src_place)); + if (IsSupportedHetePlace(dst_place)) { + op_type = kMemcpyH2D; + int dst_place_type = platform::is_gpu_place(dst_place) ? 0 + : platform::is_npu_place(dst_place) ? 1 + : platform::is_xpu_place(dst_place) ? 2 + : -1; + attr_map = {{"dst_place_type", dst_place_type}}; + } else if (IsSupportedHetePlace(src_place)) { + op_type = kMemcpyD2H; + int dst_place_type = platform::is_cpu_place(dst_place) ? 0 + : platform::is_cuda_pinned_place(dst_place) ? 1 + : -1; + attr_map = {{"dst_place_type", dst_place_type}}; + } else { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Not support Memcpy typ : %s -> %s", src_place, dst_place)); + } + auto& op_info = OpInfoMap::Instance().Get(op_type); auto op = std::shared_ptr( op_info.Creator()(op_type, in_name_map, out_name_map, attr_map)); @@ -434,31 +454,13 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, if (transfered) { // NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent - // with instruction. (hot fix, it is not good design here) - op_func_node->operator_base_ = - std::shared_ptr(framework::OpRegistry::CreateOp( - op_base->Type(), new_ins, new_outs, op_base->Attrs())); + // with instruction. + op_base->Inputs() = new_ins; + op_base->Outputs() = new_outs; } op_func_node->no_data_transform_index = std::move(no_data_transform_index); } -std::string get_memcpy_type(const platform::Place& src_place, - const platform::Place& dst_place) { - PADDLE_ENFORCE_EQ(platform::is_same_place(src_place, dst_place), false, - platform::errors::PreconditionNotMet( - "Required src_place shall be different with dst_place, " - "but received same place: %s", - src_place)); - if (platform::is_gpu_place(dst_place)) { - return kMemcpyH2D; - } else if (platform::is_gpu_place(src_place)) { - return kMemcpyD2H; - } else { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "Not support Memcpy typ : %s -> %s", src_place, dst_place)); - } -} - void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, const platform::Place& place, const VariableNameMap& out_names, diff --git a/paddle/fluid/framework/new_executor/data_transfer.h b/paddle/fluid/framework/new_executor/data_transfer.h index 9525ba5bc8f..52a342c9a7f 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.h +++ b/paddle/fluid/framework/new_executor/data_transfer.h @@ -68,9 +68,6 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node, std::vector* op_func_nodes, framework::Scope* local_scope); -std::string get_memcpy_type(const platform::Place& src_place, - const platform::Place& dst_place); - inline bool need_device_transform(const OpKernelType& kernel_type_for_var, const OpKernelType& expected_kernel_key) { auto& src_place = kernel_type_for_var.place_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index fb501175dde..9257cccc7ee 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -348,7 +348,7 @@ void deal_operator_base(const platform::Place& place, auto* dev_ctx = pool.Get(place); // input, output is prepared. set the other attributes. op_func_node->operator_base_ = op_base; - if (platform::is_gpu_place(place)) { + if (IsSupportedHetePlace(place)) { op_func_node->type_ = OpFuncType::kQueueAsync; } else if (platform::is_cpu_place(place)) { op_func_node->type_ = OpFuncType::kQueueSync; @@ -379,7 +379,6 @@ void build_op_func_list(const platform::Place& place, bool use_local_scope) { Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope() : var_scope->GetMutableScope(); - auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); std::vector> ops_unique; // its elements will be moved to vec_func_list // Step 1: create all ops for current block. @@ -429,7 +428,7 @@ void build_op_func_list(const platform::Place& place, std::tie(outs_map, outs_name2id) = build_variable_map(outputs_names, var_scope, enforce_exist); - // step 2: build OpFuncNode + // step 1: build OpFuncNode OpFuncNode op_func_node; op_func_node.operator_base_ = ops[i]; op_func_node.input_index = ins_name2id; @@ -449,11 +448,7 @@ void build_op_func_list(const platform::Place& place, runtime_context.inputs.swap(ins_map); runtime_context.outputs.swap(outs_map); - platform::DeviceContextPool& pool = - platform::DeviceContextPool::Instance(); - auto* dev_ctx = pool.Get(place); - Scope scope; - Scope* runtime_scope = &scope; + Scope scope, *runtime_scope = &scope; // NOTE(Ruibiao): We do not encourage directly using scope in OP kernel. // But some OPs do have such behavior (e.g., cinn_launch OP). Here special // treatment for them. @@ -465,63 +460,17 @@ void build_op_func_list(const platform::Place& place, runtime_scope = local_scope; } - auto expected_kernel_key = op_with_kernel->GetExpectedKernelType( - ExecutionContext(*op, *runtime_scope, *dev_ctx, runtime_context)); - op_with_kernel->ResetKernelType(new OpKernelType(expected_kernel_key)); - + auto& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.Get(place); + auto exec_ctx = ExecutionContext(*op_with_kernel, *runtime_scope, + *dev_ctx, runtime_context); + auto expected_kernel_key = + op_with_kernel->GetExpectedKernelType(exec_ctx); // change device by the device_guard() apply_device_guard(op, place, &expected_kernel_key); - VLOG(3) << "expected_kernel_key : " << expected_kernel_key; - - // step 3. apply data transforms and insert data transfer ops - VariableValueMap& ins_map_temp = runtime_context.inputs; - VariableValueMap& outs_map_temp = runtime_context.outputs; - - // NOTE(zhiqiu): op_func_node->operator_base_ maybe changed in - // ApplyDataTransform - ApplyDataTransform(expected_kernel_key, - place, - &ins_map_temp, - &outs_map_temp, - var_scope, - &op_func_node, - vec_func_list, - use_local_scope); - op_with_kernel = const_cast( - static_cast( - op_func_node.operator_base_.get())); - - // step 4. Run op kernel - VLOG(3) << op_with_kernel->Type() - << " : expected_kernel_key : " << expected_kernel_key; - - if (platform::is_gpu_place(expected_kernel_key.place_)) { - op_func_node.type_ = OpFuncType::kQueueAsync; - } else if (platform::is_cpu_place(expected_kernel_key.place_)) { - op_func_node.type_ = OpFuncType::kQueueSync; - } else { - PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s", - expected_kernel_key.place_)); - } - if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { - dev_ctx = pool.Get(expected_kernel_key.place_); - } - op_func_node.dev_ctx_ = dev_ctx; - VLOG(3) << op_with_kernel->Type() - << " : expected_kernel_key : " << expected_kernel_key; - - // see OperatorWithKernel::RunImpl in operator.cc for why - if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && - op->Attr(kAllKernelsMustComputeRuntimeShape))) { - InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); - // TODO(Aurelius84): In case of control flow ops, they are NOT - // inheritted from OperatorWithKernel. - op_with_kernel->Info().infer_shape_(&infer_shape_ctx); - } - - auto exec_ctx = ExecutionContext( - *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); + VLOG(4) << "expected_kernel_key : " << expected_kernel_key; + // step 2. select op kernel auto run_phi_kernel = false; if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( op_with_kernel->Type())) { @@ -531,10 +480,7 @@ void build_op_func_list(const platform::Place& place, if (op_with_kernel->PhiKernel()->IsValid()) { run_phi_kernel = true; } else { - auto kernels_iter = all_op_kernels.find(op_with_kernel->Type()); - if (kernels_iter == all_op_kernels.end() || - kernels_iter->second.find(expected_kernel_key) == - kernels_iter->second.end()) { + if (!op_with_kernel->SupportsKernelType(expected_kernel_key)) { auto pt_cpu_kernel_key = FallBackToCpu( expected_kernel_key, pt_kernel_key, *op_with_kernel); op_with_kernel->ResetPhiKernel( @@ -545,55 +491,76 @@ void build_op_func_list(const platform::Place& place, << pt_kernel_name << " | kernel key: " << pt_cpu_kernel_key << " | kernel: " << *(op_with_kernel->PhiKernel()); + op_with_kernel->ResetKernelType(new OpKernelType( + TransPhiKernelKeyToOpKernelType(pt_cpu_kernel_key))); run_phi_kernel = true; } } } } + if (!run_phi_kernel) { + op_with_kernel->ChooseKernel(exec_ctx); + op_func_node.kernel_func_ = *op_with_kernel->kernel_func(); + } else { + op_func_node.pt_kernel_ = op_with_kernel->PhiKernel(); + } + auto kernel_type = *(op_with_kernel->kernel_type()); + if (kernel_type.place_ != dev_ctx->GetPlace()) { + dev_ctx = pool.Get(kernel_type.place_); + } + op_func_node.dev_ctx_ = dev_ctx; + if (IsSupportedHetePlace(kernel_type.place_)) { + op_func_node.type_ = OpFuncType::kQueueAsync; + } else if (platform::is_cpu_place(kernel_type.place_)) { + op_func_node.type_ = OpFuncType::kQueueSync; + } else { + PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s", + kernel_type.place_)); + } VLOG(3) << op_with_kernel->Type() - << " : expected_kernel_key : " << expected_kernel_key; + << " : finally selected kernel_key: " << kernel_type; + + // step 3. data transform + VariableValueMap& ins_map_temp = runtime_context.inputs; + VariableValueMap& outs_map_temp = runtime_context.outputs; + ApplyDataTransform(kernel_type, place, &ins_map_temp, &outs_map_temp, + var_scope, &op_func_node, vec_func_list, + use_local_scope); + + // step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc for + // why. + if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) && + op->Attr(kAllKernelsMustComputeRuntimeShape))) { + InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); + // TODO(Aurelius84): In case of control flow ops, they are NOT + // inheritted from OperatorWithKernel. + op_with_kernel->Info().infer_shape_(&infer_shape_ctx); + } + + // step 5. run kernel if (run_phi_kernel) { phi::KernelContext pt_kernel_context; - op_with_kernel->BuildPhiKernelContext( - runtime_context, dev_ctx, &pt_kernel_context); - op_func_node.pt_kernel_ = op_with_kernel->PhiKernel(); + op_with_kernel->BuildPhiKernelContext(runtime_context, dev_ctx, + &pt_kernel_context); (*op_func_node.pt_kernel_)(&pt_kernel_context); } else { - auto kernels_iter = all_op_kernels.find(op->Type()); - PADDLE_ENFORCE_NE( - kernels_iter, - all_op_kernels.end(), - platform::errors::Unavailable( - "There are no kernels which are registered in the %s operator.", - op->Type())); - OpKernelMap& kernels = kernels_iter->second; - - auto kernel_iter = kernels.find(expected_kernel_key); - PADDLE_ENFORCE_NE(kernel_iter, - kernels.end(), - platform::errors::NotFound( - "Operator (%s) does not have kernel for %s.", - op->Type(), - KernelTypeToString(expected_kernel_key))); - // TODO(zhiqiu): add fallback logic - op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); - op_func_node.kernel_func_(exec_ctx); + // the place of exec_ctx maybe has changed. + op_func_node.kernel_func_(ExecutionContext( + *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context)); } - // post-process grad_op.outputs if need cast complex grad into real grad. + // post-process grad_op.outputs if need cast complex grad into real + // grad. // NOTE(Aurelius84): insert a transfer_dtype_op inplacely to cast it. - if (framework::IsComplexType(expected_kernel_key.data_type_)) { - interpreter::HandleComplexGradToRealGrad(op_func_node, - place, - outputs_names, - &runtime_context.outputs, - var_scope, - vec_func_list, - local_scope); + if (framework::IsComplexType(kernel_type.data_type_)) { + interpreter::HandleComplexGradToRealGrad( + op_func_node, place, outputs_names, &runtime_context.outputs, + var_scope, vec_func_list, local_scope); } if (!op_func_node.inplace_back_map.empty()) { auto& m = op_func_node.inplace_back_map; - // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in operator.cc + // NOTE(zhiqiu): same logic as TransferInplaceVarsBack() in + // operator.cc for (auto& p : m) { auto* transformed_tensor = GetMutableLoDTensorOrSelectedRowsValueFromVar( @@ -899,17 +866,17 @@ std::map> build_op_downstream_map( // step2: update 2 var2xxxx data structure for (auto& item : - vec_instruction[op_idx].Inputs()) { // for all inputs(read only) + vec_instruction[op_idx].Outputs()) { // for all write vars for (auto var : item.second) { - update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); + var2recent_write_op[var] = op_idx; + var2min_rw_op[var] = {static_cast(op_idx)}; remove_duplicate.insert(var); } } for (auto& item : - vec_instruction[op_idx].Outputs()) { // for all write vars + vec_instruction[op_idx].Inputs()) { // for all inputs(read only) for (auto var : item.second) { - var2recent_write_op[var] = op_idx; if (remove_duplicate.count(var) == 0) { // var in input list and in output list, so remove it. update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 20e51145a51..30ef7eb9fcf 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -297,7 +297,7 @@ struct InstructionInfo { enum class OpFuncType { kQueueSync = 0, // CPU kernel, block host - kQueueAsync = 1, // GPU Kernel or d2h, h2d, send, recv, broadcast + kQueueAsync = 1, // GPU、XPU Kernel or d2h, h2d, send, recv, broadcast }; class RuntimeInferShapeContext; @@ -417,6 +417,11 @@ static bool IsCpuOp(const Instruction& instr) { return platform::is_cpu_place(instr.DeviceContext().GetPlace()); } +// is supported heterogeneous place +static bool IsSupportedHetePlace(const phi::Place& place) { + return platform::is_gpu_place(place) || platform::is_xpu_place(place); +} + } // namespace interpreter } // namespace framework diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 469876b01f6..9b91cd928de 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -155,20 +155,24 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( const OpFuncNode& op_func_node) { auto& op_type = op_func_node.operator_base_->Type(); auto* dev_ctx = op_func_node.dev_ctx_; - if (op_type == interpreter::kMemcpyD2H) { - VLOG(3) << "Get dev_ctx from d2h_context_pool_"; - dev_ctx = d2h_ctxs_[place_].get().get(); - } else if (op_type == interpreter::kMemcpyH2D) { - VLOG(3) << "Get dev_ctx from h2d_context_pool_"; - dev_ctx = h2d_ctxs_[place_].get().get(); + // only gpu need update. xpu not need, because xpu memcpy op kernel is + // synchronous. + if (platform::is_gpu_place(place_)) { + if (op_type == interpreter::kMemcpyD2H) { + VLOG(3) << "Get dev_ctx from d2h_context_pool_"; + dev_ctx = d2h_ctxs_[place_].get().get(); + } else if (op_type == interpreter::kMemcpyH2D) { + VLOG(3) << "Get dev_ctx from h2d_context_pool_"; + dev_ctx = h2d_ctxs_[place_].get().get(); + } } - return dev_ctx; } /* * NOTE(dev): The following cases are considered as directly run: * + * 0. in XPU place. because xpu memcpy op kernel is synchronous. * 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU * 2. CPU -> any (it is possible: CPU op->VAR->GPU op, when var is no need * buffer or no need data transform) @@ -177,7 +181,8 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( */ bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, const Instruction& next_instr) { - return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() || + return platform::is_xpu_place(place_) || + (&cur_instr.DeviceContext() == &next_instr.DeviceContext() || interpreter::IsCpuOp(cur_instr) || interpreter::IsMemcpyD2H(cur_instr) || interpreter::IsMemcpyH2D(next_instr)); @@ -187,6 +192,9 @@ platform::DeviceType StreamAnalyzer::GetWaiterType(const Instruction& instr) { if (instr.KernelType() == OpFuncType::kQueueSync) { return platform::kCPU; } else { + if (platform::is_xpu_place(place_)) { + return platform::kXPU; + } return platform::kCUDA; } } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 31c3ea7607b..2be93f0dc91 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1296,6 +1296,23 @@ bool OperatorWithKernel::SupportsMKLDNN( }); } +bool OperatorWithKernel::SupportsKernelType( + const OpKernelType& kernel_type) const { + auto& all_op_kernels = AllOpKernels(); + auto kernels_iter = all_op_kernels.find(type_); + bool support = + kernels_iter != all_op_kernels.end() && + kernels_iter->second.find(kernel_type) != kernels_iter->second.end(); +#if defined(PADDLE_WITH_XPU) + if (paddle::platform::is_xpu_place(kernel_type.place_)) { + support = support && + paddle::platform::is_xpu_support_op(type_, kernel_type) && + !paddle::platform::is_in_xpu_black_list(type_); + } +#endif + return support; +} + bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const { const auto& attrs_map = ctx.Attrs(); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index dc13287b5aa..d09e34b43f1 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -193,6 +193,8 @@ class OperatorBase { const VariableNameMap& Inputs() const { return inputs_; } const VariableNameMap& Outputs() const { return outputs_; } + VariableNameMap& Inputs() { return inputs_; } + VariableNameMap& Outputs() { return outputs_; } const OpInfo& Info() const { PADDLE_ENFORCE_NOT_NULL( @@ -579,6 +581,8 @@ class OperatorWithKernel : public OperatorBase { } bool SupportsMKLDNN(proto::VarType::Type data_type) const; + bool SupportsKernelType(const OpKernelType& kernel_type) const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, proto::VarType::Type data_type) const; @@ -621,6 +625,7 @@ class OperatorWithKernel : public OperatorBase { /* member functions for adapting to phi lib */ phi::KernelKey ChoosePhiKernel(const ExecutionContext& ctx) const; + void ChooseKernel(const ExecutionContext& ctx) const; /** * Transfer data place for phi kernel * Is this really needed? @@ -644,6 +649,7 @@ class OperatorWithKernel : public OperatorBase { } const OpKernelType* kernel_type() const { return kernel_type_.get(); } + const OpKernelFunc* kernel_func() const { return kernel_func_.get(); } void ResetKernelType(OpKernelType* kernel_type) { kernel_type_.reset(kernel_type); @@ -672,8 +678,6 @@ class OperatorWithKernel : public OperatorBase { OpKernelType InnerGetExpectedKernelType(const ExecutionContext& ctx) const; - void ChooseKernel(const ExecutionContext& ctx) const; - void HandleComplexGradToRealGrad(const Scope& scope, RuntimeContext* ctx) const; diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index 13536be5b40..ed49b566c47 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -704,7 +704,8 @@ class AllocatorFacadePrivate { if (platform::is_gpu_place(place)) { std::shared_ptr&& allocator = std::make_shared( - pair.second, place, /* default_stream = */ nullptr, + pair.second, place, + /* default_stream = */ nullptr, /* in_cuda_graph_capturing = */ !allow_free_idle_chunk_); pair.second = allocator; @@ -1044,8 +1045,11 @@ AllocationPtr AllocatorFacade::Alloc(const platform::Place& place, size_t size, } else { return m->GetAllocator(p, size)->Allocate(size); } +#elif defined PADDLE_WITH_XPU + return GetAllocator(place)->Allocate(size); #else - PADDLE_THROW(platform::errors::PreconditionNotMet("Not compiled with GPU.")); + PADDLE_THROW( + platform::errors::PreconditionNotMet("Not compiled with GPU or XPU.")); #endif } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d2d9ef1ab8f..17aabc25b3f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -174,6 +174,9 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function) if (WITH_GPU OR WITH_ROCM) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) endif() +if(WITH_XPU) + set(COMMON_OP_DEPS ${COMMON_OP_DEPS} xpulib) +endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter) diff --git a/paddle/fluid/operators/memcpy_d2h_op.cc b/paddle/fluid/operators/memcpy_d2h_op.cc index 9ad30d72eb3..4b22ce5acd6 100644 --- a/paddle/fluid/operators/memcpy_d2h_op.cc +++ b/paddle/fluid/operators/memcpy_d2h_op.cc @@ -95,7 +95,7 @@ class MemcpyD2HOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "dst_place_type", "Determine the dst place of tensor copy. " - "By Now it ONLY support NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU" + "By Now it ONLY support XPU/NPUPlace/CUDAPlace <-> CUDAPinnedPlace/CPU" "Other place type is Unimplemented and will cause ERROR." "0: dst is on CPUPlace. " "1: dst is on CUDAPinnedPlace. "); @@ -140,6 +140,17 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR( ops::MemcpyD2HKernel, int16_t, ops::MemcpyD2HKernel); #endif +#ifdef PADDLE_WITH_XPU +REGISTER_OP_XPU_KERNEL_FUNCTOR( + memcpy_d2h, float, ops::MemcpyD2HKernel, double, ops::MemcpyD2HKernel, + int8_t, ops::MemcpyD2HKernel, uint8_t, ops::MemcpyD2HKernel, int, + ops::MemcpyD2HKernel, int64_t, ops::MemcpyD2HKernel, bool, + ops::MemcpyD2HKernel, paddle::platform::bfloat16, ops::MemcpyD2HKernel, + paddle::platform::complex, ops::MemcpyD2HKernel, + paddle::platform::complex, ops::MemcpyD2HKernel, plat::float16, + ops::MemcpyD2HKernel, int16_t, ops::MemcpyD2HKernel); +#endif + #ifdef PADDLE_WITH_ASCEND_CL REGISTER_OP_NPU_KERNEL_FUNCTOR( memcpy_d2h, float, ops::MemcpyD2HKernel, double, ops::MemcpyD2HKernel, diff --git a/paddle/fluid/operators/memcpy_h2d_op.cc b/paddle/fluid/operators/memcpy_h2d_op.cc index c8e1e17d65a..f07c1c0deff 100644 --- a/paddle/fluid/operators/memcpy_h2d_op.cc +++ b/paddle/fluid/operators/memcpy_h2d_op.cc @@ -98,7 +98,8 @@ class MemcpyH2DOpProtoMaker : public framework::OpProtoAndCheckerMaker { "By Now it ONLY support CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace " "Other place type is Unimplemented and will cause ERROR." "0: dst is on CUDAPlace. " - "1: dst is on NPUPlace. "); + "1: dst is on NPUPlace. " + "2: dst is on XPUPlace. "); AddComment(R"DOC( MemcpyD2H Operator. By now, it ONLY supports the memcopy between CUDAPinnedPlace/CPU <-> NPUPlace/CUDAPlace. @@ -140,6 +141,17 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR( ops::MemcpyH2DKernel, int16_t, ops::MemcpyH2DKernel); #endif +#ifdef PADDLE_WITH_XPU +REGISTER_OP_XPU_KERNEL_FUNCTOR( + memcpy_h2d, float, ops::MemcpyH2DKernel, double, ops::MemcpyH2DKernel, + int8_t, ops::MemcpyH2DKernel, uint8_t, ops::MemcpyH2DKernel, int, + ops::MemcpyH2DKernel, int64_t, ops::MemcpyH2DKernel, bool, + ops::MemcpyH2DKernel, paddle::platform::bfloat16, ops::MemcpyH2DKernel, + paddle::platform::complex, ops::MemcpyH2DKernel, + paddle::platform::complex, ops::MemcpyH2DKernel, plat::float16, + ops::MemcpyH2DKernel, int16_t, ops::MemcpyH2DKernel); +#endif + #ifdef PADDLE_WITH_ASCEND_CL REGISTER_OP_NPU_KERNEL_FUNCTOR( memcpy_h2d, float, ops::MemcpyH2DKernel, double, ops::MemcpyH2DKernel, diff --git a/paddle/fluid/operators/memcpy_h2d_op.h b/paddle/fluid/operators/memcpy_h2d_op.h index 0d731426074..7ae3e9bb99a 100644 --- a/paddle/fluid/operators/memcpy_h2d_op.h +++ b/paddle/fluid/operators/memcpy_h2d_op.h @@ -49,7 +49,7 @@ class MemcpyH2DFunctor { dev_ctx_.GetPlace(), lod_tensor.dtype(), phi::Stream(reinterpret_cast(stream))); - if (dst_place_type_ == 0 || dst_place_type_ == 1) { + if (dst_place_type_ == 0 || dst_place_type_ == 1 || dst_place_type_ == 2) { framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor); } else { diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index ffb3f7e6eb9..dc6911aecf1 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -270,6 +270,15 @@ cc_library( set(DEVICE_EVENT_LIBS device_event_base CACHE INTERNAL "device event libs") +if(WITH_XPU) + cc_library( + device_event_xpu + SRCS device_event_xpu.cc + DEPS device_event_base xpu_info) + set(DEVICE_EVENT_LIBS + device_event_xpu + CACHE INTERNAL "device event libs") +endif() if(WITH_GPU) nv_library( diff --git a/paddle/fluid/platform/device_event.h b/paddle/fluid/platform/device_event.h index 82d93dee398..1fd11660062 100644 --- a/paddle/fluid/platform/device_event.h +++ b/paddle/fluid/platform/device_event.h @@ -25,6 +25,7 @@ using ::paddle::platform::kCPU; using ::paddle::platform::kCUDA; +using ::paddle::platform::kXPU; USE_EVENT(kCPU) USE_EVENT_WAIT(kCPU, kCPU) @@ -34,3 +35,9 @@ USE_EVENT(kCUDA); USE_EVENT_WAIT(kCUDA, kCUDA) USE_EVENT_WAIT(kCPU, kCUDA) #endif + +#ifdef PADDLE_WITH_XPU +USE_EVENT(kXPU); +USE_EVENT_WAIT(kXPU, kXPU) +USE_EVENT_WAIT(kCPU, kXPU) +#endif diff --git a/paddle/fluid/platform/device_event_xpu.cc b/paddle/fluid/platform/device_event_xpu.cc new file mode 100644 index 00000000000..53ac33e321b --- /dev/null +++ b/paddle/fluid/platform/device_event_xpu.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/xpu/xpu_info.h" +#include "paddle/fluid/platform/device_event_base.h" + +#ifdef PADDLE_WITH_XPU +namespace paddle { +namespace platform { + +struct XPUDeviceEventWrapper { + explicit XPUDeviceEventWrapper(const platform::Place& place) { + PADDLE_ENFORCE_EQ( + platform::is_xpu_place(place), true, + platform::errors::PreconditionNotMet( + "Required device shall be XPUPlace, but received %d. ", place)); + + device_id_ = place.device; + PADDLE_ENFORCE_GT( + device_id_, -1, + platform::errors::PreconditionNotMet( + "Required DeviceOption.device_id > -1, but received %d. ", + device_id_)); + xpu_event_create(&handle_); + } + + xpuEventHandle handle_; + int device_id_; +}; + +void DeviceEventCreateXPU(DeviceEvent* event, const platform::Place& place, + unsigned int) { + event->InitEvent(std::make_shared(place)); +} + +void DeviceEventRecordXPU(DeviceEvent* event, const DeviceContext* context) { + auto* wrapper = static_cast(event->GetEvent().get()); + PADDLE_ENFORCE_NOT_NULL( + wrapper, platform::errors::PreconditionNotMet( + "Failed to dynamic_cast event into XPUDeviceEventWrapper.")); + + auto* xpu_dev_ctx = dynamic_cast(context); + PADDLE_ENFORCE_NOT_NULL( + xpu_dev_ctx, + platform::errors::PreconditionNotMet( + "Failed to dynamic_cast context into XPUDeviceContext.")); + xpu_event_record(wrapper->handle_, xpu_dev_ctx->stream()); +} + +void DeviceEventFinishXPU(const DeviceEvent* event) { + auto* wrapper = static_cast(event->GetEvent().get()); + PADDLE_ENFORCE_NOT_NULL( + wrapper, platform::errors::PreconditionNotMet( + "Failed to dynamic_cast event into XPUDeviceEventWrapper.")); + xpu_event_wait(wrapper->handle_); +} + +// current xpu not support query, used wait to instead. +bool DeviceEventQueryXPU(const DeviceEvent* event) { + DeviceEventFinishXPU(event); + return true; +} + +void DeviceEventXPUWaitXPU(const DeviceEvent* event, + const DeviceContext* context) { + auto* wrapper = static_cast(event->GetEvent().get()); + PADDLE_ENFORCE_NOT_NULL( + wrapper, platform::errors::PreconditionNotMet( + "Failed to dynamic_cast event into XPUDeviceEventWrapper.")); + auto* xpu_dev_ctx = dynamic_cast(context); + PADDLE_ENFORCE_NOT_NULL( + xpu_dev_ctx, + platform::errors::PreconditionNotMet( + "Failed to dynamic_cast context into XOUDeviceContext.")); + xpu_stream_wait_event(xpu_dev_ctx->stream(), wrapper->handle_); +} + +void DeviceEventCPUWaitXPU(const DeviceEvent* event, + const DeviceContext* context) { + DeviceEventFinishXPU(event); +} + +void DeviceEventSetFinishedXPU(const DeviceEvent* event) { + // do nothing +} + +void EventResetXPU(const DeviceEvent* event) { + // do nothing +} + +} // namespace platform +} // namespace paddle + +using ::paddle::platform::kCPU; +using ::paddle::platform::kXPU; +REGISTER_EVENT_CREATE_FUNCTION(kXPU, paddle::platform::DeviceEventCreateXPU) +REGISTER_EVENT_RECORD_FUNCTION(kXPU, paddle::platform::DeviceEventRecordXPU) +REGISTER_EVENT_QUERY_FUNCTION(kXPU, paddle::platform::DeviceEventQueryXPU) +REGISTER_EVENT_FINISH_FUNCTION(kXPU, paddle::platform::DeviceEventFinishXPU) +REGISTER_EVENT_SET_FINISHED_FUNCTION( + kXPU, paddle::platform::DeviceEventSetFinishedXPU) +REGISTER_EVENT_WAIT_FUNCTION(kXPU, kXPU, + paddle::platform::DeviceEventXPUWaitXPU) +REGISTER_EVENT_WAIT_FUNCTION(kCPU, kXPU, + paddle::platform::DeviceEventCPUWaitXPU) +REGISTER_EVENT_RESET_FUNCTION(kXPU, paddle::platform::EventResetXPU) +#endif diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 860b4e3f558..62578eef86c 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1392,9 +1392,9 @@ class Executor(object): program = pruned_program def _can_use_interpreter_core(program, place): - if core.is_compiled_with_npu() or core.is_compiled_with_xpu( - ) or core.is_compiled_with_mlu() or core.is_compiled_with_ipu( - ) or isinstance(place, core.CustomPlace): + if core.is_compiled_with_npu() or core.is_compiled_with_mlu( + ) or core.is_compiled_with_ipu() or isinstance( + place, core.CustomPlace): return False compiled = isinstance(program, compiler.CompiledProgram) -- GitLab