未验证 提交 2702af21 编写于 作者: W Weilong Wu 提交者: GitHub

Renamed Func and removed ENFORCE statement (#37348)

* Removed one ENFORCE statement

* Changed func name to _share_buffer_to

* Improve error reporting information

* Updated the logic of _is_share_buffer_to func
上级 ead89b11
...@@ -401,7 +401,7 @@ void VarBase::_CopyGradientFrom(const VarBase& src) { ...@@ -401,7 +401,7 @@ void VarBase::_CopyGradientFrom(const VarBase& src) {
auto& src_tensor = src.Var().Get<framework::LoDTensor>(); auto& src_tensor = src.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(), true, PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"tensor has not been initialized", src.Name())); "Tensor %s has not been initialized", src.Name()));
auto* grad_t = grad_var_->MutableVar()->GetMutable<framework::LoDTensor>(); auto* grad_t = grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
auto* var_ = MutableVar()->GetMutable<framework::LoDTensor>(); auto* var_ = MutableVar()->GetMutable<framework::LoDTensor>();
grad_t->ShareDataWith(src_tensor); grad_t->ShareDataWith(src_tensor);
......
...@@ -1910,54 +1910,50 @@ void BindImperative(py::module *m_ptr) { ...@@ -1910,54 +1910,50 @@ void BindImperative(py::module *m_ptr) {
.def("_clear", .def("_clear",
[](const std::shared_ptr<imperative::VarBase> &self) { [](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( t->IsInitialized(), true,
"tensor has not been initialized")); platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
t->clear(); t->clear();
}) })
.def("_offset", .def("_offset",
[](const std::shared_ptr<imperative::VarBase> &self) { [](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( t->IsInitialized(), true,
"tensor has not been initialized")); platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
return t->offset(); return t->offset();
}) })
.def("_share_buffer_with", .def("_share_buffer_to",
[](const std::shared_ptr<imperative::VarBase> &self, [](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &target_t) { std::shared_ptr<imperative::VarBase> &dst) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *src = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *t_t = auto *dst_ = dst->MutableVar()->GetMutable<framework::LoDTensor>();
target_t->MutableVar()->GetMutable<framework::LoDTensor>(); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(t->IsInitialized(), true, src->IsInitialized(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"tensor has not been initialized")); "Tensor %s has not been initialized!", self->Name()));
PADDLE_ENFORCE_EQ(t_t->IsInitialized(), true, dst_->ShareBufferWith(*src);
platform::errors::InvalidArgument(
"tensor has not been initialized"));
t->ShareBufferWith(*t_t);
}) })
.def("_is_shared_buffer_with", .def("_is_shared_buffer_with",
[](const std::shared_ptr<imperative::VarBase> &self, [](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &target_t) { std::shared_ptr<imperative::VarBase> &dst) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *src = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *t_t = auto *dst_ = dst->MutableVar()->GetMutable<framework::LoDTensor>();
target_t->MutableVar()->GetMutable<framework::LoDTensor>(); if (!src->IsInitialized() || !dst_->IsInitialized()) {
PADDLE_ENFORCE_EQ(t->IsInitialized(), true, return false;
platform::errors::InvalidArgument( }
"tensor has not been initialized")); return dst_->IsSharedBufferWith(*src);
PADDLE_ENFORCE_EQ(t_t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
return t->IsSharedBufferWith(*t_t);
}) })
.def("_slice", .def("_slice",
[](const std::shared_ptr<imperative::VarBase> &self, [](const std::shared_ptr<imperative::VarBase> &self,
int64_t begin_idx, int64_t end_idx) { int64_t begin_idx, int64_t end_idx) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( t->IsInitialized(), true,
"tensor has not been initialized")); platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
return t->Slice(begin_idx, end_idx); return t->Slice(begin_idx, end_idx);
}) })
.def("_copy_gradient_from", .def("_copy_gradient_from",
...@@ -1966,9 +1962,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -1966,9 +1962,10 @@ void BindImperative(py::module *m_ptr) {
.def("_numel", .def("_numel",
[](std::shared_ptr<imperative::VarBase> &self) { [](std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>(); auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true, PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( t->IsInitialized(), true,
"tensor has not been initialized")); platform::errors::InvalidArgument(
"Tensor %s has not been initialized!", self->Name()));
return t->numel(); return t->numel();
}) })
.def_property("name", &imperative::VarBase::Name, .def_property("name", &imperative::VarBase::Name,
......
...@@ -1189,15 +1189,15 @@ class TestVarBaseOffset(unittest.TestCase): ...@@ -1189,15 +1189,15 @@ class TestVarBaseOffset(unittest.TestCase):
self.assertEqual(actual_x._offset(), expected_offset) self.assertEqual(actual_x._offset(), expected_offset)
class TestVarBaseShareBufferWith(unittest.TestCase): class TestVarBaseShareBufferTo(unittest.TestCase):
def test_share_buffer_with(self): def test_share_buffer_To(self):
paddle.disable_static() paddle.disable_static()
np_x = np.random.random((3, 8, 8)) np_src = np.random.random((3, 8, 8))
np_y = np.random.random((3, 8, 8)) src = paddle.to_tensor(np_src, dtype="float64")
x = paddle.to_tensor(np_x, dtype="float64") # empty_var
y = paddle.to_tensor(np_y, dtype="float64") dst = core.VarBase()
x._share_buffer_with(y) src._share_buffer_to(dst)
self.assertEqual(x._is_shared_buffer_with(y), True) self.assertEqual(src._is_shared_buffer_with(dst), True)
class TestVarBaseTo(unittest.TestCase): class TestVarBaseTo(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册