未验证 提交 2702af21 编写于 作者: W Weilong Wu 提交者: GitHub

Renamed Func and removed ENFORCE statement (#37348)

* Removed one ENFORCE statement

* Changed func name to _share_buffer_to

* Improve error reporting information

* Updated the logic of _is_share_buffer_to func
上级 ead89b11
......@@ -401,7 +401,7 @@ void VarBase::_CopyGradientFrom(const VarBase& src) {
auto& src_tensor = src.Var().Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(src_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized", src.Name()));
"Tensor %s has not been initialized", src.Name()));
auto* grad_t = grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
auto* var_ = MutableVar()->GetMutable<framework::LoDTensor>();
grad_t->ShareDataWith(src_tensor);
......
......@@ -1910,54 +1910,50 @@ void BindImperative(py::module *m_ptr) {
.def("_clear",
[](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
"Tensor %s has not been initialized!", self->Name()));
t->clear();
})
.def("_offset",
[](const std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
"Tensor %s has not been initialized!", self->Name()));
return t->offset();
})
.def("_share_buffer_with",
.def("_share_buffer_to",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &target_t) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *t_t =
target_t->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
PADDLE_ENFORCE_EQ(t_t->IsInitialized(), true,
std::shared_ptr<imperative::VarBase> &dst) {
auto *src = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *dst_ = dst->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
src->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
t->ShareBufferWith(*t_t);
"Tensor %s has not been initialized!", self->Name()));
dst_->ShareBufferWith(*src);
})
.def("_is_shared_buffer_with",
[](const std::shared_ptr<imperative::VarBase> &self,
std::shared_ptr<imperative::VarBase> &target_t) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *t_t =
target_t->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
PADDLE_ENFORCE_EQ(t_t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
return t->IsSharedBufferWith(*t_t);
std::shared_ptr<imperative::VarBase> &dst) {
auto *src = self->MutableVar()->GetMutable<framework::LoDTensor>();
auto *dst_ = dst->MutableVar()->GetMutable<framework::LoDTensor>();
if (!src->IsInitialized() || !dst_->IsInitialized()) {
return false;
}
return dst_->IsSharedBufferWith(*src);
})
.def("_slice",
[](const std::shared_ptr<imperative::VarBase> &self,
int64_t begin_idx, int64_t end_idx) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
"Tensor %s has not been initialized!", self->Name()));
return t->Slice(begin_idx, end_idx);
})
.def("_copy_gradient_from",
......@@ -1966,9 +1962,10 @@ void BindImperative(py::module *m_ptr) {
.def("_numel",
[](std::shared_ptr<imperative::VarBase> &self) {
auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"tensor has not been initialized"));
"Tensor %s has not been initialized!", self->Name()));
return t->numel();
})
.def_property("name", &imperative::VarBase::Name,
......
......@@ -1189,15 +1189,15 @@ class TestVarBaseOffset(unittest.TestCase):
self.assertEqual(actual_x._offset(), expected_offset)
class TestVarBaseShareBufferWith(unittest.TestCase):
def test_share_buffer_with(self):
class TestVarBaseShareBufferTo(unittest.TestCase):
def test_share_buffer_To(self):
paddle.disable_static()
np_x = np.random.random((3, 8, 8))
np_y = np.random.random((3, 8, 8))
x = paddle.to_tensor(np_x, dtype="float64")
y = paddle.to_tensor(np_y, dtype="float64")
x._share_buffer_with(y)
self.assertEqual(x._is_shared_buffer_with(y), True)
np_src = np.random.random((3, 8, 8))
src = paddle.to_tensor(np_src, dtype="float64")
# empty_var
dst = core.VarBase()
src._share_buffer_to(dst)
self.assertEqual(src._is_shared_buffer_with(dst), True)
class TestVarBaseTo(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册