diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index d2c63d5b21008363e8c91f96b96c586f57384b5e..f47b024973ba7899ebf5040a09702f5bab83fe32 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -409,6 +409,8 @@ void VarBase::_CopyGradientFrom(const VarBase& src) { } } +pten::KernelContext OpBase::pt_kernel_context_; + void OpBase::SetType(const std::string& type) { op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); } diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 3d0847605566b0f6df63eb48db4cb38df18dc8da..8d27e4f42a5ad5b9126a0d6b68a7d3178d00d37a 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -210,6 +210,9 @@ class OpBase { std::unique_ptr op_; platform::Place place_; size_t id_{-1UL}; + // In order to reduce the compatibility phase + // performance overhead, temporarily cache KernelContext + static pten::KernelContext pt_kernel_context_; std::vector>> void_function_post_hooks_; };