未验证 提交 5df92262 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework]accelerate inference period (#42400)

上级 8cc40f47
...@@ -1116,6 +1116,21 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -1116,6 +1116,21 @@ class RuntimeInferShapeContext : public InferShapeContext {
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
}; };
struct OperatorWithKernel::CacheImpl {
explicit CacheImpl(phi::KernelContext* kernel_ctx,
RuntimeInferShapeContext* infer_shape_ctx)
: kernel_ctx_(kernel_ctx), infer_shape_ctx_(infer_shape_ctx) {}
phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); }
RuntimeInferShapeContext* getRuntimeInferShapeContext() {
return infer_shape_ctx_.get();
}
private:
std::unique_ptr<phi::KernelContext> kernel_ctx_;
std::unique_ptr<RuntimeInferShapeContext> infer_shape_ctx_;
};
static void CheckTensorNANOrInf(const std::string& op_type, static void CheckTensorNANOrInf(const std::string& op_type,
const std::string& name, const std::string& name,
const framework::Tensor& tensor) { const framework::Tensor& tensor) {
...@@ -1244,6 +1259,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1244,6 +1259,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeContext ctx(Inputs(), Outputs(), scope); RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx); RunImpl(scope, place, &ctx);
pre_scope_ = cur_scope; pre_scope_ = cur_scope;
} else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ &&
!need_prepare_phi_data_) {
if (!all_kernels_must_compute_runtime_shape_)
this->Info().infer_shape_(impl_->getRuntimeInferShapeContext());
(*pt_kernel_)(impl_->getKernelContext());
} else { } else {
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
std::lock_guard<std::mutex> lock(cache_update_mutex_); std::lock_guard<std::mutex> lock(cache_update_mutex_);
...@@ -1508,12 +1528,22 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1508,12 +1528,22 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp); 1, platform::EventRole::kInnerOp);
if (run_phi_kernel_) { if (run_phi_kernel_) {
phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx); PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx);
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
(*pt_kernel_)(&pt_kernel_context); !need_prepare_data_) {
impl_ =
new CacheImpl(new phi::KernelContext(),
new RuntimeInferShapeContext(*this, *runtime_ctx));
BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext());
(*pt_kernel_)(impl_->getKernelContext());
} else {
phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context);
}
} else { } else {
(*kernel_func_)( (*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
...@@ -2323,6 +2353,8 @@ Scope* OperatorWithKernel::PreparePhiData( ...@@ -2323,6 +2353,8 @@ Scope* OperatorWithKernel::PreparePhiData(
Tensor out; Tensor out;
framework::TensorCopySync(*tensor_in, expected_place, &out); framework::TensorCopySync(*tensor_in, expected_place, &out);
SetTensorToVariable(*var, out, trans_var); SetTensorToVariable(*var, out, trans_var);
need_prepare_phi_data_ = true;
} }
} }
......
...@@ -698,6 +698,7 @@ class OperatorWithKernel : public OperatorBase { ...@@ -698,6 +698,7 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<RuntimeContext> runtime_ctx_; mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr; mutable const Scope* pre_scope_ = nullptr;
mutable bool need_prepare_data_ = true; mutable bool need_prepare_data_ = true;
mutable bool need_prepare_phi_data_ = false;
mutable bool enable_cache_runtime_context_ = false; mutable bool enable_cache_runtime_context_ = false;
mutable bool all_kernels_must_compute_runtime_shape_ = false; mutable bool all_kernels_must_compute_runtime_shape_ = false;
mutable std::mutex cache_update_mutex_; mutable std::mutex cache_update_mutex_;
...@@ -710,6 +711,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -710,6 +711,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<phi::KernelSignature> kernel_signature_; mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_; mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_; mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
struct CacheImpl;
mutable CacheImpl* impl_{nullptr};
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册