Enforce failed. Expected dim_imgsize.size() == 2, but received dim_imgsize.size():1 != 2:2.
Created by: stevehejianguo
为使您的问题得到快速解决,在建立Issue前,请您先通过如下方式搜索是否有相似问题:【搜索issue关键字】【使用labels筛选】【官方文档】
如果您没有查询到相似问题,为快速解决您的提问,建立issue时请提供如下细节信息:
- 标题:简洁、精准描述您的问题,例如“最新预测库的API文档在哪儿 ”
- 版本、环境信息: 1)PaddlePaddle版本:请提供您的PaddlePaddle版本号(如1.1)或CommitID 2)CPU:预测若用CPU,请提供CPU型号,MKL/OpenBlas/MKLDNN/等数学库使用情况 3)GPU:预测若用GPU,请提供GPU型号、CUDA和CUDNN版本号 4)系统环境:请您描述系统类型、版本(如Mac OS 10.14),Python版本 -预测信息 1)C++预测:请您提供预测库安装包的版本信息,及其中的version.txt文件 2)CMake包含路径的完整命令 3)API信息(如调用请提供) 4)预测库来源:官网下载/特殊环境(如BCLOUD编译)
- 复现信息:如为报错,请给出复现环境、复现步骤
- 问题描述:请详细描述您的问题,同步贴出报错信息、日志/代码关键片段
Thank you for contributing to PaddlePaddle. Before submitting the issue, you could search issue in the github in case that th If there is no solution,please make sure that this is an inference issue including the following details : System information -PaddlePaddle version (eg.1.1)or CommitID -CPU: including CPUMKL/OpenBlas/MKLDNN version -GPU: including CUDA/CUDNN version -OS Platform (eg.Mac OS 10.14) -Python version -Cmake orders -C++version.txt -API information To Reproduce Steps to reproduce the behavior Describe your current behavior Code to reproduce the issue Other info / logs 运行如下命令: !cd scores && python score.py --model_dir /home/aistudio/output/yolov3tiny --data_dir /home/aistudio/data/data7122/ 比较完整的错误: exe.run(program.desc, scope, 0, True, True, fetch_var_name) paddle.fluid.core_avx.EnforceNotMet: Invoke operator yolo_box error. Python Callstacks: File "/opt/conda/envs/python35-paddle120-env/lib/python3.5/site-packages/paddle/fluid/framework.py", line 1748, in append_op attrs=kwargs.get("attrs", None)) File "/opt/conda/envs/python35-paddle120-env/lib/python3.5/site-packages/paddle/fluid/layer_helper.py", line 43, in append_op return self.main_program.current_block().append_op(*args, **kwargs) File "/opt/conda/envs/python35-paddle120-env/lib/python3.5/site-packages/paddle/fluid/layers/detection.py", line 962, in yolo_box attrs=attrs) File "/home/aistudio/models/yolov3_tiny.py", line 226, in net name="yolo_box" + str(i)) File "infer_tiny.py", line 24, in infer model.net() File "infer_tiny.py", line 77, in infer()
infer函数 def infer():
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if not os.path.exists('output'):
os.mkdir('output')
model = YOLOv3Tiny(cfg.class_num,cfg.anchors,cfg.anchor_masks,False)
model.net()
boxes, scores = model.get_boxes_scores()
outputs = model.get_pred()
input_size = cfg.input_size
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# yapf: disable
if cfg.weights:
def if_exist(var):
return os.path.exists(os.path.join(cfg.weights, var.name))
fluid.io.load_vars(exe, cfg.weights, predicate=if_exist)
# yapf: enable
# you can save inference model by following code
fluid.io.save_inference_model("./output/yolov3tiny",
feeded_var_names=['image','im_shape'],
target_vars=[boxes, scores],
executor=exe)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
fetch_list = [outputs]
image_names = []
if cfg.image_name is not None:
image_names.append(cfg.image_name)
else:
for image_name in os.listdir(cfg.image_path):
if image_name.split('.')[-1] in ['jpg', 'png']:
image_names.append(image_name)
for image_name in image_names:
infer_reader = reader.infer(input_size,
os.path.join(cfg.image_path, image_name))
label_names, _ = reader.get_label_infos()
data = next(infer_reader())
im_shape = data[0][2]
outputs = exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data),
return_numpy=False)
bboxes = np.array(outputs[0])
if bboxes.shape[1] != 6:
print("No object found in {}".format(image_name))
continue
labels = bboxes[:, 0].astype('int32')
scores = bboxes[:, 1].astype('float32')
boxes = bboxes[:, 2:].astype('float32')
path = os.path.join(cfg.image_path, image_name)
box_utils.draw_boxes_on_image(path, boxes, scores, labels, label_names,
cfg.draw_thresh)
=================== from config_tiny import cfg class YOLOv3Tiny(object): def init(self, class_num, anchors, anchor_mask, is_train=True, use_random=True): self.outputs = [] self.downsample_ratio = 1 self.anchor_mask = anchor_mask self.anchors = anchors self.class_num = class_num self.is_train = is_train self.use_random = use_random self.losses = [] self.yolo_anchors = [] self.yolo_classes = [] for mask_pair in self.anchor_mask: mask_anchors = [] for mask in mask_pair: mask_anchors.append(self.anchors[2 * mask]) mask_anchors.append(self.anchors[2 * mask + 1]) self.yolo_anchors.append(mask_anchors) self.yolo_classes.append(class_num)
def name(self):
return 'YOLOv3-tiny'
def get_anchors(self):
return self.anchors
def get_anchor_mask(self):
return self.anchor_mask
def get_class_num(self):
return self.class_num
def get_downsample_ratio(self):
return self.downsample_ratio
def get_yolo_anchors(self):
return self.yolo_anchors
def get_yolo_classes(self):
return self.yolo_classes
def conv_bn(self,
input,
num_filters,
filter_size,
stride,
padding,
num_groups=1,
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
act=None,
groups=num_groups,
use_cudnn=use_cudnn,
param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=False)
# batch_norm中的参数不需要参与正则化,所以主动使用正则系数为0的正则项屏蔽掉
out = fluid.layers.batch_norm(
input=conv, act='relu',
param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02), regularizer=L2Decay(0.)),
bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0), regularizer=L2Decay(0.)))
return out
def depthwise_conv_bn(self, input, filter_size=3, stride=1, padding=1):
num_filters = input.shape[1]
return self.conv_bn(input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_filters)
def downsample(self, input, pool_size=2, pool_stride=2):
self.downsample_ratio *= 2
return fluid.layers.pool2d(input=input, pool_type='max', pool_size=pool_size,
pool_stride=pool_stride)
def basicblock(self, input, num_filters):
conv1 = self.conv_bn(input, num_filters, filter_size=3, stride=1, padding=1)
out = self.downsample(conv1)
return out
def upsample(self, input, scale=2):
# get dynamic upsample output shape
shape_nchw = fluid.layers.shape(input)
shape_hw = fluid.layers.slice(shape_nchw, axes=[0], starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = fluid.layers.cast(shape_hw, dtype='int32')
out_shape = in_shape * scale
out_shape.stop_gradient = True
# reisze by actual_shape
out = fluid.layers.resize_nearest(
input=input,
scale=scale,
actual_shape=out_shape)
return out
def yolo_detection_block(self, input, num_filters):
route = self.conv_bn(input, num_filters, filter_size=1, stride=1, padding=0)
tip = self.conv_bn(route, num_filters * 2, filter_size=3, stride=1, padding=1)
return route, tip
def feeds(self):
if not self.is_train:
return [self.image,self.im_id, self.im_shape]
return [self.image, self.gtbox, self.gtlabel, self.gtscore]
def build_input(self):
self.image_shape = [3, 416, 416]
if self.is_train:
self.py_reader = fluid.layers.py_reader(
capacity=64,
shapes=[[-1] + self.image_shape, [-1, cfg.max_box_num, 4],
[-1, cfg.max_box_num], [-1, cfg.max_box_num]],
lod_levels=[0, 0, 0, 0],
dtypes=['float32'] * 2 + ['int32'] + ['float32'],
use_double_buffer=True)
self.image, self.gtbox, self.gtlabel, self.gtscore = \
fluid.layers.read_file(self.py_reader)
else:
self.image = fluid.layers.data(
name='image', shape=self.image_shape, dtype='float32')
self.im_shape = fluid.layers.data(
name="im_shape", shape=[2], dtype='int32')
self.im_id = fluid.layers.data(
name="im_id", shape=[1], dtype='int32')
def net(self):
self.build_input()
self.boxes = []
self.scores = []
# darknet-tiny
stages = [16, 32, 64, 128, 256, 512]
assert len(self.anchor_mask) <= len(stages), "anchor masks can't bigger than downsample times"
# 256x256
tmp = self.image
blocks = []
for i, stage_count in enumerate(stages):
if i == len(stages) - 1:
block = self.conv_bn(tmp, stage_count, filter_size=3, stride=1, padding=1)
blocks.append(block)
block = self.depthwise_conv_bn(blocks[-1])
block = self.depthwise_conv_bn(blocks[-1])
block = self.conv_bn(blocks[-1], stage_count * 2, filter_size=1, stride=1, padding=0)
blocks.append(block)
else:
tmp = self.basicblock(tmp, stage_count)
blocks.append(tmp)
blocks = [blocks[-1], blocks[3]]
# yolo detector
for i, block in enumerate(blocks):
# yolo 中跨视域链接
if i > 0:
block = fluid.layers.concat(input=[route, block], axis=1)
if i < 1:
route, tip = self.yolo_detection_block(block, num_filters=256 // (2**i))
else:
tip = self.conv_bn(block, num_filters=256, filter_size=3, stride=1, padding=1)
block_out = fluid.layers.conv2d(
input=tip,
num_filters=len(self.anchor_mask[i]) * (self.class_num + 5), # 5 elements represent x|y|h|w|score
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0), regularizer=L2Decay(0.)))
self.outputs.append(block_out)
# 为了跨视域链接,差值方式提升特征图尺寸
if i < len(blocks) - 1:
route = self.conv_bn(route, 128 // (2**i), filter_size=1, stride=1, padding=0)
route = self.upsample(route)
for i, out in enumerate(self.outputs):
anchor_mask = self.anchor_mask[i]
if self.is_train:
loss = fluid.layers.yolov3_loss(
x=out,
gt_box=self.gtbox,
gt_label=self.gtlabel,
gt_score=self.gtscore,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.class_num,
ignore_thresh=cfg.ignore_thresh,
downsample_ratio=self.downsample_ratio,
use_label_smooth=bool(cfg.label_smooth),
name="yolo_loss" + str(i))
self.losses.append(fluid.layers.reduce_mean(loss))
else:
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(cfg.anchors[2 * m])
mask_anchors.append(cfg.anchors[2 * m + 1])
boxes, scores = fluid.layers.yolo_box(
x=out,
img_size=self.im_shape,
anchors=mask_anchors,
class_num=cfg.class_num,
conf_thresh=cfg.valid_thresh,
downsample_ratio=self.downsample_ratio,
name="yolo_box" + str(i))
self.boxes.append(boxes)
self.scores.append(
fluid.layers.transpose(
scores, perm=[0, 2, 1]))
self.downsample_ratio //= 2
return self.outputs
def loss(self):
return sum(self.losses)
def get_boxes_scores(self):
yolo_boxes = fluid.layers.concat(self.boxes, axis=1)
yolo_scores = fluid.layers.concat(self.scores, axis=2)
return yolo_boxes,yolo_scores
def get_pred(self):
yolo_boxes = fluid.layers.concat(self.boxes, axis=1)
yolo_scores = fluid.layers.concat(self.scores, axis=2)
return fluid.layers.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=cfg.valid_thresh,
nms_top_k=cfg.nms_topk,
keep_top_k=cfg.nms_posk,
nms_threshold=cfg.nms_thresh,
background_label=-1,
name="multiclass_nms")