未验证 提交 82374dc1 编写于 作者: Z Zhen Wang 提交者: GitHub

Add some error messages for the op without double grads. (#25951)

* Add some error messages for the op without double grads.

* fix the test_imperative_double_grad UT.
上级 948bc8b7
......@@ -885,13 +885,15 @@ void PartialGradTask::RunEachOp(OpBase *op) {
if (create_graph_) {
auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs,
op->Attrs(), op->place());
if (double_grad_node) {
PADDLE_ENFORCE_NOT_NULL(
double_grad_node,
platform::errors::NotFound("The Op %s doesn't have any grad op.",
op->Type()));
VLOG(10) << "Create " << double_grad_node->size()
<< " double grad op(s) for " << op->Type()
<< ", pending ops: " << GradPendingOpTypes(*double_grad_node);
double_grad_nodes_.emplace_back(std::move(double_grad_node));
}
}
VLOG(10) << "There are " << grads_to_accumulate_.size() << " to sum gradient";
......
......@@ -298,16 +298,15 @@ class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad):
class TestDygraphDoubleGradVisitedUniq(TestCase):
def test_compare(self):
value = np.random.uniform(-0.5, 0.5, 100).reshape(10, 2,
5).astype("float32")
value = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
def model_f(input):
linear = fluid.dygraph.Linear(5, 3, bias_attr=False)
conv2d = fluid.dygraph.Conv2D(3, 2, 3)
for i in range(10):
if i == 0:
out = linear(input)
out = conv2d(input)
else:
out = out + linear(input)
out = out + conv2d(input)
return out
backward_strategy = fluid.dygraph.BackwardStrategy()
......@@ -319,8 +318,14 @@ class TestDygraphDoubleGradVisitedUniq(TestCase):
out = model_f(a)
dx=fluid.dygraph.grad(outputs=[out],inputs=[a],create_graph=True,retain_graph=True, \
only_inputs=True,allow_unused=False, backward_strategy=backward_strategy)
dx = fluid.dygraph.grad(
outputs=[out],
inputs=[a],
create_graph=True,
retain_graph=True,
only_inputs=True,
allow_unused=False,
backward_strategy=backward_strategy)
grad_1 = dx[0].numpy()
......@@ -334,7 +339,9 @@ class TestDygraphDoubleGradVisitedUniq(TestCase):
grad_2 = a.gradient()
self.assertTrue(np.array_equal(grad_1, grad_2))
self.assertTrue(
np.allclose(
grad_1, grad_2, rtol=1.e-5, atol=1.e-8, equal_nan=True))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册