未验证 提交 f88713e1 编写于 作者: W Wilber 提交者: GitHub

[Inference] Enable infer shape cache. (#48312)

上级 fe86771a
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/runtime_context_cache_pass.h" #include "paddle/fluid/framework/ir/runtime_context_cache_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -21,10 +22,28 @@ namespace framework { ...@@ -21,10 +22,28 @@ namespace framework {
namespace ir { namespace ir {
void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const { void RuntimeContextCachePass::ApplyImpl(ir::Graph* graph) const {
static constexpr char kNotAllowInferShapeCahce[] =
"@NOT_ALLOW_INFERSHAPE_CACHE@";
VLOG(3) << "Applies Runtime Context Cache strategy."; VLOG(3) << "Applies Runtime Context Cache strategy.";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()) { if (n->IsOp() && n->Op()) {
n->Op()->SetAttr(kEnableCacheRuntimeContext, true); n->Op()->SetAttr(framework::kEnableCacheRuntimeContext, true);
}
}
// if op1 -> var0 and op2 -> var0, then op1 and op2 not support
// InferShapeCache.
std::unordered_map<std::string, std::vector<Node*>> var2ops;
for (auto* op_node : TopologySortOperations(*graph)) {
for (auto* var_node : op_node->outputs) {
var2ops[var_node->Name()].push_back(op_node);
}
}
for (auto& it : var2ops) {
if (it.second.size() > 1) {
for (auto op_node : it.second) {
op_node->Op()->SetAttr(kNotAllowInferShapeCahce, true);
}
} }
} }
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
...@@ -36,6 +37,7 @@ limitations under the License. */ ...@@ -36,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/supplement_tracing.h" #include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/ops/compat/signatures.h" #include "paddle/phi/ops/compat/signatures.h"
...@@ -562,6 +564,14 @@ phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) { ...@@ -562,6 +564,14 @@ phi::DenseTensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
} }
} }
OperatorWithKernel::OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
OperatorWithKernel::~OperatorWithKernel() = default;
bool ExecutionContext::HasInput(const std::string& name) const { bool ExecutionContext::HasInput(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
return var != nullptr; return var != nullptr;
...@@ -1204,19 +1214,54 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -1204,19 +1214,54 @@ class RuntimeInferShapeContext : public InferShapeContext {
}; };
struct OperatorWithKernel::CacheImpl { struct OperatorWithKernel::CacheImpl {
static const char kNotAllowInferShapeCahce[];
explicit CacheImpl(phi::KernelContext* kernel_ctx, explicit CacheImpl(phi::KernelContext* kernel_ctx,
RuntimeInferShapeContext* infer_shape_ctx) RuntimeInferShapeContext* infer_shape_ctx,
: kernel_ctx_(kernel_ctx), infer_shape_ctx_(infer_shape_ctx) {} const std::vector<phi::DenseTensor*>& tensors,
bool not_allow_infer_shape_cache)
: kernel_ctx_(kernel_ctx),
infer_shape_ctx_(infer_shape_ctx),
tensors_(tensors),
not_allow_infer_shape_cache_(not_allow_infer_shape_cache) {}
phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); } phi::KernelContext* getKernelContext() { return kernel_ctx_.get(); }
RuntimeInferShapeContext* getRuntimeInferShapeContext() { RuntimeInferShapeContext* getRuntimeInferShapeContext() {
return infer_shape_ctx_.get(); return infer_shape_ctx_.get();
} }
bool NeedInferShape() {
if (not_allow_infer_shape_cache_) return true;
bool ret{false};
if (last_ddims_.empty() || tensors_.empty()) ret = true;
if (!ret) {
CHECK_EQ(last_ddims_.size(), tensors_.size());
for (size_t i = 0; i < last_ddims_.size(); ++i) {
if (tensors_[i]->dims() != last_ddims_[i]) {
ret = true;
break;
}
}
}
if (ret) {
last_ddims_.resize(tensors_.size());
for (size_t i = 0; i < last_ddims_.size(); ++i) {
last_ddims_[i] = tensors_[i]->dims();
}
}
VLOG(3) << "need infer shape is " << ret;
return ret;
}
private: private:
std::unique_ptr<phi::KernelContext> kernel_ctx_; std::unique_ptr<phi::KernelContext> kernel_ctx_;
std::unique_ptr<RuntimeInferShapeContext> infer_shape_ctx_; std::unique_ptr<RuntimeInferShapeContext> infer_shape_ctx_;
std::vector<phi::DenseTensor*> tensors_;
bool not_allow_infer_shape_cache_;
std::vector<phi::DDim> last_ddims_;
}; };
const char OperatorWithKernel::CacheImpl::kNotAllowInferShapeCahce[] =
"@NOT_ALLOW_INFERSHAPE_CACHE@";
static void CheckTensorNANOrInf(const std::string& op_type, static void CheckTensorNANOrInf(const std::string& op_type,
const std::string& name, const std::string& name,
...@@ -1524,8 +1569,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1524,8 +1569,9 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
pre_scope_ = cur_scope; pre_scope_ = cur_scope;
} else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ && } else if (run_phi_kernel_ && impl_ != nullptr && !need_prepare_data_ &&
!need_prepare_phi_data_) { !need_prepare_phi_data_) {
if (!all_kernels_must_compute_runtime_shape_) if (!all_kernels_must_compute_runtime_shape_ && impl_->NeedInferShape()) {
this->Info().infer_shape_(impl_->getRuntimeInferShapeContext()); this->Info().infer_shape_(impl_->getRuntimeInferShapeContext());
}
(*phi_kernel_)(impl_->getKernelContext()); (*phi_kernel_)(impl_->getKernelContext());
} else { } else {
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
...@@ -1828,9 +1874,31 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1828,9 +1874,31 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) { !need_prepare_data_) {
impl_ = // TODO(inference): Now we only suppor dense_tensor cache, we may be
// support ScalarTensor, SparseTensor in future.
bool all_dense_tensor_input_{true};
for (auto& iter : Inputs()) {
for (auto& name : iter.second) {
all_dense_tensor_input_ &=
scope.FindVar(name)->IsType<phi::DenseTensor>();
}
}
std::vector<phi::DenseTensor*> tensors;
if (all_dense_tensor_input_) {
for (auto& iter : Inputs()) {
for (auto& name : iter.second) {
auto* t = scope.FindVar(name)->GetMutable<phi::DenseTensor>();
tensors.push_back(t);
}
}
}
impl_.reset(
new CacheImpl(new phi::KernelContext(), new CacheImpl(new phi::KernelContext(),
new RuntimeInferShapeContext(*this, *runtime_ctx)); new RuntimeInferShapeContext(*this, *runtime_ctx),
tensors,
HasAttr(CacheImpl::kNotAllowInferShapeCahce)));
BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext()); BuildPhiKernelContext(*runtime_ctx, dev_ctx, impl_->getKernelContext());
(*phi_kernel_)(impl_->getKernelContext()); (*phi_kernel_)(impl_->getKernelContext());
} else { } else {
...@@ -3246,6 +3314,7 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -3246,6 +3314,7 @@ void OperatorWithKernel::BuildPhiKernelContext(
if (phi::OneDNNContext::classof(dev_ctx)) { if (phi::OneDNNContext::classof(dev_ctx)) {
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx); phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
one_dnn_ctx->ClearDnnAttr(); one_dnn_ctx->ClearDnnAttr();
if (!RuntimeAttrs().empty()) need_prepare_phi_data_ = true;
} }
#endif #endif
...@@ -3267,7 +3336,6 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -3267,7 +3336,6 @@ void OperatorWithKernel::BuildPhiKernelContext(
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
auto& runtime_attrs = RuntimeAttrs(); auto& runtime_attrs = RuntimeAttrs();
for (const auto& attr_iter : runtime_attrs) { for (const auto& attr_iter : runtime_attrs) {
need_prepare_phi_data_ = true;
auto& attr_name = attr_iter.first; auto& attr_name = attr_iter.first;
auto& attr = attr_iter.second; auto& attr = attr_iter.second;
auto attr_propertys = paddle::operators::GetExtraAttrProperties(attr_name); auto attr_propertys = paddle::operators::GetExtraAttrProperties(attr_name);
......
...@@ -612,8 +612,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -612,8 +612,9 @@ class OperatorWithKernel : public OperatorBase {
OperatorWithKernel(const std::string& type, OperatorWithKernel(const std::string& type,
const VariableNameMap& inputs, const VariableNameMap& inputs,
const VariableNameMap& outputs, const VariableNameMap& outputs,
const AttributeMap& attrs) const AttributeMap& attrs);
: OperatorBase(type, inputs, outputs, attrs) {}
virtual ~OperatorWithKernel();
static paddle::flat_hash_map<std::string /* op_type */, OpKernelMap>& static paddle::flat_hash_map<std::string /* op_type */, OpKernelMap>&
AllOpKernels() { AllOpKernels() {
...@@ -785,8 +786,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -785,8 +786,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<phi::Kernel> phi_kernel_; mutable std::unique_ptr<phi::Kernel> phi_kernel_;
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_; mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
private:
struct CacheImpl; struct CacheImpl;
mutable CacheImpl* impl_{nullptr}; mutable std::unique_ptr<CacheImpl> impl_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
...@@ -23,6 +23,8 @@ namespace inference { ...@@ -23,6 +23,8 @@ namespace inference {
namespace analysis { namespace analysis {
void IrGraphToProgramPass::RunImpl(Argument *argument) { void IrGraphToProgramPass::RunImpl(Argument *argument) {
auto cache_pass =
framework::ir::PassRegistry::Instance().Get("runtime_context_cache_pass");
auto pass = auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass"); framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
...@@ -31,14 +33,12 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { ...@@ -31,14 +33,12 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
new int(argument->memory_optim_sort_kind())); new int(argument->memory_optim_sort_kind()));
} }
std::unique_ptr<framework::ir::Graph> graph(argument->main_graph_ptr());
// Direct using ProgramDesc desc(argument->main_program()) may cause // Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information. // incomplete copies of information.
framework::ProgramDesc desc; framework::ProgramDesc desc;
desc.CopyFrom(*argument->main_program().Proto()); desc.CopyFrom(*argument->main_program().Proto());
pass->SetNotOwned("program", &desc); pass->SetNotOwned("program", &desc);
pass->Apply(graph.release()); // the argument still own the graph. pass->Apply(cache_pass->Apply(argument->main_graph_ptr()));
argument->SetIrAnalyzedProgram( argument->SetIrAnalyzedProgram(
new framework::proto::ProgramDesc(*desc.Proto())); new framework::proto::ProgramDesc(*desc.Proto()));
......
...@@ -188,7 +188,6 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{ ...@@ -188,7 +188,6 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fc_fuse_pass", "fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass", "fc_elementwise_layernorm_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass", "embedding_eltwise_layernorm_fuse_pass",
"runtime_context_cache_pass",
}; };
const std::vector<std::string> kTrtLowerPrecisionPasses{ const std::vector<std::string> kTrtLowerPrecisionPasses{
...@@ -254,10 +253,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -254,10 +253,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif // #endif //
"transpose_flatten_concat_fuse_pass", // "transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", // "constant_folding_pass", //
// following pass should be located in the last, since it will "float_to_half_pass", //
// work on all fused ops.
"float_to_half_pass", //
"runtime_context_cache_pass"
}); });
use_gpu_ = true; use_gpu_ = true;
...@@ -322,10 +318,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -322,10 +318,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_transpose_bn_fuse_pass", // "conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", //
"is_test_pass", // "is_test_pass", //
"constant_folding_pass", "constant_folding_pass"});
// following pass should be located in the last, since
// it will work on all fused ops.
"runtime_context_cache_pass"});
use_gpu_ = false; use_gpu_ = false;
} }
...@@ -475,7 +468,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -475,7 +468,6 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("int8_scale_calculation_mkldnn_pass"); passes_.push_back("int8_scale_calculation_mkldnn_pass");
passes_.push_back("params_quantization_mkldnn_pass"); passes_.push_back("params_quantization_mkldnn_pass");
passes_.push_back("mkldnn_inplace_pass"); passes_.push_back("mkldnn_inplace_pass");
passes_.push_back("runtime_context_cache_pass");
} }
use_mkldnn_int8_ = true; use_mkldnn_int8_ = true;
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册