From df79334cae280bdf08c37ffd896f62d55761d887 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Jul 2021 20:35:06 +0800 Subject: [PATCH] feat(mge/distributed): add user_pop function to save device memory BREAKING CHANGE: GitOrigin-RevId: 0a8e406da5275d0712a5dbb8d032f5cd7ff2cfe6 --- imperative/python/megengine/distributed/server.py | 14 ++++++++++++++ .../test/unit/distributed/test_distributed.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index e83efb92..8b3b569e 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -145,6 +145,16 @@ class Methods: del self.bcast_dict[key] return val + def _del(self, key): + with self.lock: + del self.user_dict[key] + + # thread safe function + def user_pop(self, key): + ret = self.user_get(key) + self._del(key) + return ret + class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass @@ -274,6 +284,10 @@ class Client: """Get user defined key-value pairs across processes.""" return self.proxy.user_get(key) + def user_pop(self, key): + """Get user defined key-value pairs and delete the resources when the get is done""" + return self.proxy.user_pop(key) + def bcast_val(self, val, key, size): idx = self.bcast_dict[key] + 1 self.bcast_dict[key] = idx diff --git a/imperative/python/test/unit/distributed/test_distributed.py b/imperative/python/test/unit/distributed/test_distributed.py index eb95fdae..ce51f010 100644 --- a/imperative/python/test/unit/distributed/test_distributed.py +++ b/imperative/python/test/unit/distributed/test_distributed.py @@ -219,3 +219,17 @@ def test_collect_results(early_return, output_size): else [[dev] * output_size for dev in range(world_size)] ) assert results == expects + + +@pytest.mark.require_ngpu(2) +@pytest.mark.isolated_distributed +def test_user_set_pop(): + @dist.launcher + def worker(): + # set in race condition + dist.get_client().user_set("foo", 1) + if dist.get_rank() == 1: + ret = dist.get_client().user_pop("foo") + assert ret == 1 + + worker() -- GitLab