未验证 提交 ca3b6bcf 编写于 作者: C chentianyu03 提交者: GitHub

add cache for VariableWrapper (#30880)

* add cache for VariableWrapper

* modify args names and vlog level

* format code style

* add log when set cache to variable_wrapper

* add log when set cache to variable_wrapper

* add comment to variableWrapper cache

* format code style
上级 f114c3f8
......@@ -65,6 +65,10 @@ class OpKernelType {
size_t hash_key() const { return Hash()(*this); }
bool operator<(const OpKernelType& o) const {
return hash_key() < o.hash_key();
}
bool operator==(const OpKernelType& o) const;
bool operator!=(const OpKernelType& o) const { return !(*this == o); }
......
......@@ -20,6 +20,16 @@
namespace paddle {
namespace imperative {
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar();
}
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VariableWrapper>& var) {
return var;
}
const framework::Tensor* GetTensorFromVar(const framework::Variable& var) {
if (var.IsType<framework::LoDTensor>()) {
return &(var.Get<framework::LoDTensor>());
......
......@@ -64,6 +64,11 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
}
}
extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var);
extern const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<VariableWrapper>& var);
template <typename VarType>
std::shared_ptr<NameVarMap<VarType>> PrepareData(
const framework::OperatorWithKernel& op, const NameVarMap<VarType>& ins,
......@@ -82,12 +87,32 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
} else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
if (GetVariableWrapper(var_base)->hasCacheKey(expected_kernel_key)) {
VLOG(3) << "Hit variable_wrapper cache: key="
<< expected_kernel_key;
std::shared_ptr<VariableWrapper> cache_var =
GetVariableWrapper(var_base)->getCacheValue(
expected_kernel_key);
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
}
const auto* tensor = GetTensorFromVar(cache_var->Var());
auto tmp_var = std::make_shared<VarType>(var_base->Name());
tmp_var->SetType(var_base->Type());
SetTensorToVariable(cache_var->Var(), *tensor,
tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
} else {
framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
if (NeedTransformDataType(kernel_type_for_var, expected_kernel_key)) {
if (NeedTransformDataType(kernel_type_for_var,
expected_kernel_key)) {
// To avoid NameVarMap copy construction overhead in general
// scenarios, if inplace transformed, return original input directly
// scenarios, if inplace transformed, return original input
// directly
if (tmp_ins_ptr == nullptr) {
tmp_ins_ptr = std::make_shared<NameVarMap<VarType>>(ins);
}
......@@ -95,8 +120,14 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
tmp_var->SetType(var_base->Type());
SetTensorToVariable(var_base->Var(), out, tmp_var->MutableVar());
(*tmp_ins_ptr)[name_pair.first][i] = tmp_var;
GetVariableWrapper(var_base)->setCacheValue(
expected_kernel_key, GetVariableWrapper(tmp_var));
VLOG(3) << "Set cache to variable_wrapper: key="
<< expected_kernel_key;
} else {
// if dtype is same, transform inplace will not change the original
// if dtype is same, transform inplace will not change the
// original
// value, transform inplace to avoid multiple copy
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
}
......@@ -104,6 +135,7 @@ std::shared_ptr<NameVarMap<VarType>> PrepareData(
}
}
}
}
return tmp_ins_ptr;
}
......
......@@ -14,10 +14,12 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/op_base.h"
......@@ -238,6 +240,21 @@ class VariableWrapper {
inplace_version_snapshot_ = new_version;
}
bool hasCacheKey(const paddle::framework::OpKernelType& key) {
return var_cache.find(key) != var_cache.end();
}
std::shared_ptr<VariableWrapper> getCacheValue(
const paddle::framework::OpKernelType& key) {
return var_cache[key];
}
void setCacheValue(const paddle::framework::OpKernelType& key,
std::shared_ptr<VariableWrapper> val) {
var_cache[key] = val;
return;
}
private:
void SetGradVar(const std::shared_ptr<VariableWrapper>& var) {
auto shared_var = grad_var_.lock();
......@@ -311,6 +328,10 @@ class VariableWrapper {
framework::Variable var_;
std::string name_;
// Used for cache the dtype promotioned variableWrapper in real and complex
// compute of Paddle Quantum
std::map<paddle::framework::OpKernelType, std::shared_ptr<VariableWrapper>>
var_cache;
// add this property for users may set stop_gradient themselves and this
// should override the frameworks setting (-1) unset, (1) true, (0) false
int overrided_stop_gradient_{-1};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册