未验证 提交 dd76c1f9 编写于 作者: W Wang Feng 提交者: GitHub

fix(gpu_nms): use lazy compile and import for gpu_nms (#50)

上级 9be38bee
...@@ -14,7 +14,6 @@ import megengine.module as M ...@@ -14,7 +14,6 @@ import megengine.module as M
import megengine.random as rand import megengine.random as rand
from official.vision.detection import layers from official.vision.detection import layers
from official.vision.detection.tools.gpu_nms import batched_nms
class RPN(M.Module): class RPN(M.Module):
...@@ -109,6 +108,7 @@ class RPN(M.Module): ...@@ -109,6 +108,7 @@ class RPN(M.Module):
def find_top_rpn_proposals( def find_top_rpn_proposals(
self, rpn_bbox_offset_list, rpn_cls_score_list, all_anchors_list, im_info self, rpn_bbox_offset_list, rpn_cls_score_list, all_anchors_list, im_info
): ):
from official.vision.detection.tools.gpu_nms import batched_nms
prev_nms_top_n = ( prev_nms_top_n = (
self.cfg.train_prev_nms_top_n self.cfg.train_prev_nms_top_n
if self.training if self.training
......
...@@ -12,7 +12,24 @@ from megengine._internal.craniotome import CraniotomeBase ...@@ -12,7 +12,24 @@ from megengine._internal.craniotome import CraniotomeBase
from megengine.core.tensor import wrap_io_tensor from megengine.core.tensor import wrap_io_tensor
_so_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib_nms.so") _so_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lib_nms.so")
_so_lib = ctypes.CDLL(_so_path) try:
_so_lib = ctypes.CDLL(_so_path)
except Exception:
MGE = mge.__file__.rsplit("/", 1)[0]
current_path = os.path.abspath(__file__).rsplit("/", 1)[0]
mge_path = os.path.join(MGE, "_internal/include")
src_file = os.path.join(current_path, "gpu_nms/nms.cu")
dst_file = os.path.join(current_path, "lib_nms.so")
assert os.path.exists(mge_path)
assert os.path.exists(src_file)
cmd = (
"nvcc -I {} -shared -o {} -Xcompiler '-fno-strict-aliasing -fPIC' {}".format(
mge_path, dst_file, src_file
)
)
os.system(cmd)
_so_lib = ctypes.CDLL(_so_path)
_TYPE_POINTER = ctypes.c_void_p _TYPE_POINTER = ctypes.c_void_p
_TYPE_POINTER = ctypes.c_void_p _TYPE_POINTER = ctypes.c_void_p
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册