diff --git a/mm/memcontrol.c b/mm/memcontrol.c index c8336e8f8df064c9a1ecdbb8055a785b3481d4f2..da07784dde875247dd7b2c47cd25eacb42780921 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -1158,7 +1158,15 @@ mem_cgroup_iter_load(struct mem_cgroup_reclaim_iter *iter, if (iter->last_dead_count == *sequence) { smp_rmb(); position = iter->last_visited; - if (position && !css_tryget(&position->css)) + + /* + * We cannot take a reference to root because we might race + * with root removal and returning NULL would end up in + * an endless loop on the iterator user level when root + * would be returned all the time. + */ + if (position && position != root && + !css_tryget(&position->css)) position = NULL; } return position; @@ -1167,9 +1175,11 @@ mem_cgroup_iter_load(struct mem_cgroup_reclaim_iter *iter, static void mem_cgroup_iter_update(struct mem_cgroup_reclaim_iter *iter, struct mem_cgroup *last_visited, struct mem_cgroup *new_position, + struct mem_cgroup *root, int sequence) { - if (last_visited) + /* root reference counting symmetric to mem_cgroup_iter_load */ + if (last_visited && last_visited != root) css_put(&last_visited->css); /* * We store the sequence count from the time @last_visited was @@ -1244,7 +1254,8 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root, memcg = __mem_cgroup_iter_next(root, last_visited); if (reclaim) { - mem_cgroup_iter_update(iter, last_visited, memcg, seq); + mem_cgroup_iter_update(iter, last_visited, memcg, root, + seq); if (!memcg) iter->generation++;