未验证 提交 743649b5 编写于 作者: L liym27 提交者: GitHub

[Cherry-Pick 2.0][Dynamic Inplace] Support ShareInplaceVersionCounterWith for...

[Cherry-Pick 2.0][Dynamic Inplace] Support ShareInplaceVersionCounterWith for C++ Tensor (#29842) (#30105)

Before this PR, SharePlaceHolderWith share Tensor between different C++ Variable, which meas sharing the data, shape, and inplace_version_counter_ of Tensor.
But in some cases, Sharing data and inplace_version_counter_ but not sharing shape is needed. For example, inplace op reshape, can't share shape.

This PR, discard SharePlaceHolderWith, and expose ShareInplaceVersionCounterWith for C++ Tensor.
This reverts commit b10ecd9d.

* Support ShareInplaceVersionCounterWith to share the same inplace version counter for VarBase
上级 52caf787
......@@ -39,7 +39,10 @@ void Tensor::check_memory_size() const {
numel() * SizeOfType(type()), memory_size()));
}
Tensor::Tensor(const proto::VarType::Type& dtype) : type_(dtype), offset_(0) {}
Tensor::Tensor(const proto::VarType::Type& dtype)
: type_(dtype),
offset_(0),
inplace_version_counter_(std::make_shared<TensorInplaceVersion>(0)) {}
size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
......@@ -89,6 +92,15 @@ Tensor& Tensor::ShareDataWith(const Tensor& src) {
*this = src;
return *this;
}
Tensor& Tensor::ShareInplaceVersionCounterWith(const Tensor& src) {
PADDLE_ENFORCE_NOT_NULL(
inplace_version_counter_,
platform::errors::PreconditionNotMet(
"Tensor does not hold inplace_version_counter_."));
inplace_version_counter_ = src.inplace_version_counter_;
return *this;
}
Tensor Tensor::Slice(int64_t begin_idx, int64_t end_idx) const {
check_memory_size();
......
......@@ -120,7 +120,10 @@ class Tensor {
friend struct EigenVector;
public:
Tensor() : type_(proto::VarType::FP32), offset_(0) {}
Tensor()
: type_(proto::VarType::FP32),
offset_(0),
inplace_version_counter_(std::make_shared<TensorInplaceVersion>(0)) {}
explicit Tensor(const proto::VarType::Type&);
......@@ -171,6 +174,9 @@ class Tensor {
/*! The internal of two tensors share the same memory block. */
Tensor& ShareDataWith(const Tensor& src);
/*! The internal of two tensors share the same inplace version counter. */
Tensor& ShareInplaceVersionCounterWith(const Tensor& src);
/**
* @brief Return a sub-tensor of the given tensor.
*
......@@ -252,7 +258,7 @@ class Tensor {
const proto::VarType::Type type);
TensorInplaceVersion& InplaceVersionCounter() {
return inplace_version_counter_;
return *inplace_version_counter_;
}
private:
......@@ -290,7 +296,7 @@ class Tensor {
* PlaceHolder::ptr_ and where the tensor data really begins.
*/
size_t offset_;
TensorInplaceVersion inplace_version_counter_;
std::shared_ptr<TensorInplaceVersion> inplace_version_counter_;
};
} // namespace framework
......
......@@ -69,16 +69,6 @@ class Variable {
return holder_->Type();
}
/**
* The internal of two Variables share the same Placeholder whose type can be
* Tensor, LoDTensor, SelectedRows, LoDTensorArray, etc.
*
* NOTE(liym27): In dynamic mode, sharing the same Placeholder also means
* share the same TensorInplaceVersion, which is very important for inplace
* operations.
*/
void SharePlaceholderWith(const Variable& var);
private:
// This method hides type T, so it doesn't appear as a template parameter of
// Variable.
......@@ -123,14 +113,6 @@ class Variable {
std::shared_ptr<Placeholder> holder_;
};
inline void Variable::SharePlaceholderWith(const Variable& var) {
PADDLE_ENFORCE_EQ(var.IsInitialized(), true,
platform::errors::PreconditionNotMet(
"Variable holds no memory. "
"Call Variable::GetMutable() firstly."));
holder_ = var.holder_;
}
inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
framework::TensorInplaceVersion* version_counter_ptr(nullptr);
if (IsType<framework::LoDTensor>()) {
......
......@@ -696,10 +696,9 @@ void BindImperative(py::module *m_ptr) {
x = linear(data)
print(x.numpy())
)DOC")
.def(
"detach",
[](const imperative::VarBase &self)
-> std::shared_ptr<imperative::VarBase> {
.def("detach",
[](const imperative::VarBase
&self) -> std::shared_ptr<imperative::VarBase> {
PADDLE_ENFORCE_EQ(
self.Var().IsInitialized(), true,
platform::errors::InvalidArgument(
......@@ -720,15 +719,41 @@ void BindImperative(py::module *m_ptr) {
detach_var->SetType(self.Type());
detach_var->SetDataType(self.DataType());
// NOTE(liym27):
// Call Variable::SharePlaceholderWith but not
// Tensor::ShareDataWith or Tensor::ShareBufferWith, because
// `detach_var` should share the same TensorInplaceVersion with
// `self`, and only SharePlaceholderWith can also share the same
// TensorInplaceVersion, which is used to check whether inplace
if (self.Var().IsType<framework::LoDTensor>()) {
const auto &origin_tensor =
self.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
origin_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_tensor =
detach_var->MutableVar()->GetMutable<framework::LoDTensor>();
detach_tensor->ShareDataWith(origin_tensor);
// NOTE(liym27): Call ShareInplaceVersionCounterWith to share the
// same TensorInplaceVersion, which is used to check whether
// inplace
// operations are correct.
detach_var->MutableVar()->SharePlaceholderWith(self.Var());
detach_tensor->ShareInplaceVersionCounterWith(origin_tensor);
} else {
const auto &origin_selected_rows =
self.Var().Get<framework::SelectedRows>();
PADDLE_ENFORCE_EQ(
origin_selected_rows.value().IsInitialized(), true,
platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self.Name()));
auto *detach_selected_rows =
detach_var->MutableVar()
->GetMutable<framework::SelectedRows>();
detach_selected_rows->set_height(origin_selected_rows.height());
detach_selected_rows->set_rows(origin_selected_rows.rows());
detach_selected_rows->mutable_value()->ShareDataWith(
origin_selected_rows.value());
detach_selected_rows->mutable_value()
->ShareInplaceVersionCounterWith(
origin_selected_rows.value());
}
VLOG(3) << "The detached Tensor(" << detach_var->Name()
<< ") share data with " << self.Name();
return detach_var;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册