From 462ee101224a7d2ac2ea1a88d41ef90e341a98a4 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 24 Dec 2021 13:03:39 +0800 Subject: [PATCH] [CustomOp]Add RWLock to protect loading module under multi-thread and multi-process (#38128) (#38271) [Cherry-pick][CustomOp]Add RWLock to protect loading module under multi-thread and multi-process (#38128) --- .../utils/cpp_extension/extension_utils.py | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 0a2d71abfd..dd69441817 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -17,16 +17,24 @@ import re import sys import json import glob +import atexit import hashlib import logging import collections import textwrap import warnings import subprocess +import threading +from importlib import machinery from contextlib import contextmanager from setuptools.command import bdist_egg +try: + from subprocess import DEVNULL # py3 +except ImportError: + DEVNULL = open(os.devnull, 'wb') + from ...fluid import core from ...fluid.framework import OpProtoHolder from ...sysconfig import get_include, get_lib @@ -777,7 +785,6 @@ def parse_op_info(op_name): Parse input names and outpus detail information from registered custom op from OpInfoMap. """ - from paddle.fluid.framework import OpProtoHolder if op_name not in OpProtoHolder.instance().op_proto_map: raise ValueError( "Please load {} shared library file firstly by `paddle.utils.cpp_extension.load_op_meta_info_and_register_op(...)`". @@ -824,16 +831,28 @@ def _generate_python_module(module_name, """ Automatically generate python file to allow import or load into as module """ - api_file = os.path.join(build_directory, module_name + '.py') + + def remove_if_exit(filepath): + if os.path.exists(filepath): + os.remove(filepath) + + # NOTE: Use unique id as suffix to avoid write same file at same time in + # both multi-thread and multi-process. + thread_id = str(threading.currentThread().ident) + api_file = os.path.join(build_directory, + module_name + '_' + thread_id + '.py') log_v("generate api file: {}".format(api_file), verbose) - # write into .py file + # delete the temp file before exit python process + atexit.register(lambda: remove_if_exit(api_file)) + + # write into .py file with RWLock api_content = [_custom_api_content(op_name) for op_name in op_names] with open(api_file, 'w') as f: f.write('\n\n'.join(api_content)) # load module - custom_module = _load_module_from_file(api_file, verbose) + custom_module = _load_module_from_file(api_file, module_name, verbose) return custom_module @@ -881,7 +900,7 @@ def _custom_api_content(op_name): return api_content -def _load_module_from_file(api_file_path, verbose=False): +def _load_module_from_file(api_file_path, module_name, verbose=False): """ Load module from python file. """ @@ -891,8 +910,9 @@ def _load_module_from_file(api_file_path, verbose=False): # Unique readable module name to place custom api. log_v('import module from file: {}'.format(api_file_path), verbose) - ext_name = "_paddle_cpp_extension_" - from importlib import machinery + ext_name = "_paddle_cpp_extension_" + module_name + + # load module with RWLock loader = machinery.SourceFileLoader(ext_name, api_file_path) module = loader.load_module() @@ -1046,10 +1066,6 @@ def run_cmd(command, verbose=False): """ # logging log_v("execute command: {}".format(command), verbose) - try: - from subprocess import DEVNULL # py3 - except ImportError: - DEVNULL = open(os.devnull, 'wb') # execute command try: -- GitLab