提交 147dbf8a 编写于 作者: M Megvii Engine Team

fix(test): fix a race condition in TestCudaMemAlloc

GitOrigin-RevId: 99468a7ae33fe6a1445d361f9e9d9688cd6e63cb
上级 7066ad5b
...@@ -570,19 +570,20 @@ public: ...@@ -570,19 +570,20 @@ public:
#endif #endif
using Callback = std::function<void()>; using Callback = std::function<void()>;
void test_free_mem(CompNode cn0, CompNode cn1, DevicePolicy* policy, void test_free_mem(CompNode::Locator loc0, CompNode::Locator loc1, DevicePolicy* policy,
const Callback& before_run, const Callback& after_run) { const Callback& before_run, const Callback& after_run) {
size_t tot, free; size_t tot, free;
policy->set_device(0); policy->set_device(0);
policy->get_mem_info(free, tot); policy->get_mem_info(free, tot);
// exception // exception
auto do_run = [cn0, cn1, policy, free]() { auto do_run = [loc0, loc1, policy, free]() {
void* tmp; void* tmp;
policy->raw_dev_malloc(&tmp, free / 3); policy->raw_dev_malloc(&tmp, free / 3);
auto dev_free = [&](void* ptr) { auto dev_free = [&](void* ptr) {
policy->raw_dev_free(ptr); policy->raw_dev_free(ptr);
}; };
auto cn0 = CompNode::load(loc0), cn1 = CompNode::load(loc1);
std::unique_ptr<void, decltype(dev_free)> tmp_owner{tmp, dev_free}; std::unique_ptr<void, decltype(dev_free)> tmp_owner{tmp, dev_free};
auto check_free = [&](const char* msg, size_t expect) { auto check_free = [&](const char* msg, size_t expect) {
auto get = cn0.get_mem_status_bytes().second; auto get = cn0.get_mem_status_bytes().second;
...@@ -648,7 +649,8 @@ TEST(TestCudaMemAlloc, FreeMem) { ...@@ -648,7 +649,8 @@ TEST(TestCudaMemAlloc, FreeMem) {
REQUIRE_GPU(1); REQUIRE_GPU(1);
CompNode::finalize(); CompNode::finalize();
// same device but different stream // same device but different stream
auto cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu0:1"); using Locator = CompNode::Locator;
auto loc0 = Locator::parse("gpu0"), loc1 = Locator::parse("gpu0:1");
auto policy = std::make_unique<CudaDevicePolicy>(); auto policy = std::make_unique<CudaDevicePolicy>();
constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY"; constexpr const char* KEY = "MGB_CUDA_RESERVE_MEMORY";
...@@ -662,7 +664,7 @@ TEST(TestCudaMemAlloc, FreeMem) { ...@@ -662,7 +664,7 @@ TEST(TestCudaMemAlloc, FreeMem) {
} }
CompNode::finalize(); CompNode::finalize();
}; };
test_free_mem(cn0, cn1, policy.get(), reserve, restore); test_free_mem(loc0, loc1, policy.get(), reserve, restore);
} }
#endif // MGB_CUDA #endif // MGB_CUDA
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册