Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • PaddleDetection
  • Issue
  • #787

P
PaddleDetection
  • 项目概览

PaddlePaddle / PaddleDetection
大约 2 年 前同步成功

通知 708
Star 11112
Fork 2696
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 184
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 40
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
PaddleDetection
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 184
    • Issue 184
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 40
    • 合并请求 40
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 5月 26, 2020 by saxon_zh@saxon_zhGuest

faster_rcnn_r50_vd_fpn_ciou_loss_1x运行VOC数据集上报错

Created by: yinggo

2020-05-26 23:48:12,707-INFO: Save model to output/faster_rcnn_r50_vd_fpn_ciou_loss_1x/10000.
I0526 23:48:15.279527 24507 parallel_executor.cc:440] The Program will be executed on CUDA using ParallelExecutor, 1 cards are used, so 1 programs are executed in parallel.
I0526 23:48:15.287273 24507 build_strategy.cc:365] SeqOnlyAllReduceOps:0, num_trainers:1
W0526 23:48:15.904718 24683 operator.cc:181] py_func raises an exception pybind11::error_already_set, ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 3 dimension(s)

At:
  <__array_function__ internals>(6): concatenate
  /home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/numpy/core/shape_base.py(344): hstack
  <__array_function__ internals>(6): hstack
  /home/ds1/anaconda3/envs/paddle/PaddleDetection/ppdet/modeling/ops.py(718): _diou_nms
  /home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/paddle/fluid/layers/nn.py(12409): __call__
F0526 23:48:15.904808 24683 exception_holder.h:37] std::exception caught, ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 2 dimension(s) and the array at index 1 has 3 dimension(s)

At:
  <__array_function__ internals>(6): concatenate
  /home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/numpy/core/shape_base.py(344): hstack
  <__array_function__ internals>(6): hstack
  /home/ds1/anaconda3/envs/paddle/PaddleDetection/ppdet/modeling/ops.py(718): _diou_nms
  /home/ds1/anaconda3/envs/paddle/lib/python3.7/site-packages/paddle/fluid/layers/nn.py(12409): __call__
*** Check failure stack trace: ***
    @     0x7f81fe916c2d  google::LogMessage::Fail()
    @     0x7f81fe91a6dc  google::LogMessage::SendToLog()
    @     0x7f81fe916753  google::LogMessage::Flush()
    @     0x7f81fe91bbee  google::LogMessageFatal::~LogMessageFatal()
    @     0x7f8200ef09b8  paddle::framework::details::ExceptionHolder::Catch()
    @     0x7f8200f9c68e  paddle::framework::details::FastThreadedSSAGraphExecutor::RunOpSync()
    @     0x7f8200f9b29f  paddle::framework::details::FastThreadedSSAGraphExecutor::RunOp()
    @     0x7f8200f9b564  _ZNSt17_Function_handlerIFvvESt17reference_wrapperISt12_Bind_simpleIFS1_ISt5_BindIFZN6paddle9framework7details28FastThreadedSSAGraphExecutor10RunOpAsyncEPSt13unordered_mapIPNS6_12OpHandleBaseESt6atomicIiESt4hashISA_ESt8equal_toISA_ESaISt4pairIKSA_SC_EEESA_RKSt10shared_ptrINS5_13BlockingQueueImEEEEUlvE_vEEEvEEEE9_M_invokeERKSt9_Any_data
    @     0x7f81fe96f983  std::_Function_handler<>::_M_invoke()
    @     0x7f81fe6fdc37  std::__future_base::_State_base::_M_do_set()
    @     0x7f8231911a99  __pthread_once_slow
    @     0x7f8200f96a52  _ZNSt13__future_base11_Task_stateISt5_BindIFZN6paddle9framework7details28FastThreadedSSAGraphExecutor10RunOpAsyncEPSt13unordered_mapIPNS4_12OpHandleBaseESt6atomicIiESt4hashIS8_ESt8equal_toIS8_ESaISt4pairIKS8_SA_EEES8_RKSt10shared_ptrINS3_13BlockingQueueImEEEEUlvE_vEESaIiEFvvEE6_M_runEv
    @     0x7f81fe6ffe64  _ZZN10ThreadPoolC1EmENKUlvE_clEv
    @     0x7f822a83a3e7  execute_native_thread_routine_compat
    @     0x7f823190a6ba  start_thread
    @     0x7f823164041d  clone
    @              (nil)  (unknown)
```

之前也测试过一些模型,都是在VOC数据格式下做的,但这个就跑不通。。求指教

```
(贴一下config)
:
architecture: FasterRCNN
max_iters: 90000
snapshot_iter: 10000
use_gpu: true
log_smooth_window: 20
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_pretrained.tar
weights: output/faster_rcnn_r50_vd_fpn_ciou_loss_1x_voc/model_final
metric: VOC
num_classes: 20

FasterRCNN:
  backbone: ResNet
  fpn: FPN
  rpn_head: FPNRPNHead
  roi_extractor: FPNRoIAlign
  bbox_head: BBoxHead
  bbox_assigner: BBoxAssigner

ResNet:
  depth: 50
  feature_maps: [2, 3, 4, 5]
  freeze_at: 2
  norm_type: bn
  variant: d

FPN:
  max_level: 6
  min_level: 2
  num_chan: 256
  spatial_scale: [0.03125, 0.0625, 0.125, 0.25]

FPNRPNHead:
  anchor_generator:
    anchor_sizes: [32, 64, 128, 256, 512]
    aspect_ratios: [0.5, 1.0, 2.0]
    stride: [16.0, 16.0]
    variance: [1.0, 1.0, 1.0, 1.0]
  anchor_start_size: 32
  max_level: 6
  min_level: 2
  num_chan: 256
  rpn_target_assign:
    rpn_batch_size_per_im: 256
    rpn_fg_fraction: 0.5
    rpn_negative_overlap: 0.3
    rpn_positive_overlap: 0.7
    rpn_straddle_thresh: 0.0
  train_proposal:
    min_size: 0.0
    nms_thresh: 0.7
    post_nms_top_n: 2000
    pre_nms_top_n: 2000
  test_proposal:
    min_size: 0.0
    nms_thresh: 0.7
    post_nms_top_n: 1000
    pre_nms_top_n: 1000

FPNRoIAlign:
  canconical_level: 4
  canonical_size: 224
  max_level: 5
  min_level: 2
  box_resolution: 7
  sampling_ratio: 2

BBoxAssigner:
  batch_size_per_im: 512
  bbox_reg_weights: [0.1, 0.1, 0.2, 0.2]
  bg_thresh_hi: 0.5
  bg_thresh_lo: 0.0
  fg_fraction: 0.25
  fg_thresh: 0.5

BBoxHead:
  head: TwoFCHead
  nms: MultiClassDiouNMS
  bbox_loss: DiouLoss

MultiClassDiouNMS:
  keep_top_k: 100
  nms_threshold: 0.5
  score_threshold: 0.05

DiouLoss:
  loss_weight: 10.0
  is_cls_agnostic: false
  use_complete_iou_loss: true

TwoFCHead:
  mlp_dim: 1024

LearningRate:
  base_lr: 0.02
  schedulers:
  - !PiecewiseDecay
    gamma: 0.1
    milestones: [60000, 80000]
  - !LinearWarmup
    start_factor: 0.1
    steps: 1000

OptimizerBuilder:
  optimizer:
    momentum: 0.9
    type: Momentum
  regularizer:
    factor: 0.0001
    type: L2

_READER_: '../faster_fpn_reader.yml'
TrainReader:
  batch_size: 2
  
TrainReader:
  inputs_def:
    fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
  dataset:
    !VOCDataSet
    anno_path: train.txt #annotations/instances_train2017.json
    dataset_dir: dataset/voc
    use_default_label: true
    with_background: false
  sample_transforms:
  - !DecodeImage
    to_rgb: true
  - !RandomFlipImage
    prob: 0.5
  - !NormalizeImage
    is_channel_first: false
    is_scale: True
    mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
  - !ResizeImage
    interp: 1
    target_size: [416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184, 1216, 1248, 1280, 1312, 1344, 1376, 1408]
    max_size: 1800
    use_cv2: true
  - !Permute
    to_bgr: false
  batch_transforms:
  - !PadBatch
    pad_to_stride: 32
  batch_size: 2
  shuffle: true
  drop_last: false
  worker_num: 2

EvalReader:
  inputs_def:
    fields: ['image', 'im_info', 'im_id', 'im_shape', 'gt_bbox', 'gt_class', 'is_difficult']
    use_flip: true
  dataset:
    !VOCDataSet
    anno_path: val.txt #annotations/instances_val2017.json
    dataset_dir: dataset/voc
    use_default_label: true
    with_background: false
  sample_transforms:
  - !DecodeImage
    to_rgb: True
    with_mixup: False
  - !NormalizeImage
    is_channel_first: false
    is_scale: True
    mean:
    - 0.485
    - 0.456
    - 0.406
    std:
    - 0.229
    - 0.224
    - 0.225
  - !ResizeImage
    interp: 1
    target_size:
    - 1200
    max_size: 2000
    use_cv2: true
  - !Permute
    to_bgr: false
  batch_transforms:
  - !PadBatch
    pad_to_stride: 32
  batch_size: 2
  worker_num: 2
  drop_empty: false

TestReader:
  inputs_def:
    fields: ['image', 'im_info', 'im_id','im_shape']
  dataset:
    !ImageFolder
#    anno_path: annotations/instances_val2017.json
    use_default_label: true
    with_background: false
  sample_transforms:
  - !DecodeImage
    to_rgb: true
    with_mixup: false
  - !NormalizeImage
    is_channel_first: false
    is_scale: true
    mean: [0.485,0.456,0.406]
    std: [0.229, 0.224,0.225]
  - !ResizeImage
    interp: 1
    max_size: 1333
    target_size: 800
    use_cv2: true
  - !Permute
    channel_first: true
    to_bgr: false
  batch_transforms:
  - !PadBatch
    pad_to_stride: 32
    use_padded_im_info: true
  batch_size: 1
  worker_num: 2
```
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/PaddleDetection#787
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7