From ca3b6bcf789e97288f75e8c1ae03edb88e2e5636 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 25 Feb 2021 14:08:52 +0800 Subject: [PATCH] add cache for VariableWrapper (#30880) * add cache for VariableWrapper * modify args names and vlog level * format code style * add log when set cache to variable_wrapper * add log when set cache to variable_wrapper * add comment to variableWrapper cache * format code style --- paddle/fluid/framework/op_kernel_type.h | 4 ++ paddle/fluid/imperative/prepared_operator.cc | 10 ++++ paddle/fluid/imperative/prepared_operator.h | 52 ++++++++++++++++---- paddle/fluid/imperative/variable_wrapper.h | 21 ++++++++ 4 files changed, 77 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/op_kernel_type.h b/paddle/fluid/framework/op_kernel_type.h index e903b079c27..a2e9d972c48 100644 --- a/paddle/fluid/framework/op_kernel_type.h +++ b/paddle/fluid/framework/op_kernel_type.h @@ -65,6 +65,10 @@ class OpKernelType { size_t hash_key() const { return Hash()(*this); } + bool operator<(const OpKernelType& o) const { + return hash_key() < o.hash_key(); + } + bool operator==(const OpKernelType& o) const; bool operator!=(const OpKernelType& o) const { return !(*this == o); } diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 0e7ded56302..e6e5135316a 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -20,6 +20,16 @@ namespace paddle { namespace imperative { +const std::shared_ptr& GetVariableWrapper( + const std::shared_ptr& var) { + return var->SharedVar(); +} + +const std::shared_ptr& GetVariableWrapper( + const std::shared_ptr& var) { + return var; +} + const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { if (var.IsType()) { return &(var.Get()); diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index d6a72f586b5..1f6be5483be 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -64,6 +64,11 @@ void SetForwardDataTypeOfGradVar(const std::shared_ptr& var) { } } +extern const std::shared_ptr& GetVariableWrapper( + const std::shared_ptr& var); +extern const std::shared_ptr& GetVariableWrapper( + const std::shared_ptr& var); + template std::shared_ptr> PrepareData( const framework::OperatorWithKernel& op, const NameVarMap& ins, @@ -82,23 +87,50 @@ std::shared_ptr> PrepareData( } else { VLOG(3) << "Transform Variable " << var_base->Name() << " from " << kernel_type_for_var << " to " << expected_kernel_key; - framework::Tensor out; - TransformData(expected_kernel_key, kernel_type_for_var, *tensor, - &out); - if (NeedTransformDataType(kernel_type_for_var, expected_kernel_key)) { - // To avoid NameVarMap copy construction overhead in general - // scenarios, if inplace transformed, return original input directly + + if (GetVariableWrapper(var_base)->hasCacheKey(expected_kernel_key)) { + VLOG(3) << "Hit variable_wrapper cache: key=" + << expected_kernel_key; + std::shared_ptr cache_var = + GetVariableWrapper(var_base)->getCacheValue( + expected_kernel_key); if (tmp_ins_ptr == nullptr) { tmp_ins_ptr = std::make_shared>(ins); } + + const auto* tensor = GetTensorFromVar(cache_var->Var()); auto tmp_var = std::make_shared(var_base->Name()); tmp_var->SetType(var_base->Type()); - SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar()); + SetTensorToVariable(cache_var->Var(), *tensor, + tmp_var->MutableVar()); (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; } else { - // if dtype is same, transform inplace will not change the original - // value, transform inplace to avoid multiple copy - SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); + framework::Tensor out; + TransformData(expected_kernel_key, kernel_type_for_var, *tensor, + &out); + if (NeedTransformDataType(kernel_type_for_var, + expected_kernel_key)) { + // To avoid NameVarMap copy construction overhead in general + // scenarios, if inplace transformed, return original input + // directly + if (tmp_ins_ptr == nullptr) { + tmp_ins_ptr = std::make_shared>(ins); + } + auto tmp_var = std::make_shared(var_base->Name()); + tmp_var->SetType(var_base->Type()); + SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar()); + (*tmp_ins_ptr)[name_pair.first][i] = tmp_var; + + GetVariableWrapper(var_base)->setCacheValue( + expected_kernel_key, GetVariableWrapper(tmp_var)); + VLOG(3) << "Set cache to variable_wrapper: key=" + << expected_kernel_key; + } else { + // if dtype is same, transform inplace will not change the + // original + // value, transform inplace to avoid multiple copy + SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); + } } } } diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index 1e900a34456..b42f25dcc88 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -14,10 +14,12 @@ #pragma once +#include #include #include #include +#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/op_base.h" @@ -238,6 +240,21 @@ class VariableWrapper { inplace_version_snapshot_ = new_version; } + bool hasCacheKey(const paddle::framework::OpKernelType& key) { + return var_cache.find(key) != var_cache.end(); + } + + std::shared_ptr getCacheValue( + const paddle::framework::OpKernelType& key) { + return var_cache[key]; + } + + void setCacheValue(const paddle::framework::OpKernelType& key, + std::shared_ptr val) { + var_cache[key] = val; + return; + } + private: void SetGradVar(const std::shared_ptr& var) { auto shared_var = grad_var_.lock(); @@ -311,6 +328,10 @@ class VariableWrapper { framework::Variable var_; std::string name_; + // Used for cache the dtype promotioned variableWrapper in real and complex + // compute of Paddle Quantum + std::map> + var_cache; // add this property for users may set stop_gradient themselves and this // should override the frameworks setting (-1) unset, (1) true, (0) false int overrided_stop_gradient_{-1}; -- GitLab