提交 5b54a7a5 编写于 作者: W wuzewu

Fix Detection Task issue

上级 88d1f7dc
......@@ -28,6 +28,7 @@ from paddlehub.common.logger import logger
from ..contrib.ppdet.data.source import build_source
from ..common import detection_config as dconf
class BaseCVDataset(BaseDataset):
def __init__(self,
base_path,
......@@ -164,10 +165,21 @@ class ImageClassificationDataset(object):
return self.test_examples
class ObjectDetectionDataset(ImageClassificationDataset):
def __init__(self, base_path, train_image_dir, train_list_file, validate_image_dir, validate_list_file,
test_image_dir, test_list_file, model_type='ssd'):
super(ObjectDetectionDataset, self).__init__()
class ObjectDetectionDataset(BaseCVDataset):
def __init__(self,
base_path,
train_image_dir,
train_list_file,
validate_image_dir,
validate_list_file,
test_image_dir,
test_list_file,
model_type='ssd'):
super(ObjectDetectionDataset, self).__init__(
base_path=base_path,
train_list_file=train_list_file,
validate_list_file=validate_list_file,
test_list_file=test_list_file)
self.base_path = base_path
self.train_image_dir = train_image_dir
self.train_list_file = train_list_file
......@@ -178,16 +190,10 @@ class ObjectDetectionDataset(ImageClassificationDataset):
self.model_type = model_type
self._dsc = None
self.cid2cname = None
self.label_dict() # refresh cid2cname and num_labels
assert self.cid2cname is not None
assert self.num_labels > 0
def label_dict(self):
if self.cid2cname is not None:
return self.cid2cname
# get label info from train data json
_ = self.train_data()
return self.cid2cname
self._val_data = None
self._train_data = None
self._test_data = None
self.train_data()
def _parse_data(self, data_path, image_dir, shuffle=False, phase=None):
with_background = dconf.conf[self.model_type]['with_background']
......@@ -214,9 +220,9 @@ class ObjectDetectionDataset(ImageClassificationDataset):
cid2cname = {v: k for k, v in cname2cid.items()}
self.cid2cname = cid2cname
if with_background:
self.num_labels = len(cid2cname) + 1
self.label_list = ['background'] + list(self.cid2cname.values())
else:
self.num_labels = len(cid2cname)
self.label_list = list(self.cid2cname.values())
if phase == 'train':
self.train_examples = data
......@@ -229,19 +235,26 @@ class ObjectDetectionDataset(ImageClassificationDataset):
def train_data(self, shuffle=True):
train_data_path = os.path.join(self.base_path, self.train_list_file)
train_image_dir = os.path.join(self.base_path, self.train_image_dir)
return self._parse_data(
train_data_path, train_image_dir, shuffle, phase='train')
if not self._train_data:
self._train_data = self._parse_data(
train_data_path, train_image_dir, shuffle, phase='train')
return self._train_data
def test_data(self, shuffle=False):
test_data_path = os.path.join(self.base_path, self.test_list_file)
test_image_dir = os.path.join(self.base_path, self.test_image_dir)
return self._parse_data(
test_data_path, test_image_dir, shuffle, phase='dev')
if not self._test_data:
self._test_data = self._parse_data(
test_data_path, test_image_dir, shuffle, phase='test')
return self._test_data
def validate_data(self, shuffle=False):
validate_data_path = os.path.join(self.base_path,
self.validate_list_file)
validate_image_dir = os.path.join(self.base_path,
self.validate_image_dir)
return self._parse_data(
validate_data_path, validate_image_dir, shuffle, phase='test')
if not self._val_data:
self._val_data = self._parse_data(
validate_data_path, validate_image_dir, shuffle, phase='dev')
return self._val_data
......@@ -411,8 +411,8 @@ class BaseTask(object):
if self.is_predict_phase or self.is_test_phase:
# Todo: paddle.fluid.core_avx.EnforceNotMet: Getting 'tensor_desc' is not supported by the type of var kCUDNNFwdAlgoCache. at
# self.env.main_program = clone_program(
# self.env.main_program, for_test=True)
self.env.main_program = clone_program(
self.env.main_program, for_test=True)
hub.common.paddle_helper.set_op_attr(
self.env.main_program, is_test=True)
......@@ -1063,7 +1063,8 @@ class BaseTask(object):
capacity=64,
use_double_buffer=True,
iterable=True)
data_reader = data_loader.set_sample_list_generator(self.reader, self.places[0])
data_reader = data_loader.set_sample_list_generator(
self.reader, self.places)
# data_reader = data_loader.set_batch_generator(
# self.reader, places=self.places)
else:
......@@ -1090,9 +1091,8 @@ class BaseTask(object):
return_numpy=False)
# fetch_result = [x if isinstance(x,fluid.LoDTensor) else np.array(x) for x in fetch_result]
fetch_result = [
x
if hasattr(x, 'recursive_sequence_lengths') else np.array(x)
for x in fetch_result
x if hasattr(x, 'recursive_sequence_lengths') else
np.array(x) for x in fetch_result
]
elif self.return_numpy:
fetch_result = self.exe.run(
......
......@@ -204,10 +204,13 @@ class DetectionTask(BaseTask):
inputs=feature_list,
image=image,
num_classes=self.num_classes,
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2., 3.], [2.],
[2.]],
base_size=512, # 300,
min_sizes=[20.0, 51.0, 133.0, 215.0, 296.0, 378.0, 460.0], # [60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
max_sizes=[51.0, 133.0, 215.0, 296.0, 378.0, 460.0, 542.0], # [[], 150.0, 195.0, 240.0, 285.0, 300.0],
min_sizes=[20.0, 51.0, 133.0, 215.0, 296.0, 378.0,
460.0], # [60.0, 105.0, 150.0, 195.0, 240.0, 285.0],
max_sizes=[51.0, 133.0, 215.0, 296.0, 378.0, 460.0,
542.0], # [[], 150.0, 195.0, 240.0, 285.0, 300.0],
steps=[8, 16, 32, 64, 128, 256, 512],
min_ratio=15,
max_ratio=90,
......@@ -242,7 +245,8 @@ class DetectionTask(BaseTask):
idx_list = [1, 2] # 'gt_box', 'gt_label'
elif self.is_test_phase:
# xTodo: remove 'im_shape' when using new module
idx_list = [2, 3, 4, 5] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
idx_list = [2, 3, 4,
5] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
else:
idx_list = [1] # im_id
return self._add_label_by_fields(idx_list)
......@@ -289,13 +293,17 @@ class DetectionTask(BaseTask):
# xTodo: update when using new module
# im_id, bbox, dets, loss
return [
self.base_feed_list[1], self.labels[0].name, self.outputs[0].name,
self.loss.name]
self.base_feed_list[1], self.labels[0].name,
self.outputs[0].name, self.loss.name
]
# im_shape, im_id, bbox
if for_export:
return [self.outputs[0].name]
else:
return [self.base_feed_list[1], self.labels[0].name, self.outputs[0].name]
return [
self.base_feed_list[1], self.labels[0].name,
self.outputs[0].name
]
def _rcnn_build_net(self):
if self.is_train_phase:
......@@ -306,30 +314,29 @@ class DetectionTask(BaseTask):
# Rename following layers for: ValueError: Variable cls_score_w has been created before.
# the previous shape is (2048, 81); the new shape is (100352, 81).
# They are not matched.
cls_score = fluid.layers.fc(input=head_feat,
size=self.num_classes,
act=None,
name='my_cls_score',
param_attr=ParamAttr(
name='my_cls_score_w',
initializer=Normal(
loc=0.0, scale=0.01)),
bias_attr=ParamAttr(
name='my_cls_score_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
bbox_pred = fluid.layers.fc(input=head_feat,
size=4 * self.num_classes,
act=None,
name='my_bbox_pred',
param_attr=ParamAttr(
name='my_bbox_pred_w',
initializer=Normal(
loc=0.0, scale=0.001)),
bias_attr=ParamAttr(
name='my_bbox_pred_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
cls_score = fluid.layers.fc(
input=head_feat,
size=self.num_classes,
act=None,
name='my_cls_score',
param_attr=ParamAttr(
name='my_cls_score_w', initializer=Normal(loc=0.0, scale=0.01)),
bias_attr=ParamAttr(
name='my_cls_score_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
bbox_pred = fluid.layers.fc(
input=head_feat,
size=4 * self.num_classes,
act=None,
name='my_bbox_pred',
param_attr=ParamAttr(
name='my_bbox_pred_w', initializer=Normal(loc=0.0,
scale=0.001)),
bias_attr=ParamAttr(
name='my_bbox_pred_b',
learning_rate=2.,
regularizer=L2Decay(0.)))
if self.is_train_phase:
rpn_cls_loss, rpn_reg_loss, outs = self.feature[1:]
......@@ -349,7 +356,8 @@ class DetectionTask(BaseTask):
outside_weight=bbox_outside_weights,
sigma=1.0)
loss_bbox = fluid.layers.reduce_mean(loss_bbox)
total_loss = fluid.layers.sum([loss_bbox, loss_cls, rpn_cls_loss, rpn_reg_loss])
total_loss = fluid.layers.sum(
[loss_bbox, loss_cls, rpn_cls_loss, rpn_reg_loss])
return [total_loss]
else:
rois = self.predict_feature[1]
......@@ -359,33 +367,41 @@ class DetectionTask(BaseTask):
im_scale = fluid.layers.sequence_expand(im_scale, rois)
boxes = rois / im_scale
cls_prob = fluid.layers.softmax(cls_score, use_cudnn=False)
bbox_pred = fluid.layers.reshape(bbox_pred, (-1, self.num_classes, 4))
bbox_pred = fluid.layers.reshape(bbox_pred,
(-1, self.num_classes, 4))
# decoded_box = self.box_coder(prior_box=boxes, target_box=bbox_pred)
decoded_box = fluid.layers.box_coder(
prior_box=boxes, prior_box_var=[0.1, 0.1, 0.2, 0.2],
target_box=bbox_pred, code_type='decode_center_size',
box_normalized=False, axis=1)
cliped_box = fluid.layers.box_clip(input=decoded_box, im_info=im_shape)
prior_box=boxes,
prior_box_var=[0.1, 0.1, 0.2, 0.2],
target_box=bbox_pred,
code_type='decode_center_size',
box_normalized=False,
axis=1)
cliped_box = fluid.layers.box_clip(
input=decoded_box, im_info=im_shape)
# pred_result = self.nms(bboxes=cliped_box, scores=cls_prob)
pred_result = fluid.layers.multiclass_nms(
bboxes=cliped_box, scores=cls_prob,
bboxes=cliped_box,
scores=cls_prob,
score_threshold=.05,
nms_top_k=-1,
keep_top_k=100,
nms_threshold=.5,
normalized=False,
nms_eta=1.0,
background_label=0
)
background_label=0)
if self.is_predict_phase:
self.env.labels = self._rcnn_add_label()
return [pred_result]
def _rcnn_add_label(self):
if self.is_train_phase:
idx_list = [2,] # 'im_id'
idx_list = [
2,
] # 'im_id'
elif self.is_test_phase:
idx_list = [2, 4, 5, 6] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
idx_list = [2, 4, 5,
6] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
else: # predict
idx_list = [2]
return self._add_label_by_fields(idx_list)
......@@ -394,7 +410,8 @@ class DetectionTask(BaseTask):
if self.is_train_phase:
loss = self.env.outputs[-1]
else:
loss = fluid.layers.fill_constant(shape=[1], value=-1, dtype='float32')
loss = fluid.layers.fill_constant(
shape=[1], value=-1, dtype='float32')
return loss
def _rcnn_feed_list(self, for_export=False):
......@@ -417,14 +434,19 @@ class DetectionTask(BaseTask):
if self.is_train_phase:
return [self.loss.name]
elif self.is_test_phase:
# im_shape, im_id, bbox
return [self.feed_list[2], self.labels[0].name, self.outputs[0].name, self.loss.name]
# im_shape, im_id, bbox
return [
self.feed_list[2], self.labels[0].name, self.outputs[0].name,
self.loss.name
]
# im_shape, im_id, bbox
if for_export:
return [self.outputs[0].name]
else:
return [self.feed_list[2], self.labels[0].name, self.outputs[0].name]
return [
self.feed_list[2], self.labels[0].name, self.outputs[0].name
]
def _yolo_parse_anchors(self, anchors):
"""
......@@ -449,8 +471,8 @@ class DetectionTask(BaseTask):
def _yolo_build_net(self):
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
self._yolo_parse_anchors(anchors)
tip_list = self.feature
......@@ -466,7 +488,8 @@ class DetectionTask(BaseTask):
padding=0,
act=None,
# Rename for: conflict with module pretrain weights
param_attr=ParamAttr(name="ft_yolo_output.{}.conv.weights".format(i)),
param_attr=ParamAttr(
name="ft_yolo_output.{}.conv.weights".format(i)),
bias_attr=ParamAttr(
regularizer=L2Decay(0.),
name="ft_yolo_output.{}.conv.bias".format(i)))
......@@ -495,7 +518,8 @@ class DetectionTask(BaseTask):
yolo_scores = fluid.layers.concat(scores, axis=2)
# pred = self.nms(bboxes=yolo_boxes, scores=yolo_scores)
pred = fluid.layers.multiclass_nms(
bboxes=yolo_boxes, scores=yolo_scores,
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=.01,
nms_top_k=1000,
keep_top_k=100,
......@@ -511,7 +535,8 @@ class DetectionTask(BaseTask):
if self.is_train_phase:
idx_list = [1, 2, 3] # 'gt_box', 'gt_label', 'gt_score'
elif self.is_test_phase:
idx_list = [2, 3, 4, 5] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
idx_list = [2, 3, 4,
5] # 'im_id', 'gt_box', 'gt_label', 'is_difficult'
else: # predict
idx_list = [2]
return self._add_label_by_fields(idx_list)
......@@ -541,7 +566,8 @@ class DetectionTask(BaseTask):
loss = sum(losses)
else:
loss = fluid.layers.fill_constant(shape=[1], value=-1, dtype='float32')
loss = fluid.layers.fill_constant(
shape=[1], value=-1, dtype='float32')
return loss
def _yolo_feed_list(self, for_export=False):
......@@ -560,14 +586,19 @@ class DetectionTask(BaseTask):
if self.is_train_phase:
return [self.loss.name]
elif self.is_test_phase:
# im_shape, im_id, bbox
return [self.feed_list[1], self.labels[0].name, self.outputs[0].name, self.loss.name]
# im_shape, im_id, bbox
return [
self.feed_list[1], self.labels[0].name, self.outputs[0].name,
self.loss.name
]
# im_shape, im_id, bbox
if for_export:
return [self.outputs[0].name]
else:
return [self.feed_list[1], self.labels[0].name, self.outputs[0].name]
return [
self.feed_list[1], self.labels[0].name, self.outputs[0].name
]
def _build_net(self):
if self.model_type == 'ssd':
......@@ -694,14 +725,15 @@ class DetectionTask(BaseTask):
is_bbox_normalized = dconf.conf[self.model_type]['is_bbox_normalized']
eval_feed = Feed()
eval_feed.with_background = dconf.conf[self.model_type]['with_background']
eval_feed.with_background = dconf.conf[
self.model_type]['with_background']
eval_feed.dataset = self.reader
for metric in self.metrics_choices:
if metric == "ap":
box_ap_stats = eval_results(results, eval_feed, 'COCO',
self.num_classes, None,
is_bbox_normalized, None, None)
box_ap_stats = eval_results(
results, eval_feed, 'COCO', self.num_classes, None,
is_bbox_normalized, self.config.checkpoint_dir)
print("box_ap_stats", box_ap_stats)
scores["ap"] = box_ap_stats[0]
else:
......
......@@ -18,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddlehub.module.checkinfo',
syntax='proto3',
serialized_pb=_b(
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires\x12\x1b\n\x13module_code_version\x18\x06 \x01(\t*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -35,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=551,
serialized_end=581,
serialized_start=522,
serialized_end=552,
)
_sym_db.RegisterEnumDescriptor(_FILE_TYPE)
......@@ -60,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=583,
serialized_end=674,
serialized_start=554,
serialized_end=645,
)
_sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE)
......@@ -346,22 +346,6 @@ _CHECKINFO = _descriptor.Descriptor(
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='module_code_version',
full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version',
index=5,
number=6,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
......@@ -372,7 +356,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=320,
serialized_end=549,
serialized_end=520,
)
_FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE
......
......@@ -185,17 +185,15 @@ class ImageClassificationReader(BaseReader):
return _data_reader
class ObjectDetectionReader(ImageClassificationReader):
class ObjectDetectionReader(BaseReader):
def __init__(self,
dataset=None,
model_type='ssd',
channel_order="RGB",
worker_num=1,
use_process=False,
):
super(ObjectDetectionReader,
self).__init__(1, 1, dataset, channel_order,
None, None, None)
random_seed=None):
super(ObjectDetectionReader, self).__init__(
dataset, random_seed=random_seed)
self.model_type = model_type
self.worker_num = worker_num
self.use_process = use_process
......@@ -205,8 +203,7 @@ class ObjectDetectionReader(ImageClassificationReader):
phase="train",
shuffle=False,
data=None,
return_list=False
):
return_list=False):
if phase != 'predict' and not self.dataset:
raise ValueError("The dataset is none and it's not allowed!")
drop_last = False
......@@ -224,10 +221,7 @@ class ObjectDetectionReader(ImageClassificationReader):
self.num_examples['dev'] = len(self.get_dev_examples())
else: # phase == "predict":
from ..contrib.ppdet.data.source import build_source
data_config = {
"IMAGES": data,
"TYPE": "SimpleSource"
}
data_config = {"IMAGES": data, "TYPE": "SimpleSource"}
data_src = build_source(data_config)
data_cf = {}
......@@ -243,11 +237,7 @@ class ObjectDetectionReader(ImageClassificationReader):
'USE_PADDED_IM_INFO': False,
}
phase_trans = {
"val": "dev",
"test": "dev",
"inference": "predict"
}
phase_trans = {"val": "dev", "test": "dev", "inference": "predict"}
if phase in phase_trans:
phase = phase_trans[phase]
assert phase in ('train', 'dev', 'predict')
......@@ -256,17 +246,8 @@ class ObjectDetectionReader(ImageClassificationReader):
ppdet_mode = 'VAL' if phase != 'train' else 'TRAIN'
_batch_reader = Reader.create(
batch_reader = Reader.create(
ppdet_mode, data_cf, transform_config, my_source=data_src)
# return itr
# When call `_batch_reader()`, then return generator(or iterator)
def batch_reader():
"""batch reader"""
for b in _batch_reader():
if return_list:
yield [b]
else:
yield b
batch_reader.annotation = _batch_reader.annotation
return batch_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册