未验证 提交 5c364cda 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16711 from luotao1/has_attr

reduce hasAttr elapsed time in RunImpl
...@@ -880,7 +880,16 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig( ...@@ -880,7 +880,16 @@ std::vector<KernelConfig>* OperatorWithKernel::GetKernelConfig(
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
if (!HasAttr(kEnableCacheRuntimeContext)) { // To reduce the elapsed time of HasAttr, we use bool variable to record the
// result of HasAttr.
if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext))
enable_cache_runtime_context = true;
if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel))
enable_cache_expected_kernel = true;
if (!all_kernels_must_compute_runtime_shape &&
HasAttr(kAllKernelsMustComputeRuntimeShape))
all_kernels_must_compute_runtime_shape = true;
if (!enable_cache_runtime_context) {
RuntimeContext ctx(Inputs(), Outputs(), scope); RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx); RunImpl(scope, place, &ctx);
} else { } else {
...@@ -899,7 +908,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -899,7 +908,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
if (!HasAttr(kEnableCacheExpectedKernel) || !kernel_type_) { if (!enable_cache_expected_kernel || !kernel_type_) {
ChooseKernel(*runtime_ctx, scope, place); ChooseKernel(*runtime_ctx, scope, place);
} }
...@@ -918,7 +927,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -918,7 +927,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(kernel_type_->place_); dev_ctx = pool.Get(kernel_type_->place_);
} }
if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { if (!all_kernels_must_compute_runtime_shape) {
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
......
...@@ -506,6 +506,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -506,6 +506,9 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_; mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr; mutable const Scope* pre_scope_ = nullptr;
mutable bool enable_cache_runtime_context = false;
mutable bool enable_cache_expected_kernel = false;
mutable bool all_kernels_must_compute_runtime_shape = false;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册