未验证 提交 2497f439 编写于 作者: W Wojciech Uss 提交者: GitHub

Handle missing symlink method on Windows (#31006)

上级 5653c3a4
...@@ -1276,11 +1276,11 @@ class TestProgramStateOldSave(unittest.TestCase): ...@@ -1276,11 +1276,11 @@ class TestProgramStateOldSave(unittest.TestCase):
# case 2: load with no need file # case 2: load with no need file
def symlink_force(target, link_name): def symlink_force(target, link_name):
try: try:
os.symlink(target, link_name) self.create_symlink(target, link_name)
except OSError as e: except OSError as e:
if e.errno == errno.EEXIST: if e.errno == errno.EEXIST:
os.remove(link_name) os.remove(link_name)
os.symlink(target, link_name) self.create_symlink(target, link_name)
else: else:
raise e raise e
...@@ -1304,6 +1304,14 @@ class TestProgramStateOldSave(unittest.TestCase): ...@@ -1304,6 +1304,14 @@ class TestProgramStateOldSave(unittest.TestCase):
for k, v in load_state.items(): for k, v in load_state.items():
self.assertTrue(np.array_equal(base_map[k], v)) self.assertTrue(np.array_equal(base_map[k], v))
def create_symlink(self, target, link_name):
try:
os.symlink(target, link_name)
except AttributeError:
import ctypes
kernel_dll = ctypes.windll.LoadLibrary("kernel32.dll")
kernel_dll.CreateSymbolicLinkA(target, link_name, 0)
def check_in_static(self, main_program, base_map): def check_in_static(self, main_program, base_map):
for var in main_program.list_vars(): for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable: if isinstance(var, framework.Parameter) or var.persistable:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册