From 5df922621017f1983d11e76808b8e962d6f1b96d Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Mon, 30 May 2022 16:02:46 +0800 Subject: [PATCH] [Framework]accelerate inference period (#42400) --- paddle/fluid/framework/operator.cc | 42 ++++++++++++++++++++++++++---- paddle/fluid/framework/operator.h | 4 +++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index afd1bf338c4..7dc885f54ab 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1116,6 +1116,21 @@ class RuntimeInferShapeContext : public InferShapeContext { const RuntimeContext& ctx_; }; +struct OperatorWithKernel::CacheImpl { + explicit CacheImpl(phi::KernelContext* kernel_ctx, + RuntimeInferShapeContext* infer_shape_ctx) + : kernel_ctx_(kernel_ctx), infer_shape_ctx_(infer_shape_ctx) {} + + phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); } + RuntimeInferShapeContext* getRuntimeInferShapeContext() { + return infer_shape_ctx_.get(); + } + + private: + std::unique_ptr kernel_ctx_; + std::unique_ptr infer_shape_ctx_; +}; + static void CheckTensorNANOrInf(const std::string& op_type, const std::string& name, const framework::Tensor& tensor) { @@ -1244,6 +1259,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, RuntimeContext ctx(Inputs(), Outputs(), scope); RunImpl(scope, place, &ctx); pre_scope_ = cur_scope; + } else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ && + !need_prepare_phi_data_) { + if (!all_kernels_must_compute_runtime_shape_) + this->Info().infer_shape_(impl_->getRuntimeInferShapeContext()); + (*pt_kernel_)(impl_->getKernelContext()); } else { if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { std::lock_guard lock(cache_update_mutex_); @@ -1508,12 +1528,22 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); if (run_phi_kernel_) { - phi::KernelContext pt_kernel_context; - // Do data transform before building KernelContext - // TODO(zhiqiu): support TransferInplaceVarsBack PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx); - BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); - (*pt_kernel_)(&pt_kernel_context); + if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && + !need_prepare_data_) { + impl_ = + new CacheImpl(new phi::KernelContext(), + new RuntimeInferShapeContext(*this, *runtime_ctx)); + BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext()); + + (*pt_kernel_)(impl_->getKernelContext()); + } else { + phi::KernelContext pt_kernel_context; + // Do data transform before building KernelContext + // TODO(zhiqiu): support TransferInplaceVarsBack + BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); + (*pt_kernel_)(&pt_kernel_context); + } } else { (*kernel_func_)( ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); @@ -2323,6 +2353,8 @@ Scope* OperatorWithKernel::PreparePhiData( Tensor out; framework::TensorCopySync(*tensor_in, expected_place, &out); SetTensorToVariable(*var, out, trans_var); + + need_prepare_phi_data_ = true; } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 2e00e07535b..2efa2e4bd8a 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -698,6 +698,7 @@ class OperatorWithKernel : public OperatorBase { mutable std::unique_ptr runtime_ctx_; mutable const Scope* pre_scope_ = nullptr; mutable bool need_prepare_data_ = true; + mutable bool need_prepare_phi_data_ = false; mutable bool enable_cache_runtime_context_ = false; mutable bool all_kernels_must_compute_runtime_shape_ = false; mutable std::mutex cache_update_mutex_; @@ -710,6 +711,9 @@ class OperatorWithKernel : public OperatorBase { mutable std::unique_ptr kernel_signature_; mutable std::unique_ptr pt_kernel_; mutable std::unique_ptr arg_map_fn_; + + struct CacheImpl; + mutable CacheImpl* impl_{nullptr}; }; extern bool OpSupportGPU(const std::string& op_type); -- GitLab