From f88713e1707d0f2b2806d21c13973035ea19a796 Mon Sep 17 00:00:00 2001 From: Wilber Date: Thu, 8 Dec 2022 12:56:34 +0800 Subject: [PATCH] [Inference] Enable infer shape cache. (#48312) --- .../ir/runtime_context_cache_pass.cc | 21 ++++- paddle/fluid/framework/operator.cc | 80 +++++++++++++++++-- paddle/fluid/framework/operator.h | 8 +- .../passes/ir_graph_to_program_pass.cc | 6 +- .../inference/api/paddle_pass_builder.cc | 12 +-- 5 files changed, 104 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc index 451e41e767..4f5e5edb89 100644 --- a/paddle/fluid/framework/ir/runtime_context_cache_pass.cc +++ b/paddle/fluid/framework/ir/runtime_context_cache_pass.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/runtime_context_cache_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/operator.h" namespace paddle { @@ -21,10 +22,28 @@ namespace framework { namespace ir { void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const { + static constexpr char kNotAllowInferShapeCahce[] = + "@NOT_ALLOW_INFERSHAPE_CACHE@"; VLOG(3) << "Applies Runtime Context Cache strategy."; for (const Node* n : graph->Nodes()) { if (n->IsOp() && n->Op()) { - n->Op()->SetAttr(kEnableCacheRuntimeContext, true); + n->Op()->SetAttr(framework::kEnableCacheRuntimeContext, true); + } + } + + // if op1 -> var0 and op2 -> var0, then op1 and op2 not support + // InferShapeCache. + std::unordered_map> var2ops; + for (auto* op_node : TopologySortOperations(*graph)) { + for (auto* var_node : op_node->outputs) { + var2ops[var_node->Name()].push_back(op_node); + } + } + for (auto& it : var2ops) { + if (it.second.size() > 1) { + for (auto op_node : it.second) { + op_node->Op()->SetAttr(kNotAllowInferShapeCahce, true); + } } } } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 538a76e738..19d0c6ea0d 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include #include "gflags/gflags.h" #include "paddle/fluid/framework/convert_utils.h" @@ -36,6 +37,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler/supplement_tracing.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/ops/compat/signatures.h" @@ -562,6 +564,14 @@ phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) { } } +OperatorWithKernel::OperatorWithKernel(const std::string& type, + const VariableNameMap& inputs, + const VariableNameMap& outputs, + const AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + +OperatorWithKernel::~OperatorWithKernel() = default; + bool ExecutionContext::HasInput(const std::string& name) const { auto* var = InputVar(name); return var != nullptr; @@ -1204,19 +1214,54 @@ class RuntimeInferShapeContext : public InferShapeContext { }; struct OperatorWithKernel::CacheImpl { + static const char kNotAllowInferShapeCahce[]; explicit CacheImpl(phi::KernelContext* kernel_ctx, - RuntimeInferShapeContext* infer_shape_ctx) - : kernel_ctx_(kernel_ctx), infer_shape_ctx_(infer_shape_ctx) {} + RuntimeInferShapeContext* infer_shape_ctx, + const std::vector& tensors, + bool not_allow_infer_shape_cache) + : kernel_ctx_(kernel_ctx), + infer_shape_ctx_(infer_shape_ctx), + tensors_(tensors), + not_allow_infer_shape_cache_(not_allow_infer_shape_cache) {} phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); } RuntimeInferShapeContext* getRuntimeInferShapeContext() { return infer_shape_ctx_.get(); } + bool NeedInferShape() { + if (not_allow_infer_shape_cache_) return true; + + bool ret{false}; + if (last_ddims_.empty() || tensors_.empty()) ret = true; + if (!ret) { + CHECK_EQ(last_ddims_.size(), tensors_.size()); + for (size_t i = 0; i < last_ddims_.size(); ++i) { + if (tensors_[i]->dims() != last_ddims_[i]) { + ret = true; + break; + } + } + } + if (ret) { + last_ddims_.resize(tensors_.size()); + for (size_t i = 0; i < last_ddims_.size(); ++i) { + last_ddims_[i] = tensors_[i]->dims(); + } + } + VLOG(3) << "need infer shape is " << ret; + return ret; + } + private: std::unique_ptr kernel_ctx_; std::unique_ptr infer_shape_ctx_; + std::vector tensors_; + bool not_allow_infer_shape_cache_; + std::vector last_ddims_; }; +const char OperatorWithKernel::CacheImpl::kNotAllowInferShapeCahce[] = + "@NOT_ALLOW_INFERSHAPE_CACHE@"; static void CheckTensorNANOrInf(const std::string& op_type, const std::string& name, @@ -1524,8 +1569,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, pre_scope_ = cur_scope; } else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ && !need_prepare_phi_data_) { - if (!all_kernels_must_compute_runtime_shape_) + if (!all_kernels_must_compute_runtime_shape_ && impl_->NeedInferShape()) { this->Info().infer_shape_(impl_->getRuntimeInferShapeContext()); + } (*phi_kernel_)(impl_->getKernelContext()); } else { if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { @@ -1828,9 +1874,31 @@ void OperatorWithKernel::RunImpl(const Scope& scope, phi::KernelContext phi_kernel_context; if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && !need_prepare_data_) { - impl_ = + // TODO(inference): Now we only suppor dense_tensor cache, we may be + // support ScalarTensor, SparseTensor in future. + bool all_dense_tensor_input_{true}; + for (auto& iter : Inputs()) { + for (auto& name : iter.second) { + all_dense_tensor_input_ &= + scope.FindVar(name)->IsType(); + } + } + + std::vector tensors; + if (all_dense_tensor_input_) { + for (auto& iter : Inputs()) { + for (auto& name : iter.second) { + auto* t = scope.FindVar(name)->GetMutable(); + tensors.push_back(t); + } + } + } + + impl_.reset( new CacheImpl(new phi::KernelContext(), - new RuntimeInferShapeContext(*this, *runtime_ctx)); + new RuntimeInferShapeContext(*this, *runtime_ctx), + tensors, + HasAttr(CacheImpl::kNotAllowInferShapeCahce))); BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext()); (*phi_kernel_)(impl_->getKernelContext()); } else { @@ -3246,6 +3314,7 @@ void OperatorWithKernel::BuildPhiKernelContext( if (phi::OneDNNContext::classof(dev_ctx)) { phi::OneDNNContext* one_dnn_ctx = static_cast(dev_ctx); one_dnn_ctx->ClearDnnAttr(); + if (!RuntimeAttrs().empty()) need_prepare_phi_data_ = true; } #endif @@ -3267,7 +3336,6 @@ void OperatorWithKernel::BuildPhiKernelContext( #if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA) auto& runtime_attrs = RuntimeAttrs(); for (const auto& attr_iter : runtime_attrs) { - need_prepare_phi_data_ = true; auto& attr_name = attr_iter.first; auto& attr = attr_iter.second; auto attr_propertys = paddle::operators::GetExtraAttrProperties(attr_name); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 236ff7af8d..07e1a26c7c 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -612,8 +612,9 @@ class OperatorWithKernel : public OperatorBase { OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, - const AttributeMap& attrs) - : OperatorBase(type, inputs, outputs, attrs) {} + const AttributeMap& attrs); + + virtual ~OperatorWithKernel(); static paddle::flat_hash_map& AllOpKernels() { @@ -785,8 +786,9 @@ class OperatorWithKernel : public OperatorBase { mutable std::unique_ptr phi_kernel_; mutable std::unique_ptr arg_map_fn_; + private: struct CacheImpl; - mutable CacheImpl* impl_{nullptr}; + mutable std::unique_ptr impl_; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc index 3d86f7bf39..2f7f61406b 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc @@ -23,6 +23,8 @@ namespace inference { namespace analysis { void IrGraphToProgramPass::RunImpl(Argument *argument) { + auto cache_pass = + framework::ir::PassRegistry::Instance().Get("runtime_context_cache_pass"); auto pass = framework::ir::PassRegistry::Instance().Get("graph_to_program_pass"); @@ -31,14 +33,12 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { new int(argument->memory_optim_sort_kind())); } - std::unique_ptr graph(argument->main_graph_ptr()); - // Direct using ProgramDesc desc(argument->main_program()) may cause // incomplete copies of information. framework::ProgramDesc desc; desc.CopyFrom(*argument->main_program().Proto()); pass->SetNotOwned("program", &desc); - pass->Apply(graph.release()); // the argument still own the graph. + pass->Apply(cache_pass->Apply(argument->main_graph_ptr())); argument->SetIrAnalyzedProgram( new framework::proto::ProgramDesc(*desc.Proto())); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 4e397fbd04..2fa9620542 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -188,7 +188,6 @@ const std::vector kGpuLowerPrecisionPasses{ "fc_fuse_pass", "fc_elementwise_layernorm_fuse_pass", "embedding_eltwise_layernorm_fuse_pass", - "runtime_context_cache_pass", }; const std::vector kTrtLowerPrecisionPasses{ @@ -254,10 +253,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { #endif // "transpose_flatten_concat_fuse_pass", // "constant_folding_pass", // - // following pass should be located in the last, since it will - // work on all fused ops. - "float_to_half_pass", // - "runtime_context_cache_pass" + "float_to_half_pass", // }); use_gpu_ = true; @@ -322,10 +318,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "conv_transpose_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", // "is_test_pass", // - "constant_folding_pass", - // following pass should be located in the last, since - // it will work on all fused ops. - "runtime_context_cache_pass"}); + "constant_folding_pass"}); use_gpu_ = false; } @@ -475,7 +468,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("int8_scale_calculation_mkldnn_pass"); passes_.push_back("params_quantization_mkldnn_pass"); passes_.push_back("mkldnn_inplace_pass"); - passes_.push_back("runtime_context_cache_pass"); } use_mkldnn_int8_ = true; #else -- GitLab