未验证 提交 73c861ef 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph] add yolov3 voc configs for dygraph (#2337)

* add yolov3 voc configs for dygraph

* simplify yolov3 darknet voc configs

* support num_classes shared from dataset config
上级 73acfad8
_BASE_: [
'../datasets/voc.yml',
'../runtime.yml',
'_base_/optimizer_270e.yml',
'_base_/yolov3_darknet53.yml',
'_base_/yolov3_reader.yml',
]
snapshot_epoch: 5
weights: output/yolov3_darknet53_270e_voc/model_final
...@@ -43,9 +43,11 @@ class Compose(object): ...@@ -43,9 +43,11 @@ class Compose(object):
for t in self.transforms: for t in self.transforms:
for k, v in t.items(): for k, v in t.items():
op_cls = getattr(transform, k) op_cls = getattr(transform, k)
self.transforms_cls.append(op_cls(**v)) f = op_cls(**v)
if hasattr(op_cls, 'num_classes'): if hasattr(f, 'num_classes'):
op_cls.num_classes = num_classes f.num_classes = num_classes
self.transforms_cls.append(f)
def __call__(self, data): def __call__(self, data):
for f in self.transforms_cls: for f in self.transforms_cls:
...@@ -109,8 +111,6 @@ class BatchCompose(Compose): ...@@ -109,8 +111,6 @@ class BatchCompose(Compose):
class BaseDataLoader(object): class BaseDataLoader(object):
__share__ = ['num_classes']
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=[], sample_transforms=[],
...@@ -194,6 +194,8 @@ class BaseDataLoader(object): ...@@ -194,6 +194,8 @@ class BaseDataLoader(object):
@register @register
class TrainReader(BaseDataLoader): class TrainReader(BaseDataLoader):
__shared__ = ['num_classes']
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=[], sample_transforms=[],
...@@ -211,6 +213,8 @@ class TrainReader(BaseDataLoader): ...@@ -211,6 +213,8 @@ class TrainReader(BaseDataLoader):
@register @register
class EvalReader(BaseDataLoader): class EvalReader(BaseDataLoader):
__shared__ = ['num_classes']
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=[], sample_transforms=[],
...@@ -228,6 +232,8 @@ class EvalReader(BaseDataLoader): ...@@ -228,6 +232,8 @@ class EvalReader(BaseDataLoader):
@register @register
class TestReader(BaseDataLoader): class TestReader(BaseDataLoader):
__shared__ = ['num_classes']
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=[], sample_transforms=[],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册