提交 d3ff11c0 编写于 作者: B baiyf 提交者: qingqing01

Load pretrained VGG model for face detection. (#925)

* load pretrained vgg model

* load pretrained vgg model

* delete irrelevant tools

* delete unused lines
上级 c2ed053c
import os
import numpy as np
import argparse
import functools
......@@ -13,11 +14,13 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('parallel', bool, True, "Parallel.")
add_arg('pretrained_model', str, "./vgg_model/", "The init model path.")
#yapf: enable
def train(args,
learning_rate,
batch_size):
batch_size,
pretrained_model):
network = PyramidBox([3, 640, 640])
face_loss, head_loss = network.train()
......@@ -37,9 +40,16 @@ def train(args,
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#print(fluid.default_main_program())
#print(test_program)
#fluid.io.save_persistables(exe, "model")
# fluid.io.save_inference_model('./vgg_model/', ['image'], [loss], exe)
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
# print(fluid.default_main_program())
# print(test_program)
# fluid.io.save_persistables(exe, "model")
if __name__ == '__main__':
......@@ -48,4 +58,5 @@ if __name__ == '__main__':
train(args,
learning_rate=0.01,
batch_size=args.batch_size)
batch_size=args.batch_size,
pretrained_model=args.pretrained_model)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册