未验证 提交 ec814cf5 编写于 作者: C csy0225 提交者: GitHub

revert operator.cc (#50895)

上级 cf209204
...@@ -1618,57 +1618,6 @@ void OperatorWithKernel::CheckWhetherPreparePhiData( ...@@ -1618,57 +1618,6 @@ void OperatorWithKernel::CheckWhetherPreparePhiData(
} }
} }
// When do we need to reset runtime context?
// 1. When enable cache runtime context, if the program runs for the first time,
// runtime_ctx_.get() == nullptr, we need to create a new runtime context.
// 2. When enable cache runtime context, if the program is not running for the
// first time,
// but the input shape or tensor layout of the operator has changed, we cannot
// use the runtime context stored in the cache at this time, and need to
// create a new one.
bool OperatorWithKernel::NeedResetRuntimeContext(const Scope& scope) const {
if (runtime_ctx_.get() == nullptr) return true;
const auto& name_map = Inputs();
for (auto& var_name_item : name_map) {
auto& name_vec = var_name_item.second;
std::vector<Variable*>& cache_input_vars =
runtime_ctx_->inputs[var_name_item.first];
PADDLE_ENFORCE_EQ(
name_vec.size(),
cache_input_vars.size(),
platform::errors::InvalidArgument(
"The size of input variable names (%d) must be equal to "
"the size of cache input variable ptrs (%d).",
name_vec.size(),
cache_input_vars.size()));
for (size_t i = 0; i < name_vec.size(); i++) {
auto var_name = name_vec[i];
auto* cache_input_var = cache_input_vars[i];
if (!VarIsTensor(*cache_input_var)) continue;
auto* cache_input_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(cache_input_var);
auto cache_input_tensor_dims = cache_input_tensor->dims();
auto* current_input_var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
current_input_var,
platform::errors::NotFound(
"The variable %s is not found when "
"enable_cache_runtime_context_cache in origin scope.",
var_name));
auto* current_input_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(current_input_var);
auto current_input_tensor_dims = current_input_tensor->dims();
if (cache_input_tensor_dims != current_input_tensor_dims ||
NeedTransformLayout(current_input_tensor->layout(),
cache_input_tensor->layout())) {
need_prepare_data_ = true;
return true;
}
}
}
return false;
}
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
// To reduce the elapsed time of HasAttr, we use bool variable to record the // To reduce the elapsed time of HasAttr, we use bool variable to record the
...@@ -1678,6 +1627,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1678,6 +1627,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
if (!all_kernels_must_compute_runtime_shape_ && if (!all_kernels_must_compute_runtime_shape_ &&
HasAttr(kAllKernelsMustComputeRuntimeShape)) HasAttr(kAllKernelsMustComputeRuntimeShape))
all_kernels_must_compute_runtime_shape_ = true; all_kernels_must_compute_runtime_shape_ = true;
const Scope* cur_scope = &scope;
CheckWhetherPreparePhiData(Inputs(), Outputs(), scope); CheckWhetherPreparePhiData(Inputs(), Outputs(), scope);
if (!enable_cache_runtime_context_) { if (!enable_cache_runtime_context_) {
RuntimeContext ctx(Inputs(), Outputs(), scope); RuntimeContext ctx(Inputs(), Outputs(), scope);
...@@ -1689,9 +1639,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1689,9 +1639,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
(*phi_kernel_)(impl_->getKernelContext()); (*phi_kernel_)(impl_->getKernelContext());
} else { } else {
if (NeedResetRuntimeContext(scope)) { if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
std::lock_guard<std::mutex> lock(cache_update_mutex_); std::lock_guard<std::mutex> lock(cache_update_mutex_);
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope;
}
} }
RunImpl(scope, place, runtime_ctx_.get()); RunImpl(scope, place, runtime_ctx_.get());
} }
...@@ -2086,9 +2039,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -2086,9 +2039,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// To solve issue #15032, have a discussion with @Luotao for cpu inference, // To solve issue #15032, have a discussion with @Luotao for cpu inference,
// do not cache transfer scope, hence in this case delete transfer scope // do not cache transfer scope, hence in this case delete transfer scope
// after run to avoid memory leak // after run to avoid memory leak
if (cache_transfer_scope_ && !run_by_executor_ && if (transfer_scope && !run_by_executor_ && !enable_cache_transfer_scope_) {
!enable_cache_transfer_scope_) { scope.DeleteScope(transfer_scope);
scope.DeleteScope(cache_transfer_scope_);
} }
} }
...@@ -2623,25 +2575,33 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2623,25 +2575,33 @@ Scope* OperatorWithKernel::PrepareData(
kernel_type_for_var.backend() == phi::Backend::GPUDNN || kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
new_expected_kernel_key->backend() == phi::Backend::GPU || new_expected_kernel_key->backend() == phi::Backend::GPU ||
new_expected_kernel_key->backend() == phi::Backend::GPUDNN) { new_expected_kernel_key->backend() == phi::Backend::GPUDNN) {
cache_transfer_scope_ = TryCreateTransferScope( new_scope = TryCreateTransferScope(
kernel_type_for_var, *new_expected_kernel_key, &scope); kernel_type_for_var, *new_expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true; enable_cache_transfer_scope_ = true;
new_scope = cache_transfer_scope_;
} }
} else if (kernel_type_for_var.backend() == phi::Backend::GPU || } else if (kernel_type_for_var.backend() == phi::Backend::GPU ||
kernel_type_for_var.backend() == phi::Backend::GPUDNN || kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
expected_kernel_key.backend() == phi::Backend::GPU || expected_kernel_key.backend() == phi::Backend::GPU ||
expected_kernel_key.backend() == phi::Backend::GPUDNN) { expected_kernel_key.backend() == phi::Backend::GPUDNN) {
cache_transfer_scope_ = TryCreateTransferScope( new_scope = TryCreateTransferScope(
kernel_type_for_var, expected_kernel_key, &scope); kernel_type_for_var, expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true; enable_cache_transfer_scope_ = true;
new_scope = cache_transfer_scope_;
} }
} }
if (!new_scope) { if (!new_scope) {
new_scope = &scope.NewScope(); new_scope = &scope.NewScope();
} }
// For inference, if a gpu model has an op which could only run on CPU,
// each result of different input will be the same with the first one.
// The reason is that if a gpu tensor is the input of a cpu kernel,
// we will create a new cpu tensor in new scope.
// However, if enable_cache_runtime_context_, we get the cpu tensor each
// time, not the gpu tensor. Thus, we set pre_scope_ = nullptr
// to trigger `new RuntimeContext()` in RunImpl().
if (enable_cache_runtime_context_) {
pre_scope_ = nullptr;
}
// Create new var with the same name in transfer scopes // Create new var with the same name in transfer scopes
auto* trans_var = new_scope->Var(var_name); auto* trans_var = new_scope->Var(var_name);
...@@ -2727,13 +2687,18 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2727,13 +2687,18 @@ Scope* OperatorWithKernel::PrepareData(
} }
} }
// If pre_scope = &scope, it means that scope is cached and the op is not in
// while block. If new_scope = nullptr, it means that for each input of this
// Op, there is no need to do PrepareData. So PrepareData could be skipped at
// the rest iterations to save the elapsed time.
// We do not support skipping PrepareData in while block, because the Op's // We do not support skipping PrepareData in while block, because the Op's
// input may be changed by subsequent Ops, which may cause an error. // input may be changed by subsequent Ops, which may cause an error.
// For inference, ops that behind conditional branch aren't supported well, // For inference, ops that behind conditional branch aren't supported well,
// so disable prepare optimization conservatively. // so disable prepare optimization conservatively.
bool force_prepare_data = HasAttr("inference_force_prepare_data") && bool force_prepare_data = HasAttr("inference_force_prepare_data") &&
Attr<bool>("inference_force_prepare_data"); Attr<bool>("inference_force_prepare_data");
if (enable_cache_runtime_context_ && !force_prepare_data) { if (pre_scope_ == &scope && new_scope == nullptr && !force_prepare_data) {
need_prepare_data_ = false; need_prepare_data_ = false;
} }
......
...@@ -781,19 +781,18 @@ class OperatorWithKernel : public OperatorBase { ...@@ -781,19 +781,18 @@ class OperatorWithKernel : public OperatorBase {
// used for IndicateOrPromoteVarDataTypes // used for IndicateOrPromoteVarDataTypes
phi::DenseTensor* GetTensorFormInputSafely(const ExecutionContext& ctx, phi::DenseTensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
const std::string& name) const; const std::string& name) const;
bool NeedResetRuntimeContext(const Scope& scope) const;
protected: protected:
mutable std::unique_ptr<OpKernelType> kernel_type_; mutable std::unique_ptr<OpKernelType> kernel_type_;
mutable std::unique_ptr<OpKernelFunc> kernel_func_; mutable std::unique_ptr<OpKernelFunc> kernel_func_;
mutable std::unique_ptr<RuntimeContext> runtime_ctx_; mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr;
mutable bool need_prepare_data_ = true; mutable bool need_prepare_data_ = true;
mutable bool need_prepare_phi_data_ = false; mutable bool need_prepare_phi_data_ = false;
mutable bool enable_cache_runtime_context_ = false; mutable bool enable_cache_runtime_context_ = false;
mutable bool all_kernels_must_compute_runtime_shape_ = false; mutable bool all_kernels_must_compute_runtime_shape_ = false;
mutable std::mutex cache_update_mutex_; mutable std::mutex cache_update_mutex_;
mutable bool enable_cache_transfer_scope_ = false; mutable bool enable_cache_transfer_scope_ = false;
mutable Scope* cache_transfer_scope_ = nullptr;
// NOTE(jiahongyu): Whether fallback to plain kernel after calling // NOTE(jiahongyu): Whether fallback to plain kernel after calling
// GetExpectedKernelType, use this bool flag to solve mkldnn and cudnn hard // GetExpectedKernelType, use this bool flag to solve mkldnn and cudnn hard
// code // code
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册