未验证 提交 9602a182 编写于 作者: L liym27 提交者: GitHub

[Dynamic Inplace] Support ShareInplaceVersionCounterWith for C++ Tensor (#29842)

* Revert "[inplace] Add ShareHolderWith for class Variable and SharePlaceholderWith in VarBase.detach() to share the same Tensor/SelectedRows (#29267)"

This reverts commit b10ecd9d.

* Support ShareInplaceVersionCounterWith to share the same inplace version counter for VarBase
上级 4427df37
...@@ -39,7 +39,10 @@ void Tensor::check_memory_size() const { ...@@ -39,7 +39,10 @@ void Tensor::check_memory_size() const {
numel() * SizeOfType(type()), memory_size())); numel() * SizeOfType(type()), memory_size()));
} }
Tensor::Tensor(const proto::VarType::Type& dtype) : type_(dtype), offset_(0) {} Tensor::Tensor(const proto::VarType::Type& dtype)
: type_(dtype),
offset_(0),
inplace_version_counter_(std::make_shared<TensorInplaceVersion>(0)) {}
size_t Tensor::memory_size() const { size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
...@@ -89,6 +92,15 @@ Tensor& Tensor::ShareDataWith(const Tensor& src) { ...@@ -89,6 +92,15 @@ Tensor& Tensor::ShareDataWith(const Tensor& src) {
*this = src; *this = src;
return *this; return *this;
} }
Tensor& Tensor::ShareInplaceVersionCounterWith(const Tensor& src) {
PADDLE_ENFORCE_NOT_NULL(
inplace_version_counter_,
platform::errors::PreconditionNotMet(
"Tensor does not hold inplace_version_counter_."));
inplace_version_counter_ = src.inplace_version_counter_;
return *this;
}
Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const { Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size(); check_memory_size();
......
...@@ -120,7 +120,10 @@ class Tensor { ...@@ -120,7 +120,10 @@ class Tensor {
friend struct EigenVector; friend struct EigenVector;
public: public:
Tensor() : type_(proto::VarType::FP32), offset_(0) {} Tensor()
: type_(proto::VarType::FP32),
offset_(0),
inplace_version_counter_(std::make_shared<TensorInplaceVersion>(0)) {}
explicit Tensor(const proto::VarType::Type&); explicit Tensor(const proto::VarType::Type&);
...@@ -171,6 +174,9 @@ class Tensor { ...@@ -171,6 +174,9 @@ class Tensor {
/*! The internal of two tensors share the same memory block. */ /*! The internal of two tensors share the same memory block. */
Tensor& ShareDataWith(const Tensor& src); Tensor& ShareDataWith(const Tensor& src);
/*! The internal of two tensors share the same inplace version counter. */
Tensor& ShareInplaceVersionCounterWith(const Tensor& src);
/** /**
* @brief Return a sub-tensor of the given tensor. * @brief Return a sub-tensor of the given tensor.
* *
...@@ -252,7 +258,7 @@ class Tensor { ...@@ -252,7 +258,7 @@ class Tensor {
const proto::VarType::Type type); const proto::VarType::Type type);
TensorInplaceVersion& InplaceVersionCounter() { TensorInplaceVersion& InplaceVersionCounter() {
return inplace_version_counter_; return *inplace_version_counter_;
} }
private: private:
...@@ -290,7 +296,7 @@ class Tensor { ...@@ -290,7 +296,7 @@ class Tensor {
* PlaceHolder::ptr_ and where the tensor data really begins. * PlaceHolder::ptr_ and where the tensor data really begins.
*/ */
size_t offset_; size_t offset_;
TensorInplaceVersion inplace_version_counter_; std::shared_ptr<TensorInplaceVersion> inplace_version_counter_;
}; };
} // namespace framework } // namespace framework
......
...@@ -69,16 +69,6 @@ class Variable { ...@@ -69,16 +69,6 @@ class Variable {
return holder_->Type(); return holder_->Type();
} }
/**
* The internal of two Variables share the same Placeholder whose type can be
* Tensor, LoDTensor, SelectedRows, LoDTensorArray, etc.
*
* NOTE(liym27): In dynamic mode, sharing the same Placeholder also means
* share the same TensorInplaceVersion, which is very important for inplace
* operations.
*/
void SharePlaceholderWith(const Variable& var);
private: private:
// This method hides type T, so it doesn't appear as a template parameter of // This method hides type T, so it doesn't appear as a template parameter of
// Variable. // Variable.
...@@ -123,14 +113,6 @@ class Variable { ...@@ -123,14 +113,6 @@ class Variable {
std::shared_ptr<Placeholder> holder_; std::shared_ptr<Placeholder> holder_;
}; };
inline void Variable::SharePlaceholderWith(const Variable& var) {
PADDLE_ENFORCE_EQ(var.IsInitialized(), true,
platform::errors::PreconditionNotMet(
"Variable holds no memory. "
"Call Variable::GetMutable() firstly."));
holder_ = var.holder_;
}
inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() { inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
framework::TensorInplaceVersion* version_counter_ptr(nullptr); framework::TensorInplaceVersion* version_counter_ptr(nullptr);
if (IsType<framework::LoDTensor>()) { if (IsType<framework::LoDTensor>()) {
......
...@@ -696,44 +696,69 @@ void BindImperative(py::module *m_ptr) { ...@@ -696,44 +696,69 @@ void BindImperative(py::module *m_ptr) {
x = linear(data) x = linear(data)
print(x.numpy()) print(x.numpy())
)DOC") )DOC")
.def( .def("detach",
"detach", [](const imperative::VarBase
[](const imperative::VarBase &self) &self) -> std::shared_ptr<imperative::VarBase> {
-> std::shared_ptr<imperative::VarBase> { PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( self.Var().IsInitialized(), true,
self.Var().IsInitialized(), true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Tensor %s has not been initialized!", self.Name()));
"Tensor %s has not been initialized!", self.Name()));
PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( self.Var().IsType<framework::LoDTensor>() ||
self.Var().IsType<framework::LoDTensor>() || self.Var().IsType<framework::SelectedRows>(),
self.Var().IsType<framework::SelectedRows>(), true,
true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Type of Tensor[%s] must be LoDTensor or SelectedRows!",
"Type of Tensor[%s] must be LoDTensor or SelectedRows!", self.Name()));
self.Name()));
auto detach_var = std::make_shared<imperative::VarBase>(
auto detach_var = std::make_shared<imperative::VarBase>( true, "detach_" + self.Name());
true, "detach_" + self.Name());
detach_var->SetPersistable(self.Persistable());
detach_var->SetPersistable(self.Persistable()); detach_var->SetType(self.Type());
detach_var->SetType(self.Type()); detach_var->SetDataType(self.DataType());
detach_var->SetDataType(self.DataType());
if (self.Var().IsType<framework::LoDTensor>()) {
// NOTE(liym27): const auto &origin_tensor =
// Call Variable::SharePlaceholderWith but not self.Var().Get<framework::LoDTensor>();
// Tensor::ShareDataWith or Tensor::ShareBufferWith, because PADDLE_ENFORCE_EQ(
// `detach_var` should share the same TensorInplaceVersion with origin_tensor.IsInitialized(), true,
// `self`, and only SharePlaceholderWith can also share the same platform::errors::InvalidArgument(
// TensorInplaceVersion, which is used to check whether inplace "Tensor %s has not been initialized!", self.Name()));
// operations are correct.
detach_var->MutableVar()->SharePlaceholderWith(self.Var()); auto *detach_tensor =
detach_var->MutableVar()->GetMutable<framework::LoDTensor>();
VLOG(3) << "The detached Tensor(" << detach_var->Name() detach_tensor->ShareDataWith(origin_tensor);
<< ") share data with " << self.Name(); // NOTE(liym27): Call ShareInplaceVersionCounterWith to share the
return detach_var; // same TensorInplaceVersion, which is used to check whether
}, // inplace
py::return_value_policy::take_ownership, R"DOC( // operations are correct.
detach_tensor->ShareInplaceVersionCounterWith(origin_tensor);
} else {
const auto &origin_selected_rows =
self.Var().Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(
origin_selected_rows.value().IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_selected_rows =
detach_var->MutableVar()
->GetMutable<framework::SelectedRows>();
detach_selected_rows->set_height(origin_selected_rows.height());
detach_selected_rows->set_rows(origin_selected_rows.rows());
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
detach_selected_rows->mutable_value()
->ShareInplaceVersionCounterWith(
origin_selected_rows.value());
}
VLOG(3) << "The detached Tensor(" << detach_var->Name()
<< ") share data with " << self.Name();
return detach_var;
},
py::return_value_policy::take_ownership, R"DOC(
Returns a new Tensor, detached from the current graph. Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy. It will share data with origin Tensor and always doesn't have a Tensor copy.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册