提交 0b8b4815 编写于 作者: D dengkaipeng

use in_dygraph_mode

上级 4643cad7
......@@ -97,7 +97,7 @@ def test_bmn(args):
eval_dataset = BmnDataset(eval_cfg, 'test')
#model
model = bmn(config, args.dynamic, pretrained=args.weights is None)
model = bmn(config, pretrained=args.weights is None)
model.prepare(
loss_function=BmnLoss(config),
metrics=BmnMetric(
......
......@@ -92,7 +92,7 @@ def infer_bmn(args):
#data
infer_dataset = BmnDataset(infer_cfg, 'infer')
model = bmn(config, args.dynamic, pretrained=args.weights is None)
model = bmn(config, pretrained=args.weights is None)
model.prepare(
metrics=BmnMetric(
config, mode='infer'),
......
......@@ -136,7 +136,7 @@ def train_bmn(args):
val_dataset = BmnDataset(val_cfg, 'valid')
# model
model = bmn(config, args.dynamic, pretrained=False)
model = bmn(config, pretrained=False)
optim = optimizer(config, parameter_list=model.parameters())
model.prepare(
optimizer=optim,
......
......@@ -14,6 +14,7 @@
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid.framework import in_dygraph_mode
import numpy as np
import math
......@@ -131,9 +132,8 @@ class BMN(Model):
Args:
cfg (AttrDict): configs for BMN model
is_dygraph (bool): whether in dygraph mode, default True.
"""
def __init__(self, cfg, is_dygraph=True):
def __init__(self, cfg):
super(BMN, self).__init__()
#init config
......@@ -142,7 +142,6 @@ class BMN(Model):
self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
self.num_sample = cfg.MODEL.num_sample
self.num_sample_perbin = cfg.MODEL.num_sample_perbin
self.is_dygraph = is_dygraph
self.hidden_dim_1d = 256
self.hidden_dim_2d = 128
......@@ -197,7 +196,7 @@ class BMN(Model):
sample_mask_array = get_interp1d_mask(
self.tscale, self.dscale, self.prop_boundary_ratio,
self.num_sample, self.num_sample_perbin)
if self.is_dygraph:
if in_dygraph_mode():
self.sample_mask = fluid.dygraph.base.to_variable(
sample_mask_array)
else: # static
......@@ -438,16 +437,15 @@ class BmnLoss(Loss):
return loss
def bmn(cfg, is_dygraph=True, pretrained=True):
def bmn(cfg, pretrained=True):
"""BMN model
Args:
cfg (AttrDict): configs for BMN model
is_dygraph (bool): whether in dygraph mode, default True.
pretrained (bool): If True, returns a model with pre-trained model
on COCO, default True
"""
model = BMN(cfg, is_dygraph=is_dygraph)
model = BMN(cfg)
if pretrained:
weight_path = get_weights_path(*(pretrain_infos['bmn']))
assert weight_path.endswith('.pdparams'), \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册