未验证 提交 5df92262 编写于 作者: H huzhiqiang 提交者: GitHub

[Framework]accelerate inference period (#42400)

上级 8cc40f47
......@@ -1116,6 +1116,21 @@ class RuntimeInferShapeContext : public InferShapeContext {
const RuntimeContext& ctx_;
};
struct OperatorWithKernel::CacheImpl {
explicit CacheImpl(phi::KernelContext* kernel_ctx,
RuntimeInferShapeContext* infer_shape_ctx)
: kernel_ctx_(kernel_ctx), infer_shape_ctx_(infer_shape_ctx) {}
phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); }
RuntimeInferShapeContext* getRuntimeInferShapeContext() {
return infer_shape_ctx_.get();
}
private:
std::unique_ptr<phi::KernelContext> kernel_ctx_;
std::unique_ptr<RuntimeInferShapeContext> infer_shape_ctx_;
};
static void CheckTensorNANOrInf(const std::string& op_type,
const std::string& name,
const framework::Tensor& tensor) {
......@@ -1244,6 +1259,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx);
pre_scope_ = cur_scope;
} else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ &&
!need_prepare_phi_data_) {
if (!all_kernels_must_compute_runtime_shape_)
this->Info().infer_shape_(impl_->getRuntimeInferShapeContext());
(*pt_kernel_)(impl_->getKernelContext());
} else {
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
std::lock_guard<std::mutex> lock(cache_update_mutex_);
......@@ -1508,12 +1528,22 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
if (run_phi_kernel_) {
PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx);
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) {
impl_ =
new CacheImpl(new phi::KernelContext(),
new RuntimeInferShapeContext(*this, *runtime_ctx));
BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext());
(*pt_kernel_)(impl_->getKernelContext());
} else {
phi::KernelContext pt_kernel_context;
// Do data transform before building KernelContext
// TODO(zhiqiu): support TransferInplaceVarsBack
PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx);
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
(*pt_kernel_)(&pt_kernel_context);
}
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
......@@ -2323,6 +2353,8 @@ Scope* OperatorWithKernel::PreparePhiData(
Tensor out;
framework::TensorCopySync(*tensor_in, expected_place, &out);
SetTensorToVariable(*var, out, trans_var);
need_prepare_phi_data_ = true;
}
}
......
......@@ -698,6 +698,7 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
mutable bool need_prepare_data_ = true;
mutable bool need_prepare_phi_data_ = false;
mutable bool enable_cache_runtime_context_ = false;
mutable bool all_kernels_must_compute_runtime_shape_ = false;
mutable std::mutex cache_update_mutex_;
......@@ -710,6 +711,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
mutable std::unique_ptr<phi::Kernel> pt_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
struct CacheImpl;
mutable CacheImpl* impl_{nullptr};
};
extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册