未验证 提交 93cb2350 编写于 作者: P pangyoki 提交者: GitHub

unify inplace_version checking log in new and old dygraph framework (#41209)

* change inplace_version checking log

* fix
上级 c86e3a11
......@@ -121,10 +121,10 @@ class TensorWrapper {
static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get());
auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();
uint32_t current_inplace_version =
inplace_version_counter.CurrentVersion();
uint32_t wrapper_version_snapshot = inplace_version_snapshot_;
uint32_t tensor_version = inplace_version_counter.CurrentVersion();
PADDLE_ENFORCE_EQ(
current_inplace_version, inplace_version_snapshot_,
tensor_version, wrapper_version_snapshot,
paddle::platform::errors::PermissionDenied(
"Tensor '%s' used in gradient computation has been "
"modified by an inplace operation. "
......@@ -132,14 +132,14 @@ class TensorWrapper {
"Please fix your code to void calling an inplace operator "
"after using the Tensor which will used in gradient "
"computation.",
intermidiate_tensor_.name(), current_inplace_version,
inplace_version_snapshot_));
VLOG(6) << " The inplace_version_snapshot_ of Tensor '"
intermidiate_tensor_.name(), tensor_version,
wrapper_version_snapshot));
VLOG(6) << " The wrapper_version_snapshot of Tensor '"
<< intermidiate_tensor_.name() << "' is [ "
<< inplace_version_snapshot_ << " ]";
VLOG(6) << " The current_inplace_version of Tensor '"
<< intermidiate_tensor_.name() << "' is [ "
<< current_inplace_version << " ]";
<< wrapper_version_snapshot << " ]";
VLOG(6) << " The tensor_version of Tensor '"
<< intermidiate_tensor_.name() << "' is [ " << tensor_version
<< " ]";
}
}
......
......@@ -61,18 +61,11 @@ class TestInplace(unittest.TestCase):
var_d = var_b**2
loss = paddle.nn.functional.relu(var_c + var_d)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_error(self):
with _test_eager_guard():
......@@ -203,18 +196,11 @@ class TestDygraphInplace(unittest.TestCase):
self.inplace_api_processing(var_b)
loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_error(self):
with _test_eager_guard():
......
......@@ -487,7 +487,7 @@ class TestPyLayer(unittest.TestCase):
z = layer(data)
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
z.backward()
......
......@@ -91,18 +91,11 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
view_var_b[0] = 2. # var_b is modified inplace
loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_error(self):
with _test_eager_guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册