deeplabv3+ 如何修改类别数目
Created by: 994374821
为使您的问题得到快速解决,在建立Issues前,请您先通过如下方式搜索是否有相似问题:【搜索issue关键字】【使用labels筛选】【官方文档】
如果您没有查询到相似问题,为快速解决您的提问,建立issue时请提供如下细节信息:
-
标题:简洁、精准概括您的问题,例如“Insufficient Memory xxx" ”
-
版本、环境信息: 1)PaddlePaddle版本:请提供您的PaddlePaddle版本号,例如1.2 2)系统环境:ubuntu16.04
-
训练信息 1)单卡
-
复现信息:在跑deeplabv3的代码https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleCV/deeplabv3%2B,)) 想要修改检测的类别数目,修改了models.py中的label_number=9,用官网的预训练模型初始化,将logit的信息打印出来后通道数为19 具体代码如下: train.py `with fluid.program_guard(tp, sp): img = fluid.layers.data( name='img', shape=[3] + image_shape, dtype='float32') label = fluid.layers.data(name='label', shape=image_shape, dtype='int32') logit_9 = deeplabv3p(img) print('logit_9',logit_9)
for i, imgs, labels, names in batches: prev_start_time = time.time() if args.parallel: retv = exe_p.run(fetch_list=[pred.name, loss_mean.name], feed={'img': imgs, 'label': labels}) else: retv = exe.run(tp, feed={'img': imgs, 'label': labels}, fetch_list=[logit_9.name]) print('logit shape: {}'.format(retv[0].shape))
终端训练命令:
python train.py --init_weights_path init_params/deeplabv3plus_xception65_initialize.params --save_weights_path output/ --dataset_path ~/Downloads/gaomingda/dataset/baidu_road/终端打印信息:
logit_9 name: "conv2d_146.tmp_1"
type {
type: LOD_TENSOR
lod_tensor {
tensor {
data_type: FP32
dims: -1
dims: 9
dims: 193
dims: 193
}
lod_level: 0
}
}
persistable: false
W1212 14:44:48.281414 13888 device_context.cc:203] Please NOTE: device: 0, CUDA Capability: 52, Driver Version: 9.0, Runtime Version: 9.0
W1212 14:44:48.281466 13888 device_context.cc:210] device: 0, cuDNN Version: 7.0.
load from: init_params/deeplabv3plus_xception65_initialize.params
total number 18914
/home/gaomingda/baidu_road/deeplabv3+/reader.py:39: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use arr[tuple(seq)]
instead of arr[seq]
. In the future this will be interpreted as an array index, arr[np.array(seq)]
, which will result either in an error or a different result.
a = a[slices]
logit shape: (2, 19, 193, 193)
`
- 问题描述:为什么网络的预测通道为9,然而输出的logit为19,这怎么解决呢