提交 4643cad7 编写于 作者: D dengkaipeng

update BMN

上级 d8541eac
......@@ -19,7 +19,7 @@ import logging
import paddle.fluid as fluid
from hapi.model import set_device, Input
from hapi.vision.models import BMN, BmnLoss
from hapi.vision.models import bmn, BmnLoss
from bmn_metric import BmnMetric
from reader import BmnDataset
from config_utils import *
......@@ -53,7 +53,7 @@ def parse_args():
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
default=None,
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
......@@ -97,7 +97,7 @@ def test_bmn(args):
eval_dataset = BmnDataset(eval_cfg, 'test')
#model
model = BMN(config, args.dynamic)
model = bmn(config, args.dynamic, pretrained=args.weights is None)
model.prepare(
loss_function=BmnLoss(config),
metrics=BmnMetric(
......@@ -107,11 +107,11 @@ def test_bmn(args):
device=device)
#load checkpoint
if args.weights:
if args.weights is not None:
assert os.path.exists(args.weights + '.pdparams'), \
"Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
model.evaluate(
eval_data=eval_dataset,
......
......@@ -19,7 +19,7 @@ import logging
import paddle.fluid as fluid
from hapi.model import set_device, Input
from hapi.vision.models import BMN, BmnLoss
from hapi.vision.models import bmn, BmnLoss
from bmn_metric import BmnMetric
from reader import BmnDataset
from config_utils import *
......@@ -50,7 +50,7 @@ def parse_args():
parser.add_argument(
'--weights',
type=str,
default="checkpoint/final",
default=None,
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
......@@ -92,7 +92,7 @@ def infer_bmn(args):
#data
infer_dataset = BmnDataset(infer_cfg, 'infer')
model = BMN(config, args.dynamic)
model = bmn(config, args.dynamic, pretrained=args.weights is None)
model.prepare(
metrics=BmnMetric(
config, mode='infer'),
......@@ -101,12 +101,12 @@ def infer_bmn(args):
device=device)
# load checkpoint
if args.weights:
if args.weights is not None:
assert os.path.exists(
args.weights +
".pdparams"), "Given weight dir {} not exist.".format(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
logger.info('load test weights from {}'.format(args.weights))
model.load(args.weights)
# here use model.eval instead of model.test, as post process is required in our case
model.evaluate(
......
......@@ -21,7 +21,7 @@ import sys
sys.path.append('../')
from distributed import DistributedBatchSampler
from hapi.distributed import DistributedBatchSampler
from paddle.io import Dataset, DataLoader
logger = logging.getLogger(__name__)
......
......@@ -19,7 +19,7 @@ import sys
import os
from hapi.model import set_device, Input
from hapi.vision.models import BMN, BmnLoss
from hapi.vision.models import bmn, BmnLoss
from reader import BmnDataset
from config_utils import *
......@@ -136,7 +136,7 @@ def train_bmn(args):
val_dataset = BmnDataset(val_cfg, 'valid')
# model
model = BMN(config, args.dynamic)
model = bmn(config, args.dynamic, pretrained=False)
optim = optimizer(config, parameter_list=model.parameters())
model.prepare(
optimizer=optim,
......
......@@ -19,7 +19,7 @@ from . import mobilenetv2
from . import darknet
from . import yolov3
from . import tsm
from . import bmn
from . import bmn_model
from .resnet import *
from .mobilenetv1 import *
......@@ -28,7 +28,7 @@ from .vgg import *
from .darknet import *
from .yolov3 import *
from .tsm import *
from .bmn import *
from .bmn_model import *
__all__ = resnet.__all__ \
+ vgg.__all__ \
......@@ -37,4 +37,4 @@ __all__ = resnet.__all__ \
+ darknet.__all__ \
+ yolov3.__all__ \
+ tsm.__all__ \
+ bmn.__all__
+ bmn_model.__all__
......@@ -18,11 +18,17 @@ import numpy as np
import math
from hapi.model import Model, Loss
from hapi.download import get_weights_path
__all__ = ["BMN", "BmnLoss"]
__all__ = ["BMN", "BmnLoss", "bmn"]
DATATYPE = 'float32'
pretrain_infos = {
'bmn': ('https://paddlemodels.bj.bcebos.com/hapi/bmn.pdparams',
'9286c821acc4cad46d6613b931ba468c')
}
def _get_interp1d_bin_mask(seg_xmin, seg_xmax, tscale, num_sample,
num_sample_perbin):
......@@ -120,6 +126,13 @@ class Conv1D(fluid.dygraph.Layer):
class BMN(Model):
"""BMN model from
`"BMN: Boundary-Matching Network for Temporal Action Proposal Generation" <https://arxiv.org/abs/1907.09702>`_
Args:
cfg (AttrDict): configs for BMN model
is_dygraph (bool): whether in dygraph mode, default True.
"""
def __init__(self, cfg, is_dygraph=True):
super(BMN, self).__init__()
......@@ -277,6 +290,11 @@ class BMN(Model):
class BmnLoss(Loss):
"""Loss for BMN model
Args:
cfg (AttrDict): configs for BMN model
"""
def __init__(self, cfg):
super(BmnLoss, self).__init__()
self.cfg = cfg
......@@ -418,3 +436,21 @@ class BmnLoss(Loss):
loss = tem_loss + 10 * pem_reg_loss + pem_cls_loss
return loss
def bmn(cfg, is_dygraph=True, pretrained=True):
"""BMN model
Args:
cfg (AttrDict): configs for BMN model
is_dygraph (bool): whether in dygraph mode, default True.
pretrained (bool): If True, returns a model with pre-trained model
on COCO, default True
"""
model = BMN(cfg, is_dygraph=is_dygraph)
if pretrained:
weight_path = get_weights_path(*(pretrain_infos['bmn']))
assert weight_path.endswith('.pdparams'), \
"suffix of weight must be .pdparams"
model.load(weight_path[:-9])
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册