未验证 提交 ca3b6bcf 编写于 作者: C chentianyu03 提交者: GitHub

add cache for VariableWrapper (#30880)

* add cache for VariableWrapper

* modify args names and vlog level

* format code style

* add log when set cache to variable_wrapper

* add log when set cache to variable_wrapper

* add comment to variableWrapper cache

* format code style
上级 f114c3f8
...@@ -65,6 +65,10 @@ class OpKernelType { ...@@ -65,6 +65,10 @@ class OpKernelType {
size_t hash_key() const { return Hash()(*this); } size_t hash_key() const { return Hash()(*this); }
bool operator<(const OpKernelType& o) const {
return hash_key() < o.hash_key();
}
bool operator==(const OpKernelType& o) const; bool operator==(const OpKernelType& o) const;
bool operator!=(const OpKernelType& o) const { return !(*this == o); } bool operator!=(const OpKernelType& o) const { return !(*this == o); }
......
...@@ -20,6 +20,16 @@ ...@@ -20,6 +20,16 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar();
}
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VariableWrapper>& var) {
return var;
}
const framework::Tensor* GetTensorFromVar(const framework::Variable& var) { const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
if (var.IsType<framework::LoDTensor>()) { if (var.IsType<framework::LoDTensor>()) {
return &(var.Get<framework::LoDTensor>()); return &(var.Get<framework::LoDTensor>());
......
...@@ -64,6 +64,11 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) { ...@@ -64,6 +64,11 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
} }
} }
extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var);
extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VariableWrapper>& var);
template <typename VarType> template <typename VarType>
std::shared_ptr<NameVarMap<VarType>> PrepareData( std::shared_ptr<NameVarMap<VarType>> PrepareData(
const framework::OperatorWithKernel& op, const NameVarMap<VarType>& ins, const framework::OperatorWithKernel& op, const NameVarMap<VarType>& ins,
...@@ -82,12 +87,32 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData( ...@@ -82,12 +87,32 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
} else { } else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from " VLOG(3) << "Transform Variable " << var_base->Name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key; << kernel_type_for_var << " to " << expected_kernel_key;
if (GetVariableWrapper(var_base)->hasCacheKey(expected_kernel_key)) {
VLOG(3) << "Hit variable_wrapper cache: key="
<< expected_kernel_key;
std::shared_ptr<VariableWrapper> cache_var =
GetVariableWrapper(var_base)->getCacheValue(
expected_kernel_key);
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
}
const auto* tensor = GetTensorFromVar(cache_var->Var());
auto tmp_var = std::make_shared<VarType>(var_base->Name());
tmp_var->SetType(var_base->Type());
SetTensorToVariable(cache_var->Var(), *tensor,
tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
} else {
framework::Tensor out; framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor, TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out); &out);
if (NeedTransformDataType(kernel_type_for_var, expected_kernel_key)) { if (NeedTransformDataType(kernel_type_for_var,
expected_kernel_key)) {
// To avoid NameVarMap copy construction overhead in general // To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input directly // scenarios, if inplace transformed, return original input
// directly
if (tmp_ins_ptr == nullptr) { if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins); tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
} }
...@@ -95,8 +120,14 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData( ...@@ -95,8 +120,14 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
tmp_var->SetType(var_base->Type()); tmp_var->SetType(var_base->Type());
SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar()); SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var; (*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
GetVariableWrapper(var_base)->setCacheValue(
expected_kernel_key, GetVariableWrapper(tmp_var));
VLOG(3) << "Set cache to variable_wrapper: key="
<< expected_kernel_key;
} else { } else {
// if dtype is same, transform inplace will not change the original // if dtype is same, transform inplace will not change the
// original
// value, transform inplace to avoid multiple copy // value, transform inplace to avoid multiple copy
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
} }
...@@ -104,6 +135,7 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData( ...@@ -104,6 +135,7 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
} }
} }
} }
}
return tmp_ins_ptr; return tmp_ins_ptr;
} }
......
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/op_base.h"
...@@ -238,6 +240,21 @@ class VariableWrapper { ...@@ -238,6 +240,21 @@ class VariableWrapper {
inplace_version_snapshot_ = new_version; inplace_version_snapshot_ = new_version;
} }
bool hasCacheKey(const paddle::framework::OpKernelType& key) {
return var_cache.find(key) != var_cache.end();
}
std::shared_ptr<VariableWrapper> getCacheValue(
const paddle::framework::OpKernelType& key) {
return var_cache[key];
}
void setCacheValue(const paddle::framework::OpKernelType& key,
std::shared_ptr<VariableWrapper> val) {
var_cache[key] = val;
return;
}
private: private:
void SetGradVar(const std::shared_ptr<VariableWrapper>& var) { void SetGradVar(const std::shared_ptr<VariableWrapper>& var) {
auto shared_var = grad_var_.lock(); auto shared_var = grad_var_.lock();
...@@ -311,6 +328,10 @@ class VariableWrapper { ...@@ -311,6 +328,10 @@ class VariableWrapper {
framework::Variable var_; framework::Variable var_;
std::string name_; std::string name_;
// Used for cache the dtype promotioned variableWrapper in real and complex
// compute of Paddle Quantum
std::map<paddle::framework::OpKernelType, std::shared_ptr<VariableWrapper>>
var_cache;
// add this property for users may set stop_gradient themselves and this // add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false // should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1}; int overrided_stop_gradient_{-1};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册