提交 0b8b4815 编写于 作者: D dengkaipeng

use in_dygraph_mode

上级 4643cad7
...@@ -97,7 +97,7 @@ def test_bmn(args): ...@@ -97,7 +97,7 @@ def test_bmn(args):
eval_dataset = BmnDataset(eval_cfg, 'test') eval_dataset = BmnDataset(eval_cfg, 'test')
#model #model
model = bmn(config, args.dynamic, pretrained=args.weights is None) model = bmn(config, pretrained=args.weights is None)
model.prepare( model.prepare(
loss_function=BmnLoss(config), loss_function=BmnLoss(config),
metrics=BmnMetric( metrics=BmnMetric(
......
...@@ -92,7 +92,7 @@ def infer_bmn(args): ...@@ -92,7 +92,7 @@ def infer_bmn(args):
#data #data
infer_dataset = BmnDataset(infer_cfg, 'infer') infer_dataset = BmnDataset(infer_cfg, 'infer')
model = bmn(config, args.dynamic, pretrained=args.weights is None) model = bmn(config, pretrained=args.weights is None)
model.prepare( model.prepare(
metrics=BmnMetric( metrics=BmnMetric(
config, mode='infer'), config, mode='infer'),
......
...@@ -136,7 +136,7 @@ def train_bmn(args): ...@@ -136,7 +136,7 @@ def train_bmn(args):
val_dataset = BmnDataset(val_cfg, 'valid') val_dataset = BmnDataset(val_cfg, 'valid')
# model # model
model = bmn(config, args.dynamic, pretrained=False) model = bmn(config, pretrained=False)
optim = optimizer(config, parameter_list=model.parameters()) optim = optimizer(config, parameter_list=model.parameters())
model.prepare( model.prepare(
optimizer=optim, optimizer=optim,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import ParamAttr from paddle.fluid import ParamAttr
from paddle.fluid.framework import in_dygraph_mode
import numpy as np import numpy as np
import math import math
...@@ -131,9 +132,8 @@ class BMN(Model): ...@@ -131,9 +132,8 @@ class BMN(Model):
Args: Args:
cfg (AttrDict): configs for BMN model cfg (AttrDict): configs for BMN model
is_dygraph (bool): whether in dygraph mode, default True.
""" """
def __init__(self, cfg, is_dygraph=True): def __init__(self, cfg):
super(BMN, self).__init__() super(BMN, self).__init__()
#init config #init config
...@@ -142,7 +142,6 @@ class BMN(Model): ...@@ -142,7 +142,6 @@ class BMN(Model):
self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio self.prop_boundary_ratio = cfg.MODEL.prop_boundary_ratio
self.num_sample = cfg.MODEL.num_sample self.num_sample = cfg.MODEL.num_sample
self.num_sample_perbin = cfg.MODEL.num_sample_perbin self.num_sample_perbin = cfg.MODEL.num_sample_perbin
self.is_dygraph = is_dygraph
self.hidden_dim_1d = 256 self.hidden_dim_1d = 256
self.hidden_dim_2d = 128 self.hidden_dim_2d = 128
...@@ -197,7 +196,7 @@ class BMN(Model): ...@@ -197,7 +196,7 @@ class BMN(Model):
sample_mask_array = get_interp1d_mask( sample_mask_array = get_interp1d_mask(
self.tscale, self.dscale, self.prop_boundary_ratio, self.tscale, self.dscale, self.prop_boundary_ratio,
self.num_sample, self.num_sample_perbin) self.num_sample, self.num_sample_perbin)
if self.is_dygraph: if in_dygraph_mode():
self.sample_mask = fluid.dygraph.base.to_variable( self.sample_mask = fluid.dygraph.base.to_variable(
sample_mask_array) sample_mask_array)
else: # static else: # static
...@@ -438,16 +437,15 @@ class BmnLoss(Loss): ...@@ -438,16 +437,15 @@ class BmnLoss(Loss):
return loss return loss
def bmn(cfg, is_dygraph=True, pretrained=True): def bmn(cfg, pretrained=True):
"""BMN model """BMN model
Args: Args:
cfg (AttrDict): configs for BMN model cfg (AttrDict): configs for BMN model
is_dygraph (bool): whether in dygraph mode, default True.
pretrained (bool): If True, returns a model with pre-trained model pretrained (bool): If True, returns a model with pre-trained model
on COCO, default True on COCO, default True
""" """
model = BMN(cfg, is_dygraph=is_dygraph) model = BMN(cfg)
if pretrained: if pretrained:
weight_path = get_weights_path(*(pretrain_infos['bmn'])) weight_path = get_weights_path(*(pretrain_infos['bmn']))
assert weight_path.endswith('.pdparams'), \ assert weight_path.endswith('.pdparams'), \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册