Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • PaddleDetection
  • Issue
  • #1450

P
PaddleDetection
  • 项目概览

PaddlePaddle / PaddleDetection
大约 2 年 前同步成功

通知 708
Star 11112
Fork 2696
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 184
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 40
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
PaddleDetection
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 184
    • Issue 184
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 40
    • 合并请求 40
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 9月 21, 2020 by saxon_zh@saxon_zhGuest

aistudio: cascade_rcnn_mobilenetv3_fpn 训练报错

Created by: Fauny


Paddle version: None Paddle With CUDA: None OS: Ubuntu 18.04 Python version: 3.6.9 CUDA version: 10.0.326 cuDNN version: 7.6.3 Nvidia driver version: None


自制数据集,640, 320 都报错

2020-09-21 11:38:22,275-INFO: iter: 39600, lr: 0.003861, 'loss_cls_0': '0.000000', 'loss_loc_0': '3.746526', 'loss_cls_1': '0.000000', 'loss_loc_1': '0.000000', 'loss_cls_2': '0.000000', 'loss_loc_2': '0.000000', 'loss_rpn_cls': '0.250100', 'loss_rpn_bbox': '0.029995', 'loss': '4.017196', time: 0.072, eta: 9:10:32
Traceback (most recent call last):
  File "tools/train.py", line 372, in <module>
    main()
  File "tools/train.py", line 245, in main
    outs = exe.run(compiled_train_prog, fetch_list=train_values)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py", line 1082, in run
    six.reraise(*sys.exc_info())
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/six.py", line 703, in reraise
    raise value
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py", line 1080, in run
    return_merged=return_merged)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py", line 1178, in _run_impl
    return_merged=return_merged)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py", line 893, in _run_parallel
    tensors = exe.run(fetch_var_names, return_merged)._move_to_list()
paddle.fluid.core_noavx.EnforceNotMet: 

--------------------------------------------
C++ Call Stacks (More useful to developers):
--------------------------------------------
0   std::string paddle::platform::GetTraceBackString<std::string const&>(std::string const&, char const*, int)
1   paddle::platform::EnforceNotMet::EnforceNotMet(std::string const&, char const*, int)
2   paddle::framework::Tensor::mutable_data(paddle::platform::Place const&, paddle::framework::proto::VarType_Type, unsigned long)
3   paddle::operators::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, float>::Compute(paddle::framework::ExecutionContext const&) const
4   std::_Function_handler<void (paddle::framework::ExecutionContext const&), paddle::framework::OpKernelRegistrarFunctor<paddle::platform::CUDAPlace, false, 0ul, paddle::operators::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, float>, paddle::operators::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, double>, paddle::operators::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int>, paddle::operators::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, long>, paddle::operators::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, paddle::platform::float16> >::operator()(char const*, char const*, int) const::{lambda(paddle::framework::ExecutionContext const&)#1}>::_M_invoke(std::_Any_data const&, paddle::framework::ExecutionContext const&)
5   paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, paddle::platform::Place const&, paddle::framework::RuntimeContext*) const
6   paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, paddle::platform::Place const&) const
7   paddle::framework::OperatorBase::Run(paddle::framework::Scope const&, paddle::platform::Place const&)
8   paddle::framework::details::ComputationOpHandle::RunImpl()
9   paddle::framework::details::FastThreadedSSAGraphExecutor::RunOpSync(paddle::framework::details::OpHandleBase*)
10  paddle::framework::details::FastThreadedSSAGraphExecutor::RunOp(paddle::framework::details::OpHandleBase*, std::shared_ptr<paddle::framework::BlockingQueue<unsigned long> > const&, unsigned long*)
11  std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, void> >::_M_invoke(std::_Any_data const&)
12  std::__future_base::_State_base::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>&, bool&)
13  ThreadPool::ThreadPool(unsigned long)::{lambda()#1}::operator()() const

------------------------------------------
Python Call Stacks (More useful to users):
------------------------------------------
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/framework.py", line 2798, in append_op
    attrs=kwargs.get("attrs", None))
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layer_helper.py", line 43, in append_op
    return self.main_program.current_block().append_op(*args, **kwargs)
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layer_helper.py", line 135, in append_bias_op
    attrs={'axis': dim_start})
  File "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/nn.py", line 363, in fc
    pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims)
  File "/home/aistudio/cascade_rcnn_mobilenetv3_fpn/PaddleDetection/ppdet/modeling/roi_heads/cascade_head.py", line 352, in __call__
    regularizer=L2Decay(0.)))
  File "/home/aistudio/cascade_rcnn_mobilenetv3_fpn/PaddleDetection/ppdet/modeling/roi_heads/cascade_head.py", line 79, in get_output
    head_feat = self.head(roi_feat, wb_scalar, name)
  File "/home/aistudio/cascade_rcnn_mobilenetv3_fpn/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py", line 158, in build
    name='_' + str(i + 1) if i > 0 else '')
  File "/home/aistudio/cascade_rcnn_mobilenetv3_fpn/PaddleDetection/ppdet/modeling/architectures/cascade_rcnn.py", line 327, in train
    return self.build(feed_vars, 'train')
  File "tools/train.py", line 117, in main
    train_fetches = model.train(feed_vars)
  File "tools/train.py", line 372, in <module>
    main()

----------------------
Error Message Summary:
----------------------
Error: When calling this method, the Tensor's numel must be equal or larger than zero. Please check Tensor::dims, or Tensor::Resize has been called first. The Tensor's shape is [-1, 128] now
  [Hint: Expected numel() >= 0, but received numel():-128 < 0:0.] at (/paddle/paddle/fluid/framework/tensor.cc:45)
  [operator < elementwise_add > error]
terminate called without an active exception
W0921 11:38:23.122191  1643 init.cc:235] Warning: PaddlePaddle catches a failure signal, it may not work properly
W0921 11:38:23.122239  1643 init.cc:237] You could check whether you killed PaddlePaddle thread/process accidentally or report the case to PaddlePaddle
W0921 11:38:23.122243  1643 init.cc:240] The detail failure signal is:

W0921 11:38:23.122251  1643 init.cc:243] *** Aborted at 1600659503 (unix time) try "date -d @1600659503" if you are using GNU date ***
W0921 11:38:23.124091  1643 init.cc:243] PC: @                0x0 (unknown)
W0921 11:38:23.124198  1643 init.cc:243] *** SIGABRT (@0x3e80000062e) received by PID 1582 (TID 0x7f5bacc3f700) from PID 1582; stack trace: ***
W0921 11:38:23.125437  1643 init.cc:243]     @     0x7f5bbdc28390 (unknown)
W0921 11:38:23.126608  1643 init.cc:243]     @     0x7f5bbd882428 gsignal
W0921 11:38:23.127753  1643 init.cc:243]     @     0x7f5bbd88402a abort
W0921 11:38:23.128615  1643 init.cc:243]     @     0x7f5b7e60184a __gnu_cxx::__verbose_terminate_handler()
W0921 11:38:23.129341  1643 init.cc:243]     @     0x7f5b7e5fff47 __cxxabiv1::__terminate()
W0921 11:38:23.130137  1643 init.cc:243]     @     0x7f5b7e5fff7d std::terminate()
W0921 11:38:23.130897  1643 init.cc:243]     @     0x7f5b7e5ffc5a __gxx_personality_v0
W0921 11:38:23.131584  1643 init.cc:243]     @     0x7f5b7e8f2b97 _Unwind_ForcedUnwind_Phase2
W0921 11:38:23.132267  1643 init.cc:243]     @     0x7f5b7e8f2e7d _Unwind_ForcedUnwind
W0921 11:38:23.133441  1643 init.cc:243]     @     0x7f5bbdc27070 __GI___pthread_unwind
W0921 11:38:23.134588  1643 init.cc:243]     @     0x7f5bbdc1f845 __pthread_exit
W0921 11:38:23.134867  1643 init.cc:243]     @     0x55de55a09e59 PyThread_exit_thread
W0921 11:38:23.134949  1643 init.cc:243]     @     0x55de5588fc17 PyEval_RestoreThread.cold.798
W0921 11:38:23.136456  1643 init.cc:243]     @     0x7f5b40694e39 pybind11::gil_scoped_release::~gil_scoped_release()
W0921 11:38:23.136919  1643 init.cc:243]     @     0x7f5b407e0a0c _ZZN8pybind1112cpp_function10initializeIZN6paddle6pybind10BindReaderEPNS_6moduleEEUlRNS2_9operators6reader40OrderedMultiDeviceLoDTensorBlockingQueueERKSt6vectorINS2_9framework9LoDTensorESaISC_EEE2_bIS9_SG_EINS_4nameENS_9is_methodENS_7siblingENS_10call_guardIINS_18gil_scoped_releaseEEEEEEEvOT_PFT0_DpT1_EDpRKT2_ENUlRNS_6detail13function_callEE1_4_FUNES11_
W0921 11:38:23.138195  1643 init.cc:243]     @     0x7f5b406b1666 pybind11::cpp_function::dispatcher()
W0921 11:38:23.138525  1643 init.cc:243]     @     0x55de5598b744 _PyMethodDef_RawFastCallKeywords
W0921 11:38:23.138772  1643 init.cc:243]     @     0x55de5598b861 _PyCFunction_FastCallKeywords
W0921 11:38:23.139005  1643 init.cc:243]     @     0x55de559f76e8 _PyEval_EvalFrameDefault
W0921 11:38:23.139221  1643 init.cc:243]     @     0x55de5593b81a _PyEval_EvalCodeWithName
W0921 11:38:23.139432  1643 init.cc:243]     @     0x55de5593c635 _PyFunction_FastCallDict
W0921 11:38:23.139658  1643 init.cc:243]     @     0x55de559f4232 _PyEval_EvalFrameDefault
W0921 11:38:23.139865  1643 init.cc:243]     @     0x55de5598accb _PyFunction_FastCallKeywords
W0921 11:38:23.140095  1643 init.cc:243]     @     0x55de559f2a93 _PyEval_EvalFrameDefault
W0921 11:38:23.140291  1643 init.cc:243]     @     0x55de5598accb _PyFunction_FastCallKeywords
W0921 11:38:23.140516  1643 init.cc:243]     @     0x55de559f2a93 _PyEval_EvalFrameDefault
W0921 11:38:23.140728  1643 init.cc:243]     @     0x55de5593c56b _PyFunction_FastCallDict
W0921 11:38:23.140941  1643 init.cc:243]     @     0x55de5595ae53 _PyObject_Call_Prepend
W0921 11:38:23.141172  1643 init.cc:243]     @     0x55de5594ddbe PyObject_Call
W0921 11:38:23.141268  1643 init.cc:243]     @     0x55de55a4a817 t_bootstrap
W0921 11:38:23.141319  1643 init.cc:243]     @     0x55de55a05788 pythread_wrapper
W0921 11:38:23.142601  1643 init.cc:243]     @     0x7f5bbdc1e6ba start_thread
Aborted (core dumped)

配置文件:

architecture: CascadeRCNN
max_iters: 500000
snapshot_iter: 2000
use_gpu: true
log_smooth_window: 200
save_dir: output
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
weights: output/big/model_final
metric: COCO
num_classes: 1

CascadeRCNN:
  backbone: MobileNetV3RCNN
  fpn: FPN
  rpn_head: FPNRPNHead
  roi_extractor: FPNRoIAlign
  bbox_head: CascadeBBoxHead
  bbox_assigner: CascadeBBoxAssigner

MobileNetV3RCNN:
  norm_type: bn
  freeze_norm: true
  norm_decay: 0.0
  feature_maps: [2, 3, 4]
  conv_decay: 0.00001
  lr_mult_list: [1.0, 1.0, 1.0, 1.0, 1.0]
  scale: 1.0
  model_name: large

FPN:
  min_level: 2
  max_level: 6
  num_chan: 48
  has_extra_convs: true
  spatial_scale: [0.0625, 0.125, 0.25]

FPNRPNHead:
  anchor_generator:
    anchor_sizes: [32, 64, 128, 256, 512]
    aspect_ratios: [0.5, 1.0, 2.0]
    stride: [16.0, 16.0]
    variance: [1.0, 1.0, 1.0, 1.0]
  anchor_start_size: 24
  min_level: 2
  max_level: 6
  num_chan: 48
  rpn_target_assign:
    rpn_batch_size_per_im: 256
    rpn_fg_fraction: 0.5
    rpn_positive_overlap: 0.7
    rpn_negative_overlap: 0.3
    rpn_straddle_thresh: 0.0
  train_proposal:
    min_size: 0.0
    nms_thresh: 0.7
    pre_nms_top_n: 2000
    post_nms_top_n: 2000
  test_proposal:
    min_size: 0.0
    nms_thresh: 0.7
    pre_nms_top_n: 300
    post_nms_top_n: 100

FPNRoIAlign:
  canconical_level: 4
  canonical_size: 224
  min_level: 2
  max_level: 5
  box_resolution: 7
  sampling_ratio: 2

CascadeBBoxAssigner:
  batch_size_per_im: 512
  bbox_reg_weights: [10, 20, 30]
  bg_thresh_lo: [0.0, 0.0, 0.0]
  bg_thresh_hi: [0.5, 0.6, 0.7]
  fg_thresh: [0.5, 0.6, 0.7]
  fg_fraction: 0.25

CascadeBBoxHead:
  head: CascadeTwoFCHead
  bbox_loss: BalancedL1Loss
  nms:
    keep_top_k: 100
    nms_threshold: 0.5
    score_threshold: 0.05

BalancedL1Loss:
  alpha: 0.5
  gamma: 1.5
  beta: 1.0
  loss_weight: 1.0

CascadeTwoFCHead:
  mlp_dim: 128

LearningRate:
  base_lr: 0.005
  schedulers:
  - !CosineDecay
    max_iters: 125000
  - !LinearWarmup
    start_factor: 0.1
    steps: 500

OptimizerBuilder:
  optimizer:
    momentum: 0.9
    type: Momentum
  regularizer:
    factor: 0.00004
    type: L2

TrainReader:
  inputs_def:
    fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
  dataset:
    !COCODataSet
    image_dir: train
    anno_path: annotations/instance_train.json
    dataset_dir: dataset/data50
  sample_transforms:
  - !DecodeImage
    to_rgb: true
  - !RandomFlipImage
    prob: 0.5
  - !AutoAugmentImage
    autoaug_type: v1
  - !NormalizeImage
    is_channel_first: false
    is_scale: true
    mean: [0.485,0.456,0.406]
    std: [0.229, 0.224,0.225]
  - !ResizeImage
    target_size: [416, 448, 480, 512, 544, 576, 608, 640, 672]
    max_size: 1000
    interp: 1
    use_cv2: true
  - !Permute
    to_bgr: false
    channel_first: true
  batch_transforms:
  - !PadBatch
    pad_to_stride: 32
    use_padded_im_info: false
  batch_size: 2
  shuffle: true
  worker_num: 2
  use_process: false


TestReader:
  inputs_def:
    # set image_shape if needed
    fields: ['image', 'im_info', 'im_id', 'im_shape']
  dataset:
    !ImageFolder
    anno_path: annotations/instance_val.json
  sample_transforms:
  - !DecodeImage
    to_rgb: true
    with_mixup: false
  - !NormalizeImage
    is_channel_first: false
    is_scale: true
    mean: [0.485,0.456,0.406]
    std: [0.229, 0.224,0.225]
  - !ResizeImage
    interp: 1
    max_size: 640
    target_size: 640
    use_cv2: true
  - !Permute
    channel_first: true
    to_bgr: false
  batch_transforms:
  - !PadBatch
    pad_to_stride: 32
    use_padded_im_info: true
  batch_size: 1
  shuffle: false



EvalReader:
  inputs_def:
    fields: ['image', 'im_info', 'im_id', 'im_shape']
    # for voc
    #fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_difficult']
  dataset:
    !COCODataSet
    image_dir: val
    anno_path: annotations/instance_val.json
    dataset_dir: dataset/data50
  sample_transforms:
  - !DecodeImage
    to_rgb: true
    with_mixup: false
  - !NormalizeImage
    is_channel_first: false
    is_scale: true
    mean: [0.485,0.456,0.406]
    std: [0.229, 0.224,0.225]
  - !ResizeImage
    interp: 1
    max_size: 640
    target_size: 640
    use_cv2: true
  - !Permute
    channel_first: true
    to_bgr: false
  batch_transforms:
  - !PadBatch
    pad_to_stride: 32
    use_padded_im_info: true
  batch_size: 1
  shuffle: false
  drop_empty: false
  worker_num: 2
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/PaddleDetection#1450
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7