diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index e83efb926c3bdaefa375eac4f264e58f57b808a8..8b3b569e7369fffddd0dbe8796f6879b3ca5f049 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -145,6 +145,16 @@ class Methods: del self.bcast_dict[key] return val + def _del(self, key): + with self.lock: + del self.user_dict[key] + + # thread safe function + def user_pop(self, key): + ret = self.user_get(key) + self._del(key) + return ret + class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass @@ -274,6 +284,10 @@ class Client: """Get user defined key-value pairs across processes.""" return self.proxy.user_get(key) + def user_pop(self, key): + """Get user defined key-value pairs and delete the resources when the get is done""" + return self.proxy.user_pop(key) + def bcast_val(self, val, key, size): idx = self.bcast_dict[key] + 1 self.bcast_dict[key] = idx diff --git a/imperative/python/test/unit/distributed/test_distributed.py b/imperative/python/test/unit/distributed/test_distributed.py index eb95fdaed48318bd736bb1d5580072463e601ec9..ce51f0107ed2d3d330f634db9af9e419ebbd7483 100644 --- a/imperative/python/test/unit/distributed/test_distributed.py +++ b/imperative/python/test/unit/distributed/test_distributed.py @@ -219,3 +219,17 @@ def test_collect_results(early_return, output_size): else [[dev] * output_size for dev in range(world_size)] ) assert results == expects + + +@pytest.mark.require_ngpu(2) +@pytest.mark.isolated_distributed +def test_user_set_pop(): + @dist.launcher + def worker(): + # set in race condition + dist.get_client().user_set("foo", 1) + if dist.get_rank() == 1: + ret = dist.get_client().user_pop("foo") + assert ret == 1 + + worker()