提交 881f62a3 编写于 作者: D dengkaipeng

add pretrain info in stnet.py

上级 0741c752
......@@ -139,10 +139,10 @@ class ModelBase(object):
logger.info("Download pretrain weights of {} from {}".format(
self.name, url))
utils.download(url, path)
download(url, path)
return path
def load_pretrain_params(self, exe, pretrain, prog):
def load_pretrain_params(self, exe, pretrain, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_base, var.name))
fluid.io.load_params(exe, pretrain, main_program=prog)
......
......@@ -122,7 +122,10 @@ class STNET(ModelBase):
self.label_input
]
def load_pretrain_params(self, exe, pretrain, prog):
def pretrain_info(self):
return ('ResNet50_pretrained', 'https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz')
def load_pretrain_params(self, exe, pretrain, prog, place):
def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
......
......@@ -140,7 +140,6 @@ def train(args):
valid_feeds = valid_model.feeds()
valid_outputs = valid_model.outputs()
valid_loss = valid_model.loss()
#valid_metrics = valid_model.metrics()
valid_pyreader = valid_model.pyreader()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册