未验证 提交 93cb2350 编写于 作者: P pangyoki 提交者: GitHub

unify inplace_version checking log in new and old dygraph framework (#41209)

* change inplace_version checking log

* fix
上级 c86e3a11
...@@ -121,10 +121,10 @@ class TensorWrapper { ...@@ -121,10 +121,10 @@ class TensorWrapper {
static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get()); static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get());
auto& inplace_version_counter = dense_tensor->InplaceVersionCounter(); auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();
uint32_t current_inplace_version = uint32_t wrapper_version_snapshot = inplace_version_snapshot_;
inplace_version_counter.CurrentVersion(); uint32_t tensor_version = inplace_version_counter.CurrentVersion();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
current_inplace_version, inplace_version_snapshot_, tensor_version, wrapper_version_snapshot,
paddle::platform::errors::PermissionDenied( paddle::platform::errors::PermissionDenied(
"Tensor '%s' used in gradient computation has been " "Tensor '%s' used in gradient computation has been "
"modified by an inplace operation. " "modified by an inplace operation. "
...@@ -132,14 +132,14 @@ class TensorWrapper { ...@@ -132,14 +132,14 @@ class TensorWrapper {
"Please fix your code to void calling an inplace operator " "Please fix your code to void calling an inplace operator "
"after using the Tensor which will used in gradient " "after using the Tensor which will used in gradient "
"computation.", "computation.",
intermidiate_tensor_.name(), current_inplace_version, intermidiate_tensor_.name(), tensor_version,
inplace_version_snapshot_)); wrapper_version_snapshot));
VLOG(6) << " The inplace_version_snapshot_ of Tensor '" VLOG(6) << " The wrapper_version_snapshot of Tensor '"
<< intermidiate_tensor_.name() << "' is [ " << intermidiate_tensor_.name() << "' is [ "
<< inplace_version_snapshot_ << " ]"; << wrapper_version_snapshot << " ]";
VLOG(6) << " The current_inplace_version of Tensor '" VLOG(6) << " The tensor_version of Tensor '"
<< intermidiate_tensor_.name() << "' is [ " << intermidiate_tensor_.name() << "' is [ " << tensor_version
<< current_inplace_version << " ]"; << " ]";
} }
} }
......
...@@ -61,13 +61,6 @@ class TestInplace(unittest.TestCase): ...@@ -61,13 +61,6 @@ class TestInplace(unittest.TestCase):
var_d = var_b**2 var_d = var_b**2
loss = paddle.nn.functional.relu(var_c + var_d) loss = paddle.nn.functional.relu(var_c + var_d)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
RuntimeError, RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}". "received tensor_version:{} != wrapper_version_snapshot:{}".
...@@ -203,13 +196,6 @@ class TestDygraphInplace(unittest.TestCase): ...@@ -203,13 +196,6 @@ class TestDygraphInplace(unittest.TestCase):
self.inplace_api_processing(var_b) self.inplace_api_processing(var_b)
loss = paddle.nn.functional.relu(var_c) loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
RuntimeError, RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}". "received tensor_version:{} != wrapper_version_snapshot:{}".
......
...@@ -487,7 +487,7 @@ class TestPyLayer(unittest.TestCase): ...@@ -487,7 +487,7 @@ class TestPyLayer(unittest.TestCase):
z = layer(data) z = layer(data)
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
RuntimeError, RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}". "received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)): format(1, 0)):
z.backward() z.backward()
......
...@@ -91,13 +91,6 @@ class TestDygraphViewReuseAllocation(unittest.TestCase): ...@@ -91,13 +91,6 @@ class TestDygraphViewReuseAllocation(unittest.TestCase):
view_var_b[0] = 2. # var_b is modified inplace view_var_b[0] = 2. # var_b is modified inplace
loss = paddle.nn.functional.relu(var_c) loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
RuntimeError, RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}". "received tensor_version:{} != wrapper_version_snapshot:{}".
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册