提交 df79334c 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mge/distributed): add user_pop function to save device memory

BREAKING CHANGE:

GitOrigin-RevId: 0a8e406da5275d0712a5dbb8d032f5cd7ff2cfe6
上级 1eaf32cd
...@@ -145,6 +145,16 @@ class Methods: ...@@ -145,6 +145,16 @@ class Methods:
del self.bcast_dict[key] del self.bcast_dict[key]
return val return val
def _del(self, key):
with self.lock:
del self.user_dict[key]
# thread safe function
def user_pop(self, key):
ret = self.user_get(key)
self._del(key)
return ret
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
...@@ -274,6 +284,10 @@ class Client: ...@@ -274,6 +284,10 @@ class Client:
"""Get user defined key-value pairs across processes.""" """Get user defined key-value pairs across processes."""
return self.proxy.user_get(key) return self.proxy.user_get(key)
def user_pop(self, key):
"""Get user defined key-value pairs and delete the resources when the get is done"""
return self.proxy.user_pop(key)
def bcast_val(self, val, key, size): def bcast_val(self, val, key, size):
idx = self.bcast_dict[key] + 1 idx = self.bcast_dict[key] + 1
self.bcast_dict[key] = idx self.bcast_dict[key] = idx
......
...@@ -219,3 +219,17 @@ def test_collect_results(early_return, output_size): ...@@ -219,3 +219,17 @@ def test_collect_results(early_return, output_size):
else [[dev] * output_size for dev in range(world_size)] else [[dev] * output_size for dev in range(world_size)]
) )
assert results == expects assert results == expects
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_user_set_pop():
@dist.launcher
def worker():
# set in race condition
dist.get_client().user_set("foo", 1)
if dist.get_rank() == 1:
ret = dist.get_client().user_pop("foo")
assert ret == 1
worker()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册