未验证 提交 880eb04d 编写于 作者: Z Zhang Ting 提交者: GitHub

skip PrepareData when it is unnecessary (#22839)

* remove unnecessary prepare data, test=develop

* Op in while block will not skip PrepareData, test=develop
上级 2403362d
......@@ -924,11 +924,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (!all_kernels_must_compute_runtime_shape_ &&
HasAttr(kAllKernelsMustComputeRuntimeShape))
all_kernels_must_compute_runtime_shape_ = true;
const Scope* cur_scope = &scope;
if (!enable_cache_runtime_context_) {
RuntimeContext ctx(Inputs(), Outputs(), scope);
RunImpl(scope, place, &ctx);
pre_scope_ = cur_scope;
} else {
const Scope* cur_scope = &scope;
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
std::lock_guard<std::mutex> lock(cache_update_mutex_);
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
......@@ -958,8 +959,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
{
platform::RecordEvent record_event("prepare_data",
platform::EventRole::kInnerOp);
transfer_scope = PrepareData(scope, *kernel_type_, &transfered_inplace_vars,
runtime_ctx);
if (need_prepare_data_) {
transfer_scope = PrepareData(scope, *kernel_type_,
&transfered_inplace_vars, runtime_ctx);
}
}
// exec scope is the scope that kernel actually executed on.
const Scope& exec_scope =
......@@ -1252,6 +1255,15 @@ Scope* OperatorWithKernel::PrepareData(
SetTensorToVariable(*var, out, trans_var);
}
}
// If pre_scope = &scope, it means that scope is cached and the op is not in
// while block. If new_scope = nullptr, it means that for each input of this
// Op, there is no need to do PrepareData. So PrepareData could be skipped at
// the rest iterations to save the elapsed time.
// We do not support skipping PrepareData in while block, because the Op's
// input may be changed by subsequent Ops, which may cause an error.
if (pre_scope_ == &scope && new_scope == nullptr) {
need_prepare_data_ = false;
}
return new_scope;
}
......
......@@ -185,6 +185,7 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {}
......@@ -541,6 +542,7 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
mutable bool need_prepare_data_ = true;
mutable bool enable_cache_runtime_context_ = false;
mutable bool all_kernels_must_compute_runtime_shape_ = false;
mutable std::mutex cache_update_mutex_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册