From 2b5771c4af99b003ee00c06678de6e257b0a6229 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Wed, 15 Jun 2022 20:39:51 +0800 Subject: [PATCH] op cache supports un-persistable attributes (#43221) --- paddle/fluid/framework/operator.cc | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7395a8e0da8..dbf6bec676c 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1529,8 +1529,20 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // Do data transform before building KernelContext // TODO(zhiqiu): support TransferInplaceVarsBack PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx); - BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); - (*pt_kernel_)(&pt_kernel_context); + if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && + !need_prepare_data_) { + impl_ = + new CacheImpl(new phi::KernelContext(), + new RuntimeInferShapeContext(*this, *runtime_ctx)); + BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext()); + (*pt_kernel_)(impl_->getKernelContext()); + } else { + phi::KernelContext pt_kernel_context; + // Do data transform before building KernelContext + // TODO(zhiqiu): support TransferInplaceVarsBack + BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context); + (*pt_kernel_)(&pt_kernel_context); + } } else { (*kernel_func_)( ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); @@ -2386,7 +2398,6 @@ void OperatorWithKernel::BuildPhiKernelContext( // calcute the start and end index of the input tensors size_t start_idx = (i == 0 ? 0 : pt_kernel_context->InputRangeAt(i - 1).second); - // deal with optional here if ((it == ctx.inputs.end() || it->second.size() == 0) && (input_defs[i].type_index == @@ -2400,6 +2411,7 @@ void OperatorWithKernel::BuildPhiKernelContext( auto end_idx = start_idx + 1; pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); + continue; } auto ins_vector = it->second; @@ -2414,6 +2426,7 @@ void OperatorWithKernel::BuildPhiKernelContext( tensor_in = &(var->Get()); pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } else if (var->IsType()) { + need_prepare_phi_data_ = true; paddle::small_vector tensor_vector; auto& tensor_array = var->Get(); for (auto& t : tensor_array) { @@ -2514,6 +2527,7 @@ void OperatorWithKernel::BuildPhiKernelContext( attr_names[i])); } } else { // scalar is in the input + need_prepare_phi_data_ = true; auto& ins_vector = ctx.inputs.at(attr_names[i]); pt_kernel_context->EmplaceBackAttr(std::move( experimental::MakePhiScalarFromVar(*ins_vector.front()))); @@ -2545,6 +2559,7 @@ void OperatorWithKernel::BuildPhiKernelContext( attr_names[i])); } } else { // shape is in the input + need_prepare_phi_data_ = true; auto& ins_vector = ctx.inputs.at(attr_names[i]); if (ins_vector.size() == 1) { // ShapeTensor pt_kernel_context->EmplaceBackAttr(std::move( -- GitLab