未验证 提交 67e36296 编写于 作者: H hong 提交者: GitHub

Cherry pick dygraph double grad depend bug (#25828)

* Fix dygraph grad bugs (#25781)

* fix double grad visitid unit; test=develop

* change name hash_pair to HashPair; test=develop

* follow comment; test=develop

* remove manual seed; test=develop

* change create_graph from True to False; test=develop
上级 731caea3
......@@ -36,6 +36,15 @@
namespace paddle {
namespace imperative {
struct HashPair {
template <class T1, class T2>
size_t operator()(const std::pair<T1, T2> &p) const noexcept {
auto hash1 = std::hash<T1>{}(p.first);
auto hash2 = std::hash<T2>{}(p.second);
return hash1 ^ hash2;
}
};
/**
* This function prunes the graph to get the ops between `output_targets`
* and `input_target_grads`.
......@@ -152,8 +161,10 @@ static void GetGraphInfoBetweenTargets(
target_vars = *input_target_grads;
std::queue<std::pair<OpBase * /*op*/, OpBase * /*pending op*/>> op_queue;
std::unordered_set<std::pair<OpBase *, OpBase *>, HashPair> op_base_visited;
for (auto &endpoint_op : endpoint_ops) {
op_queue.emplace(endpoint_op, nullptr);
op_base_visited.emplace(endpoint_op, nullptr);
}
while (!op_queue.empty()) {
......@@ -207,6 +218,7 @@ static void GetGraphInfoBetweenTargets(
if (pending_op) {
VLOG(10) << "Pending op of " << op->Type() << " is "
<< pending_op->Type();
pending_ops[op].insert(pending_op);
++op_deps[pending_op];
} else {
......@@ -216,7 +228,10 @@ static void GetGraphInfoBetweenTargets(
auto iter = preceding_ops.find(op);
if (iter != preceding_ops.end()) {
for (auto &preceding_op : iter->second) {
op_queue.emplace(preceding_op, op);
if (op_base_visited.count(std::make_pair(preceding_op, op)) == 0) {
op_queue.emplace(preceding_op, op);
op_base_visited.emplace(preceding_op, op);
}
}
}
}
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import paddle.fluid as fluid
import paddle
from paddle.fluid.wrapped_decorator import wrap_decorator
import unittest
from unittest import TestCase
......@@ -295,5 +296,48 @@ class TestDygraphDoubleGradSortGradient(TestDygraphDoubleGrad):
self.shape = [5, 10]
class TestDygraphDoubleGradVisitedUniq(TestCase):
def test_compare(self):
value = np.random.uniform(-0.5, 0.5, 100).reshape(10, 2,
5).astype("float32")
def model_f(input):
linear = fluid.dygraph.Linear(5, 3, bias_attr=False)
for i in range(10):
if i == 0:
out = linear(input)
else:
out = out + linear(input)
return out
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = 123
fluid.default_main_program().random_seed = 123
a = fluid.dygraph.to_variable(value)
a.stop_gradient = False
out = model_f(a)
dx=fluid.dygraph.grad(outputs=[out],inputs=[a],create_graph=False,retain_graph=False, \
only_inputs=True,allow_unused=False, backward_strategy=backward_strategy)
grad_1 = dx[0].numpy()
with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = 123
fluid.default_main_program().random_seed = 123
a = fluid.dygraph.to_variable(value)
a.stop_gradient = False
out = model_f(a)
out.backward(backward_strategy)
grad_2 = a.gradient()
self.assertTrue(np.array_equal(grad_1, grad_2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册