diff --git a/fluid/face_detction/train.py b/fluid/face_detction/train.py index 83926115166db5bcd5fc4a56ce30ab2ca7cf3ef6..fb2d17b018cceda74f973b2879a8267705844b17 100644 --- a/fluid/face_detction/train.py +++ b/fluid/face_detction/train.py @@ -1,3 +1,4 @@ +import os import numpy as np import argparse import functools @@ -13,11 +14,13 @@ add_arg = functools.partial(add_arguments, argparser=parser) add_arg('batch_size', int, 32, "Minibatch size.") add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('parallel', bool, True, "Parallel.") +add_arg('pretrained_model', str, "./vgg_model/", "The init model path.") #yapf: enable def train(args, learning_rate, - batch_size): + batch_size, + pretrained_model): network = PyramidBox([3, 640, 640]) face_loss, head_loss = network.train() @@ -37,9 +40,16 @@ def train(args, place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - #print(fluid.default_main_program()) - #print(test_program) - #fluid.io.save_persistables(exe, "model") + + # fluid.io.save_inference_model('./vgg_model/', ['image'], [loss], exe) + if pretrained_model: + def if_exist(var): + return os.path.exists(os.path.join(pretrained_model, var.name)) + fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) + + # print(fluid.default_main_program()) + # print(test_program) + # fluid.io.save_persistables(exe, "model") if __name__ == '__main__': @@ -48,4 +58,5 @@ if __name__ == '__main__': train(args, learning_rate=0.01, - batch_size=args.batch_size) + batch_size=args.batch_size, + pretrained_model=args.pretrained_model)