未验证 提交 82374dc1 编写于 作者: Z Zhen Wang 提交者: GitHub

Add some error messages for the op without double grads. (#25951)

* Add some error messages for the op without double grads.

* fix the test_imperative_double_grad UT.
上级 948bc8b7
...@@ -885,12 +885,14 @@ void PartialGradTask::RunEachOp(OpBase *op) { ...@@ -885,12 +885,14 @@ void PartialGradTask::RunEachOp(OpBase *op) {
if (create_graph_) { if (create_graph_) {
auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs,
op->Attrs(), op->place()); op->Attrs(), op->place());
if (double_grad_node) { PADDLE_ENFORCE_NOT_NULL(
VLOG(10) << "Create " << double_grad_node->size() double_grad_node,
<< " double grad op(s) for " << op->Type() platform::errors::NotFound("The Op %s doesn't have any grad op.",
<< ", pending ops: " << GradPendingOpTypes(*double_grad_node); op->Type()));
double_grad_nodes_.emplace_back(std::move(double_grad_node)); VLOG(10) << "Create " << double_grad_node->size()
} << " double grad op(s) for " << op->Type()
<< ", pending ops: " << GradPendingOpTypes(*double_grad_node);
double_grad_nodes_.emplace_back(std::move(double_grad_node));
} }
VLOG(10) << "There are " << grads_to_accumulate_.size() << " to sum gradient"; VLOG(10) << "There are " << grads_to_accumulate_.size() << " to sum gradient";
......
...@@ -298,16 +298,15 @@ class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad): ...@@ -298,16 +298,15 @@ class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad):
class TestDygraphDoubleGradVisitedUniq(TestCase): class TestDygraphDoubleGradVisitedUniq(TestCase):
def test_compare(self): def test_compare(self):
value = np.random.uniform(-0.5, 0.5, 100).reshape(10, 2, value = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
5).astype("float32")
def model_f(input): def model_f(input):
linear = fluid.dygraph.Linear(5, 3, bias_attr=False) conv2d = fluid.dygraph.Conv2D(3, 2, 3)
for i in range(10): for i in range(10):
if i == 0: if i == 0:
out = linear(input) out = conv2d(input)
else: else:
out = out + linear(input) out = out + conv2d(input)
return out return out
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
...@@ -319,8 +318,14 @@ class TestDygraphDoubleGradVisitedUniq(TestCase): ...@@ -319,8 +318,14 @@ class TestDygraphDoubleGradVisitedUniq(TestCase):
out = model_f(a) out = model_f(a)
dx=fluid.dygraph.grad(outputs=[out],inputs=[a],create_graph=True,retain_graph=True, \ dx = fluid.dygraph.grad(
only_inputs=True,allow_unused=False, backward_strategy=backward_strategy) outputs=[out],
inputs=[a],
create_graph=True,
retain_graph=True,
only_inputs=True,
allow_unused=False,
backward_strategy=backward_strategy)
grad_1 = dx[0].numpy() grad_1 = dx[0].numpy()
...@@ -334,7 +339,9 @@ class TestDygraphDoubleGradVisitedUniq(TestCase): ...@@ -334,7 +339,9 @@ class TestDygraphDoubleGradVisitedUniq(TestCase):
grad_2 = a.gradient() grad_2 = a.gradient()
self.assertTrue(np.array_equal(grad_1, grad_2)) self.assertTrue(
np.allclose(
grad_1, grad_2, rtol=1.e-5, atol=1.e-8, equal_nan=True))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册