未验证 提交 cdeb3167 编写于 作者: Y YuanRisheng 提交者: GitHub

[BugFix]Fix dims mismatch when run rec_svtrnet model in eager mode (#43373)

* change tensor name

* fix unittest bugs
上级 acfd7129
...@@ -57,7 +57,7 @@ class Controller { ...@@ -57,7 +57,7 @@ class Controller {
} }
bool HasGrad() const { return tracer_->HasGrad(); } bool HasGrad() const { return tracer_->HasGrad(); }
void SetHasGrad(bool has_grad) { tracer_->SetHasGrad(has_grad); } void SetHasGrad(bool has_grad) { tracer_->SetHasGrad(has_grad); }
std::string GenerateUniqueName(std::string key = "eager_tmp") { std::string GenerateUniqueName(std::string key = "eager_in_tmp") {
return tracer_->GenerateUniqueName(key); return tracer_->GenerateUniqueName(key);
} }
const std::shared_ptr<paddle::imperative::Tracer>& GetCurrentTracer() { const std::shared_ptr<paddle::imperative::Tracer>& GetCurrentTracer() {
......
...@@ -1776,7 +1776,7 @@ class TestEagerTensorGradNameValue(unittest.TestCase): ...@@ -1776,7 +1776,7 @@ class TestEagerTensorGradNameValue(unittest.TestCase):
b = a**2 b = a**2
self.assertEqual(a._grad_value(), None) self.assertEqual(a._grad_value(), None)
b.backward() b.backward()
self.assertEqual('eager_tmp' in a._grad_name(), True) self.assertEqual('eager_in_tmp' in a._grad_name(), True)
self.assertNotEqual(a._grad_value(), None) self.assertNotEqual(a._grad_value(), None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册