未验证 提交 8bc27015 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp]Add RWLock to protect loading module under multi-thread and multi-process (#38128)

* Add RWLock to protect loading module under multi-thread

* refine code

* remove import statement
上级 18a59822
......@@ -17,16 +17,24 @@ import re
import sys
import json
import glob
import atexit
import hashlib
import logging
import collections
import textwrap
import warnings
import subprocess
import threading
from importlib import machinery
from contextlib import contextmanager
from setuptools.command import bdist_egg
try:
from subprocess import DEVNULL # py3
except ImportError:
DEVNULL = open(os.devnull, 'wb')
from ...fluid import core
from ...fluid.framework import OpProtoHolder
from ...sysconfig import get_include, get_lib
......@@ -797,7 +805,6 @@ def parse_op_info(op_name):
Parse input names and outpus detail information from registered custom op
from OpInfoMap.
"""
from paddle.fluid.framework import OpProtoHolder
if op_name not in OpProtoHolder.instance().op_proto_map:
raise ValueError(
"Please load {} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`".
......@@ -844,16 +851,28 @@ def _generate_python_module(module_name,
"""
Automatically generate python file to allow import or load into as module
"""
api_file = os.path.join(build_directory, module_name + '.py')
def remove_if_exit(filepath):
if os.path.exists(filepath):
os.remove(filepath)
# NOTE: Use unique id as suffix to avoid write same file at same time in
# both multi-thread and multi-process.
thread_id = str(threading.currentThread().ident)
api_file = os.path.join(build_directory,
module_name + '_' + thread_id + '.py')
log_v("generate api file: {}".format(api_file), verbose)
# write into .py file
# delete the temp file before exit python process
atexit.register(lambda: remove_if_exit(api_file))
# write into .py file with RWLock
api_content = [_custom_api_content(op_name) for op_name in op_names]
with open(api_file, 'w') as f:
f.write('\n\n'.join(api_content))
# load module
custom_module = _load_module_from_file(api_file, verbose)
custom_module = _load_module_from_file(api_file, module_name, verbose)
return custom_module
......@@ -901,7 +920,7 @@ def _custom_api_content(op_name):
return api_content
def _load_module_from_file(api_file_path, verbose=False):
def _load_module_from_file(api_file_path, module_name, verbose=False):
"""
Load module from python file.
"""
......@@ -911,8 +930,9 @@ def _load_module_from_file(api_file_path, verbose=False):
# Unique readable module name to place custom api.
log_v('import module from file: {}'.format(api_file_path), verbose)
ext_name = "_paddle_cpp_extension_"
from importlib import machinery
ext_name = "_paddle_cpp_extension_" + module_name
# load module with RWLock
loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module()
......@@ -1066,10 +1086,6 @@ def run_cmd(command, verbose=False):
"""
# logging
log_v("execute command: {}".format(command), verbose)
try:
from subprocess import DEVNULL # py3
except ImportError:
DEVNULL = open(os.devnull, 'wb')
# execute command
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册