未验证 提交 c56832ed 编写于 作者: F Francisco Massa 提交者: GitHub

Add support for Python 2 (#11)

* Add missing __init__.py files

* Add packages

* Rename logging.py to logger.py

Import rules from Python2 makes this a bad idea

* Make import_file py2 compatible

* list does not have .copy() in py2

* math.log2 does not exist in py2

* Miscellaneous fixes for py2

* Address comments
上级 8323c118
......@@ -3,7 +3,7 @@ __pycache__
_ext
*.pyc
*.so
maskrcnn-benchmark.egg-info/
maskrcnn_benchmark.egg-info/
build/
dist/
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import bisect
import copy
import logging
import torch.utils.data
......@@ -63,7 +64,8 @@ def make_data_sampler(dataset, shuffle, distributed):
def _quantize(x, bins):
bins = sorted(bins.copy())
bins = copy.copy(bins)
bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import math
import torch
import torch.nn.functional as F
from torch import nn
......@@ -57,7 +56,7 @@ class Pooler(nn.Module):
"""
Arguments:
output_size (list[tuple[int]] or list[int]): output size for the pooled region
scales (list[flaot]): scales for each Pooler
scales (list[float]): scales for each Pooler
sampling_ratio (int): sampling ratio for ROIAlign
"""
super(Pooler, self).__init__()
......@@ -72,8 +71,8 @@ class Pooler(nn.Module):
self.output_size = output_size
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -math.log2(scales[0])
lvl_max = -math.log2(scales[-1])
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.map_levels = LevelMapper(lvl_min, lvl_max)
def convert_to_roi_format(self, boxes):
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from __future__ import division
import torch
......
......@@ -119,7 +119,10 @@ def _rename_weights_for_resnet(weights, stage_names):
def _load_c2_pickled_weights(file_path):
with open(file_path, "rb") as f:
data = pickle.load(f, encoding="latin1")
if torch._six.PY3:
data = pickle.load(f, encoding="latin1")
else:
data = pickle.load(f)
if "blobs" in data:
weights = data["blobs"]
else:
......
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import importlib
import importlib.util
import sys
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
def import_file(module_name, file_path, make_importable=False):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if make_importable:
sys.modules[module_name] = module
return module
import torch
if torch._six.PY3:
import importlib
import importlib.util
import sys
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
def import_file(module_name, file_path, make_importable=False):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if make_importable:
sys.modules[module_name] = module
return module
else:
import imp
def import_file(module_name, file_path, make_importable=None):
module = imp.load_source(module_name, file_path)
return module
......@@ -62,7 +62,7 @@ setup(
author="fmassa",
url="https://github.com/facebookresearch/maskrnn-benchmark",
description="object detection in pytorch",
# packages=find_packages(exclude=("configs", "examples", "test",)),
packages=find_packages(exclude=("configs", "tests",)),
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
......
......@@ -14,7 +14,7 @@ from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize
from maskrcnn_benchmark.utils.logging import setup_logger
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
......
......@@ -22,7 +22,7 @@ from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logging import setup_logger
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册