未验证 提交 8bc27015 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp]Add RWLock to protect loading module under multi-thread and multi-process (#38128)

* Add RWLock to protect loading module under multi-thread

* refine code

* remove import statement
上级 18a59822
...@@ -17,16 +17,24 @@ import re ...@@ -17,16 +17,24 @@ import re
import sys import sys
import json import json
import glob import glob
import atexit
import hashlib import hashlib
import logging import logging
import collections import collections
import textwrap import textwrap
import warnings import warnings
import subprocess import subprocess
import threading
from importlib import machinery
from contextlib import contextmanager from contextlib import contextmanager
from setuptools.command import bdist_egg from setuptools.command import bdist_egg
try:
from subprocess import DEVNULL # py3
except ImportError:
DEVNULL = open(os.devnull, 'wb')
from ...fluid import core from ...fluid import core
from ...fluid.framework import OpProtoHolder from ...fluid.framework import OpProtoHolder
from ...sysconfig import get_include, get_lib from ...sysconfig import get_include, get_lib
...@@ -797,7 +805,6 @@ def parse_op_info(op_name): ...@@ -797,7 +805,6 @@ def parse_op_info(op_name):
Parse input names and outpus detail information from registered custom op Parse input names and outpus detail information from registered custom op
from OpInfoMap. from OpInfoMap.
""" """
from paddle.fluid.framework import OpProtoHolder
if op_name not in OpProtoHolder.instance().op_proto_map: if op_name not in OpProtoHolder.instance().op_proto_map:
raise ValueError( raise ValueError(
"Please load {} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`". "Please load {} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`".
...@@ -844,16 +851,28 @@ def _generate_python_module(module_name, ...@@ -844,16 +851,28 @@ def _generate_python_module(module_name,
""" """
Automatically generate python file to allow import or load into as module Automatically generate python file to allow import or load into as module
""" """
api_file = os.path.join(build_directory, module_name + '.py')
def remove_if_exit(filepath):
if os.path.exists(filepath):
os.remove(filepath)
# NOTE: Use unique id as suffix to avoid write same file at same time in
# both multi-thread and multi-process.
thread_id = str(threading.currentThread().ident)
api_file = os.path.join(build_directory,
module_name + '_' + thread_id + '.py')
log_v("generate api file: {}".format(api_file), verbose) log_v("generate api file: {}".format(api_file), verbose)
# write into .py file # delete the temp file before exit python process
atexit.register(lambda: remove_if_exit(api_file))
# write into .py file with RWLock
api_content = [_custom_api_content(op_name) for op_name in op_names] api_content = [_custom_api_content(op_name) for op_name in op_names]
with open(api_file, 'w') as f: with open(api_file, 'w') as f:
f.write('\n\n'.join(api_content)) f.write('\n\n'.join(api_content))
# load module # load module
custom_module = _load_module_from_file(api_file, verbose) custom_module = _load_module_from_file(api_file, module_name, verbose)
return custom_module return custom_module
...@@ -901,7 +920,7 @@ def _custom_api_content(op_name): ...@@ -901,7 +920,7 @@ def _custom_api_content(op_name):
return api_content return api_content
def _load_module_from_file(api_file_path, verbose=False): def _load_module_from_file(api_file_path, module_name, verbose=False):
""" """
Load module from python file. Load module from python file.
""" """
...@@ -911,8 +930,9 @@ def _load_module_from_file(api_file_path, verbose=False): ...@@ -911,8 +930,9 @@ def _load_module_from_file(api_file_path, verbose=False):
# Unique readable module name to place custom api. # Unique readable module name to place custom api.
log_v('import module from file: {}'.format(api_file_path), verbose) log_v('import module from file: {}'.format(api_file_path), verbose)
ext_name = "_paddle_cpp_extension_" ext_name = "_paddle_cpp_extension_" + module_name
from importlib import machinery
# load module with RWLock
loader = machinery.SourceFileLoader(ext_name, api_file_path) loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module() module = loader.load_module()
...@@ -1066,10 +1086,6 @@ def run_cmd(command, verbose=False): ...@@ -1066,10 +1086,6 @@ def run_cmd(command, verbose=False):
""" """
# logging # logging
log_v("execute command: {}".format(command), verbose) log_v("execute command: {}".format(command), verbose)
try:
from subprocess import DEVNULL # py3
except ImportError:
DEVNULL = open(os.devnull, 'wb')
# execute command # execute command
try: try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册