From dec293164d92702edc92225df21e371583637ef9 Mon Sep 17 00:00:00 2001 From: chenyun Date: Wed, 13 Jun 2018 16:03:57 +0800 Subject: [PATCH] update to pytorch 0.4 --- README.MD | 55 +++++++--------------- data/dataset.py | 6 ++- data/util.py | 2 +- misc/convert_caffe_pretrain.py | 6 ++- misc/train_fast.py | 2 - model/__init__.py | 2 +- model/faster_rcnn.py | 12 ++++- model/faster_rcnn_vgg16.py | 3 +- model/roi_module.py | 9 ++-- model/utils/creator_tool.py | 2 +- model/utils/nms/non_maximum_suppression.py | 2 +- train.py | 18 ++++--- trainer.py | 20 ++++---- utils/array_tool.py | 29 +++--------- utils/config.py | 2 +- utils/vis_tool.py | 6 +-- 16 files changed, 75 insertions(+), 101 deletions(-) diff --git a/README.MD b/README.MD index 10a0cbf..a988105 100644 --- a/README.MD +++ b/README.MD @@ -2,6 +2,8 @@ ## 1. Introduction +**I've update the code to support both Python2 and Python3, PyTorch 0.4. If you want the old version code please checkout branch [v0.3]()** + This project is a **Simplified** Faster R-CNN implementation based on [chainercv](https://github.com/chainer/chainercv) and other [projects](#acknowledgement) . It aims to: - Simplify the code (*Simple is better than complex*) @@ -43,16 +45,16 @@ VGG16 train on `trainval` and test on `test` split. | This[1] | TITAN Xp | 14-15 fps | 6 fps | | [pytorch-faster-rcnn](https://github.com/ruotianluo/pytorch-faster-rcnn) | TITAN Xp | 15-17fps | 6fps | -[1]: make sure you install cupy correctly and only one program run on the GPU. The training speed is sensitive to your gpu status. see [troubleshooting](troubleshooting) for more info. Morever it's slow in the start of the program. +[1]: make sure you install cupy correctly and only one program run on the GPU. The training speed is sensitive to your gpu status. see [troubleshooting](troubleshooting) for more info. Morever it's slow in the start of the program -- it need time to warm up. It could be faster by removing visualization, logging, averaging loss etc. ## 3. Install dependencies -requires python3 and PyTorch 0.3 +requires PyTorch >=0.4 -- install PyTorch >=0.3 with GPU (code are GPU-only), refer to [official website](http://pytorch.org) +- install PyTorch >=0.4 with GPU (code are GPU-only), refer to [official website](http://pytorch.org) -- install cupy, you can install via `pip install` but it's better to read the [docs](https://docs-cupy.chainer.org/en/latest/install.html#install-cupy-with-cudnn-and-nccl) and make sure the environ is correctly set +- install cupy, you can install via `pip install cupy-cuda80` or(cupy-cuda90,cupy-cuda91). - install other dependencies: `pip install -r requirements.txt ` @@ -60,13 +62,14 @@ requires python3 and PyTorch 0.3 ```Bash cd model/utils/nms/ - python3 build.py build_ext --inplace + python build.py build_ext --inplace + cd - ``` -- start vidom for visualization +- start visdom for visualization ```Bash -nohup python3 -m visdom.server & +nohup python -m visdom.server & ``` @@ -124,7 +127,7 @@ python misc/convert_caffe_pretrain.py This scripts would download pretrained model and converted it to the format compatible with torchvision. -Then you should specify where caffe-pretraind model `vgg16_caffe.pth` stored in `utils/config.py` by setting `caffe_pretrain_path` +Then you could specify where caffe-pretraind model `vgg16_caffe.pth` stored in `utils/config.py` by setting `caffe_pretrain_path`. The default path is ok. If you want to use pretrained model from torchvision, you may skip this step. @@ -139,7 +142,7 @@ mkdir checkpoints/ # folder for snapshots ``` ```bash -python3 train.py train --env='fasterrcnn-caffe' --plot-every=100 --caffe-pretrain +python train.py train --env='fasterrcnn-caffe' --plot-every=100 --caffe-pretrain ``` you may refer to `utils/config.py` for more argument. @@ -156,47 +159,25 @@ Some Key arguments: you may open browser, visit `http://:8097` and see the visualization of training procedure as below: -![visdom](http://7zh43r.com2.z0.glb.clouddn.com/del/visdom-fasterrcnn.png) - -If you're in China and encounter problem with visdom (i.e. timeout, blank screen), you may refer to [visdom issue](https://github.com/facebookresearch/visdom/issues/111#issuecomment-321743890), and see [troubleshooting](#troubleshooting) for solution. +![visdom](http://7zh43r.com2.z0.glb.clouddn.com/del/visdom-fasterrcnn.png) ## Troubleshooting -- visdom - - Some js files in visdom was blocked in China, see simple solution [here](https://github.com/chenyuntc/PyTorch-book/blob/master/README.md#visdom打不开及其解决方案) - - Also, `updata=append` doesn't work due to a bug brought in latest version, see [issue](https://github.com/facebookresearch/visdom/issues/233) and [fix](https://github.com/facebookresearch/visdom/pull/234/files) - - You don't need to build from source, modifying related files would be OK. - dataloader: `received 0 items of ancdata` see [discussion](https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667), It's alreadly fixed in [train.py](https://github.com/chenyuntc/simple-faster-rcnn-pytorch/blob/master/train.py#L17-L22). So I think you are free from this problem. -- cupy `numpy.core._internal.AxisError: axis 1 out of bounds [0, 1)` - - bug of cupy, see [issue](https://github.com/cupy/cupy/issues/793), fix via [pull request](https://github.com/cupy/cupy/pull/749) - - You don't need to build from source, modifying related files would be OK. - -- VGG: Slow in construction - - VGG16 is slow in construction(i.e. 9 seconds),it could be speed up by this [PR](https://github.com/pytorch/vision/pull/377) - - You don't need to build from source, modifying related files would be OK. - -- About the speed - - One strange thing is that, even the code doesn't use chainer, but if I remove `from chainer import cuda`, the speed drops a lot (train 6.5->6.1,test 14.5->10), because Chainer replaces the default allocator of CuPy by its memory pool implementation. But ever since V4.0, cupy use memory pool as default. However you need to build from souce if you are gona use the latest version of cupy (uninstall cupy -> git clone -> git checkout v4.0 -> setup.py install) @_@ +- Windows support + + I don't have windows machine with GPU to debug and test it. It's welcome if anyone could make a pull request and test it. - Another simple fix: add `from chainer import cuda` at the begining of `train.py`. in such case,you'll need to `pip install chainer` first. ## More - [ ] training on coco - [ ] resnet - [ ] Maybe;replace cupy with THTensor+cffi? - [ ] Maybe:Convert all numpy code to tensor? -- [ ] check python2-compatibility +- [x] python2-compatibility ## Acknowledgement This work builds on many excellent works, which include: @@ -211,7 +192,7 @@ Licensed under MIT, see the LICENSE for more detail. Contribution Welcome. -If you encounter any problem, feel free to open an issue. +If you encounter any problem, feel free to open an issue, but too busy lately. Correct me if anything is wrong or unclear. diff --git a/data/dataset.py b/data/dataset.py index b76f546..aeaf12f 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,8 +1,10 @@ +from __future__ import absolute_import +from __future__ import division import torch as t -from .voc_dataset import VOCBboxDataset +from data.voc_dataset import VOCBboxDataset from skimage import transform as sktsf from torchvision import transforms as tvtsf -from . import util +from data import util import numpy as np from utils.config import opt diff --git a/data/util.py b/data/util.py index 01c3169..3a3fbc1 100644 --- a/data/util.py +++ b/data/util.py @@ -167,7 +167,7 @@ def crop_bbox( if allow_outside_center: mask = np.ones(bbox.shape[0], dtype=bool) else: - center = (bbox[:, :2] + bbox[:, 2:]) / 2 + center = (bbox[:, :2] + bbox[:, 2:]) / 2.0 mask = np.logical_and(crop_bb[:2] <= center, center < crop_bb[2:]) \ .all(axis=1) diff --git a/misc/convert_caffe_pretrain.py b/misc/convert_caffe_pretrain.py index 805b9bd..a2d3d93 100644 --- a/misc/convert_caffe_pretrain.py +++ b/misc/convert_caffe_pretrain.py @@ -15,6 +15,8 @@ sd['classifier.3.bias'] = sd['classifier.4.bias'] del sd['classifier.4.weight'] del sd['classifier.4.bias'] - +import os # speicify the path to save -torch.save(sd, "vgg16_caffe.pth") \ No newline at end of file +if not os.path.exists('checkpoints'): + os.makedirs('checkpoints') +torch.save(sd, "checkpoints/vgg16_caffe.pth") \ No newline at end of file diff --git a/misc/train_fast.py b/misc/train_fast.py index 0201550..33240ec 100644 --- a/misc/train_fast.py +++ b/misc/train_fast.py @@ -7,7 +7,6 @@ from tqdm import tqdm from utils.config import opt from data.dataset import Dataset, TestDataset from model import FasterRCNNVGG16 -from torch.autograd import Variable from torch.utils import data as data_ from trainer import FasterRCNNTrainer from utils import array_tool as at @@ -68,7 +67,6 @@ def train(**kwargs): for ii, (img, bbox_, label_, scale, ori_img) in tqdm(enumerate(dataloader)): scale = at.scalar(scale) img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda() - img, bbox, label = Variable(img), Variable(bbox), Variable(label) losses = trainer.train_step(img, bbox, label, scale) if (ii + 1) % opt.plot_every == 0: diff --git a/model/__init__.py b/model/__init__.py index 70c38e3..ce66f96 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1 +1 @@ -from .faster_rcnn_vgg16 import FasterRCNNVGG16 +from model.faster_rcnn_vgg16 import FasterRCNNVGG16 diff --git a/model/faster_rcnn.py b/model/faster_rcnn.py index 7aefafb..5dba41e 100644 --- a/model/faster_rcnn.py +++ b/model/faster_rcnn.py @@ -1,3 +1,4 @@ +from __future__ import absolute_import from __future__ import division import torch as t import numpy as np @@ -12,6 +13,12 @@ from torch.nn import functional as F from utils.config import opt +def nograd(f): + def new_f(*args,**kwargs): + with t.no_grad(): + return f(*args,**kwargs) + return new_f + class FasterRCNN(nn.Module): """Base class for Faster R-CNN. @@ -176,6 +183,7 @@ class FasterRCNN(nn.Module): score = np.concatenate(score, axis=0).astype(np.float32) return bbox, label, score + @nograd def predict(self, imgs,sizes=None,visualize=False): """Detect objects from images. @@ -220,7 +228,7 @@ class FasterRCNN(nn.Module): labels = list() scores = list() for img, size in zip(prepared_imgs, sizes): - img = t.autograd.Variable(at.totensor(img).float()[None], volatile=True) + img = at.totensor(img[None]).float() scale = img.shape[3] / size[1] roi_cls_loc, roi_scores, rois, _ = self(img, scale=scale) # We are assuming that batch size is 1. @@ -246,7 +254,7 @@ class FasterRCNN(nn.Module): cls_bbox[:, 0::2] = (cls_bbox[:, 0::2]).clamp(min=0, max=size[0]) cls_bbox[:, 1::2] = (cls_bbox[:, 1::2]).clamp(min=0, max=size[1]) - prob = at.tonumpy(F.softmax(at.tovariable(roi_score), dim=1)) + prob = at.tonumpy(F.softmax(at.totensor(roi_score), dim=1)) raw_cls_bbox = at.tonumpy(cls_bbox) raw_prob = at.tonumpy(prob) diff --git a/model/faster_rcnn_vgg16.py b/model/faster_rcnn_vgg16.py index de1eda2..edf02e3 100644 --- a/model/faster_rcnn_vgg16.py +++ b/model/faster_rcnn_vgg16.py @@ -1,3 +1,4 @@ +from __future__ import absolute_import import torch as t from torch import nn from torchvision.models import vgg16 @@ -136,7 +137,7 @@ class VGG16RoIHead(nn.Module): indices_and_rois = t.cat([roi_indices[:, None], rois], dim=1) # NOTE: important: yx->xy xy_indices_and_rois = indices_and_rois[:, [0, 2, 1, 4, 3]] - indices_and_rois = t.autograd.Variable(xy_indices_and_rois.contiguous()) + indices_and_rois = xy_indices_and_rois.contiguous() pool = self.roi(x, indices_and_rois) pool = pool.view(pool.size(0), -1) diff --git a/model/roi_module.py b/model/roi_module.py index 6b45f89..9218af8 100644 --- a/model/roi_module.py +++ b/model/roi_module.py @@ -27,10 +27,6 @@ def GET_BLOCKS(N, K=CUDA_NUM_THREADS): class RoI(Function): - """ - NOTE:only CUDA-compatible - """ - def __init__(self, outh, outw, spatial_scale): self.forward_fn = load_kernel('roi_forward', kernel_forward) self.backward_fn = load_kernel('roi_backward', kernel_backward) @@ -104,8 +100,9 @@ def test_roi_module(): # pytorch version module = RoIPooling2D(outh, outw, spatial_scale) - x = t.autograd.Variable(bottom_data, requires_grad=True) - rois = t.autograd.Variable(bottom_rois) + x = bottom_data.requires_grad_() + rois = bottom_rois.detach() + output = module(x, rois) output.sum().backward() diff --git a/model/utils/creator_tool.py b/model/utils/creator_tool.py index 91e19dc..d9a0d0e 100644 --- a/model/utils/creator_tool.py +++ b/model/utils/creator_tool.py @@ -38,7 +38,7 @@ class ProposalTargetCreator(object): self.pos_ratio = pos_ratio self.pos_iou_thresh = pos_iou_thresh self.neg_iou_thresh_hi = neg_iou_thresh_hi - self.neg_iou_thresh_lo = neg_iou_thresh_lo # NOTE: py-faster-rcnn默认的值是0.1 + self.neg_iou_thresh_lo = neg_iou_thresh_lo # NOTE:default 0.1 in py-faster-rcnn def __call__(self, roi, bbox, label, loc_normalize_mean=(0., 0., 0., 0.), diff --git a/model/utils/nms/non_maximum_suppression.py b/model/utils/nms/non_maximum_suppression.py index c488b52..f176594 100644 --- a/model/utils/nms/non_maximum_suppression.py +++ b/model/utils/nms/non_maximum_suppression.py @@ -167,7 +167,7 @@ def _call_nms_kernel(bbox, thresh): threads = (threads_per_block, 1, 1) mask_dev = cp.zeros((n_bbox * col_blocks,), dtype=np.uint64) - bbox = cp.ascontiguousarray(bbox, dtype=np.float32) # NOTE: 变成连续的 + bbox = cp.ascontiguousarray(bbox, dtype=np.float32) kern = _load_kernel('nms_kernel', _nms_gpu_code) kern(blocks, threads, args=(cp.int32(n_bbox), cp.float32(thresh), bbox, mask_dev)) diff --git a/train.py b/train.py index 7c59556..83f3159 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +from __future__ import absolute_import import os import ipdb @@ -7,7 +8,6 @@ from tqdm import tqdm from utils.config import opt from data.dataset import Dataset, TestDataset, inverse_normalize from model import FasterRCNNVGG16 -from torch.autograd import Variable from torch.utils import data as data_ from trainer import FasterRCNNTrainer from utils import array_tool as at @@ -28,7 +28,7 @@ def eval(dataloader, faster_rcnn, test_num=10000): pred_bboxes, pred_labels, pred_scores = list(), list(), list() gt_bboxes, gt_labels, gt_difficults = list(), list(), list() for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)): - sizes = [sizes[0][0], sizes[1][0]] + sizes = [sizes[0][0].item(), sizes[1][0].item()] pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes]) gt_bboxes += list(gt_bboxes_.numpy()) gt_labels += list(gt_labels_.numpy()) @@ -68,7 +68,6 @@ def train(**kwargs): if opt.load_path: trainer.load(opt.load_path) print('load pretrained model from %s' % opt.load_path) - trainer.vis.text(dataset.db.label_names, win='labels') best_map = 0 for epoch in range(opt.epoch): @@ -76,7 +75,6 @@ def train(**kwargs): for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)): scale = at.scalar(scale) img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda() - img, bbox, label = Variable(img), Variable(bbox), Variable(label) trainer.train_step(img, bbox, label, scale) if (ii + 1) % opt.plot_every == 0: @@ -106,6 +104,12 @@ def train(**kwargs): # roi confusion matrix trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float()) eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num) + trainer.vis.plot('test_map', eval_result['map']) + lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr'] + log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_), + str(eval_result['map']), + str(trainer.get_meter_data())) + trainer.vis.log(log_info) if eval_result['map'] > best_map: best_map = eval_result['map'] @@ -114,12 +118,6 @@ def train(**kwargs): trainer.load(best_path) trainer.faster_rcnn.scale_lr(opt.lr_decay) - trainer.vis.plot('test_map', eval_result['map']) - lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr'] - log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_), - str(eval_result['map']), - str(trainer.get_meter_data())) - trainer.vis.log(log_info) if epoch == 13: break diff --git a/trainer.py b/trainer.py index 29c142b..913d59f 100644 --- a/trainer.py +++ b/trainer.py @@ -1,3 +1,5 @@ +from __future__ import absolute_import +import os from collections import namedtuple import time from torch.nn import functional as F @@ -5,7 +7,6 @@ from model.utils.creator_tool import AnchorTargetCreator, ProposalTargetCreator from torch import nn import torch as t -from torch.autograd import Variable from utils import array_tool as at from utils.vis_tool import Visualizer @@ -126,8 +127,8 @@ class FasterRCNNTrainer(nn.Module): at.tonumpy(bbox), anchor, img_size) - gt_rpn_label = at.tovariable(gt_rpn_label).long() - gt_rpn_loc = at.tovariable(gt_rpn_loc) + gt_rpn_label = at.totensor(gt_rpn_label).long() + gt_rpn_loc = at.totensor(gt_rpn_loc) rpn_loc_loss = _fast_rcnn_loc_loss( rpn_loc, gt_rpn_loc, @@ -145,8 +146,8 @@ class FasterRCNNTrainer(nn.Module): roi_cls_loc = roi_cls_loc.view(n_sample, -1, 4) roi_loc = roi_cls_loc[t.arange(0, n_sample).long().cuda(), \ at.totensor(gt_roi_label).long()] - gt_roi_label = at.tovariable(gt_roi_label).long() - gt_roi_loc = at.tovariable(gt_roi_loc) + gt_roi_label = at.totensor(gt_roi_label).long() + gt_roi_loc = at.totensor(gt_roi_loc) roi_loc_loss = _fast_rcnn_loc_loss( roi_loc.contiguous(), @@ -199,6 +200,10 @@ class FasterRCNNTrainer(nn.Module): for k_, v_ in kwargs.items(): save_path += '_%s' % v_ + save_dir = os.path.dirname(save_path) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + t.save(save_dict, save_path) self.vis.save([self.vis.env]) return save_path @@ -236,7 +241,6 @@ def _smooth_l1_loss(x, t, in_weight, sigma): diff = in_weight * (x - t) abs_diff = diff.abs() flag = (abs_diff.data < (1. / sigma2)).float() - flag = Variable(flag) y = (flag * (sigma2 / 2.) * (diff ** 2) + (1 - flag) * (abs_diff - 0.5 / sigma2)) return y.sum() @@ -248,7 +252,7 @@ def _fast_rcnn_loc_loss(pred_loc, gt_loc, gt_label, sigma): # NOTE: unlike origin implementation, # we don't need inside_weight and outside_weight, they can calculate by gt_label in_weight[(gt_label > 0).view(-1, 1).expand_as(in_weight).cuda()] = 1 - loc_loss = _smooth_l1_loss(pred_loc, gt_loc, Variable(in_weight), sigma) + loc_loss = _smooth_l1_loss(pred_loc, gt_loc, in_weight.detach(), sigma) # Normalize by total number of negtive and positive rois. - loc_loss /= (gt_label >= 0).sum() # ignore gt_label==-1 for rpn_loss + loc_loss /= ((gt_label >= 0).sum().float()) # ignore gt_label==-1 for rpn_loss return loc_loss diff --git a/utils/array_tool.py b/utils/array_tool.py index 6c75fe0..12798a2 100644 --- a/utils/array_tool.py +++ b/utils/array_tool.py @@ -8,39 +8,22 @@ import numpy as np def tonumpy(data): if isinstance(data, np.ndarray): return data - if isinstance(data, t._TensorBase): - return data.cpu().numpy() - if isinstance(data, t.autograd.Variable): - return tonumpy(data.data) + if isinstance(data, t.Tensor): + return data.detach().cpu().numpy() def totensor(data, cuda=True): if isinstance(data, np.ndarray): tensor = t.from_numpy(data) - if isinstance(data, t._TensorBase): - tensor = data - if isinstance(data, t.autograd.Variable): - tensor = data.data + if isinstance(data, t.Tensor): + tensor = data.detach() if cuda: tensor = tensor.cuda() return tensor -def tovariable(data): - if isinstance(data, np.ndarray): - return tovariable(totensor(data)) - if isinstance(data, t._TensorBase): - return t.autograd.Variable(data) - if isinstance(data, t.autograd.Variable): - return data - else: - raise ValueError("UnKnow data type: %s, input should be {np.ndarray,Tensor,Variable}" %type(data)) - - def scalar(data): if isinstance(data, np.ndarray): return data.reshape(1)[0] - if isinstance(data, t._TensorBase): - return data.view(1)[0] - if isinstance(data, t.autograd.Variable): - return data.data.view(1)[0] + if isinstance(data, t.Tensor): + return data.view(1)[0].item() \ No newline at end of file diff --git a/utils/config.py b/utils/config.py index 89d7bf1..cb7609f 100644 --- a/utils/config.py +++ b/utils/config.py @@ -48,7 +48,7 @@ class Config: load_path = None caffe_pretrain = False # use caffe pretrained model instead of torchvision - caffe_pretrain_path = 'checkpoints/vgg16-caffe.pth' + caffe_pretrain_path = 'checkpoints/vgg16_caffe.pth' def _parse(self, kwargs): state_dict = self._state_dict() diff --git a/utils/vis_tool.py b/utils/vis_tool.py index 6878474..9cf651a 100644 --- a/utils/vis_tool.py +++ b/utils/vis_tool.py @@ -133,7 +133,7 @@ def fig2data(fig): brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it - @param fig: a matplotlib figure + @param fig: a matplotlib figure @return a numpy 3D array of RGBA values """ # draw the renderer @@ -178,7 +178,7 @@ class Visualizer(object): self.vis = visdom.Visdom(env=env, **kwargs) self._vis_kw = kwargs - # e.g.(’loss',23) the 23th value of loss + # e.g.('loss',23) the 23th value of loss self.index = {} self.log_text = '' @@ -221,7 +221,7 @@ class Visualizer(object): self.img('input_imgs',t.Tensor(3,64,64)) self.img('input_imgs',t.Tensor(100,1,64,64)) self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10) - !!!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!!! + !!don't ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!! """ self.vis.images(t.Tensor(img_).cpu().numpy(), win=name, -- GitLab