faster_rcnn_r50_vd_fpn_ciou_loss_1x运行VOC数据集上报错
Created by: yinggo
2020-05-26 23:48:12,707-INFO: Save model to output/faster_rcnn_r50_vd_fpn_ciou_loss_1x/10000.
I0526 23:48:15.279527 24507 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0526 23:48:15.287273 24507 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
W0526 23:48:15.904718 24683 operator.cc:181] py_func raises an exception pybind11::error_already_set, ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 3 dimension(s)
At:
<__array_function__ internals>(6): concatenate
/home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/numpy/core/shape_base.py(344): hstack
<__array_function__ internals>(6): hstack
/home/ds1/anaconda3/envs/paddle/PaddleDetection/ppdet/modeling/ops.py(718): _diou_nms
/home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/paddle/fluid/layers/nn.py(12409): __call__
F0526 23:48:15.904808 24683 exception_holder.h:37] std::exception caught, ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 3 dimension(s)
At:
<__array_function__ internals>(6): concatenate
/home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/numpy/core/shape_base.py(344): hstack
<__array_function__ internals>(6): hstack
/home/ds1/anaconda3/envs/paddle/PaddleDetection/ppdet/modeling/ops.py(718): _diou_nms
/home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/paddle/fluid/layers/nn.py(12409): __call__
*** Check failure stack trace: ***
@ 0x7f81fe916c2d google::LogMessage::Fail()
@ 0x7f81fe91a6dc google::LogMessage::SendToLog()
@ 0x7f81fe916753 google::LogMessage::Flush()
@ 0x7f81fe91bbee google::LogMessageFatal::~LogMessageFatal()
@ 0x7f8200ef09b8 paddle::framework::details::ExceptionHolder::Catch()
@ 0x7f8200f9c68e paddle::framework::details::FastThreadedSSAGraphExecutor::RunOpSync()
@ 0x7f8200f9b29f paddle::framework::details::FastThreadedSSAGraphExecutor::RunOp()
@ 0x7f8200f9b564 _ZNSt17_Function_handlerIFvvESt17reference_wrapperISt12_Bind_simpleIFS1_ISt5_BindIFZN6paddle9framework7details28FastThreadedSSAGraphExecutor10RunOpAsyncEPSt13unordered_mapIPNS6_12OpHandleBaseESt6atomicIiESt4hashISA_ESt8equal_toISA_ESaISt4pairIKSA_SC_EEESA_RKSt10shared_ptrINS5_13BlockingQueueImEEEEUlvE_vEEEvEEEE9_M_invokeERKSt9_Any_data
@ 0x7f81fe96f983 std::_Function_handler<>::_M_invoke()
@ 0x7f81fe6fdc37 std::__future_base::_State_base::_M_do_set()
@ 0x7f8231911a99 __pthread_once_slow
@ 0x7f8200f96a52 _ZNSt13__future_base11_Task_stateISt5_BindIFZN6paddle9framework7details28FastThreadedSSAGraphExecutor10RunOpAsyncEPSt13unordered_mapIPNS4_12OpHandleBaseESt6atomicIiESt4hashIS8_ESt8equal_toIS8_ESaISt4pairIKS8_SA_EEES8_RKSt10shared_ptrINS3_13BlockingQueueImEEEEUlvE_vEESaIiEFvvEE6_M_runEv
@ 0x7f81fe6ffe64 _ZZN10ThreadPoolC1EmENKUlvE_clEv
@ 0x7f822a83a3e7 execute_native_thread_routine_compat
@ 0x7f823190a6ba start_thread
@ 0x7f823164041d clone
@ (nil) (unknown)
```
之前也测试过一些模型,都是在VOC数据格式下做的,但这个就跑不通。。求指教
```
(贴一下config)
:
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_ciou_loss_1x_voc/model_final
metric: VOC
num_classes: 20
FasterRCNN:
backbone: ResNet
fpn: FPN
rpn_head: FPNRPNHead
roi_extractor: FPNRoIAlign
bbox_head: BBoxHead
bbox_assigner: BBoxAssigner
ResNet:
depth: 50
feature_maps: [2, 3, 4, 5]
freeze_at: 2
norm_type: bn
variant: d
FPN:
max_level: 6
min_level: 2
num_chan: 256
spatial_scale: [0.03125, 0.0625, 0.125, 0.25]
FPNRPNHead:
anchor_generator:
anchor_sizes: [32, 64, 128, 256, 512]
aspect_ratios: [0.5, 1.0, 2.0]
stride: [16.0, 16.0]
variance: [1.0, 1.0, 1.0, 1.0]
anchor_start_size: 32
max_level: 6
min_level: 2
num_chan: 256
rpn_target_assign:
rpn_batch_size_per_im: 256
rpn_fg_fraction: 0.5
rpn_negative_overlap: 0.3
rpn_positive_overlap: 0.7
rpn_straddle_thresh: 0.0
train_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 2000
pre_nms_top_n: 2000
test_proposal:
min_size: 0.0
nms_thresh: 0.7
post_nms_top_n: 1000
pre_nms_top_n: 1000
FPNRoIAlign:
canconical_level: 4
canonical_size: 224
max_level: 5
min_level: 2
box_resolution: 7
sampling_ratio: 2
BBoxAssigner:
batch_size_per_im: 512
bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
bg_thresh_hi: 0.5
bg_thresh_lo: 0.0
fg_fraction: 0.25
fg_thresh: 0.5
BBoxHead:
head: TwoFCHead
nms: MultiClassDiouNMS
bbox_loss: DiouLoss
MultiClassDiouNMS:
keep_top_k: 100
nms_threshold: 0.5
score_threshold: 0.05
DiouLoss:
loss_weight: 10.0
is_cls_agnostic: false
use_complete_iou_loss: true
TwoFCHead:
mlp_dim: 1024
LearningRate:
base_lr: 0.02
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [60000, 80000]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_READER_: '../faster_fpn_reader.yml'
TrainReader:
batch_size: 2
TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
dataset:
!VOCDataSet
anno_path: train.txt #annotations/instances_train2017.json
dataset_dir: dataset/voc
use_default_label: true
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size: [416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216, 1248, 1280, 1312, 1344, 1376, 1408]
max_size: 1800
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
batch_size: 2
shuffle: true
drop_last: false
worker_num: 2
EvalReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'im_shape', 'gt_bbox', 'gt_class', 'is_difficult']
use_flip: true
dataset:
!VOCDataSet
anno_path: val.txt #annotations/instances_val2017.json
dataset_dir: dataset/voc
use_default_label: true
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: False
- !NormalizeImage
is_channel_first: false
is_scale: True
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
- !ResizeImage
interp: 1
target_size:
- 1200
max_size: 2000
use_cv2: true
- !Permute
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
batch_size: 2
worker_num: 2
drop_empty: false
TestReader:
inputs_def:
fields: ['image', 'im_info', 'im_id','im_shape']
dataset:
!ImageFolder
# anno_path: annotations/instances_val2017.json
use_default_label: true
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
with_mixup: false
- !NormalizeImage
is_channel_first: false
is_scale: true
mean: [0.485,0.456,0.406]
std: [0.229, 0.224,0.225]
- !ResizeImage
interp: 1
max_size: 1333
target_size: 800
use_cv2: true
- !Permute
channel_first: true
to_bgr: false
batch_transforms:
- !PadBatch
pad_to_stride: 32
use_padded_im_info: true
batch_size: 1
worker_num: 2
```