未验证 提交 cb6f778b 编写于 作者: S suytingwan 提交者: GitHub

Metric learning api update (#4631)

* test=develop update one_hot usage according to 1.8

* test=develop update dataloader api by 1.8

* test=develop fix image shape

* test=develop update save load fluid.data
上级 9d2346a5
......@@ -52,8 +52,9 @@ def eval(args):
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[None, 1], dtype='int64')
image = fluid.data(name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
test_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
......@@ -75,7 +76,7 @@ def eval(args):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
fluid.load(program=test_program, model_path=pretrained_model, executor=exe)
test_loader.set_sample_generator(
reader.test(args),
......
......@@ -51,7 +51,7 @@ def infer(args):
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=[None] + image_shape, dtype='float32')
image = fluid.data(name='image', shape=[None] + image_shape, dtype='float32')
infer_loader = fluid.io.DataLoader.from_generator(
feed_list=[image],
......@@ -74,7 +74,7 @@ def infer(args):
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
fluid.load(model_path=pretrained_model, program=test_program, executor=exe)
infer_loader.set_sample_generator(
reader.test(args),
......
......@@ -108,9 +108,9 @@ def build_program(is_train, main_prog, startup_prog, args):
model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog):
queue_capacity = 64
image = fluid.layers.data(
image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data(
label = fluid.data(
name='label', shape=[None, 1], dtype='int64')
loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
......@@ -190,15 +190,14 @@ def train_async(args):
logging.debug('after run startup program')
if checkpoint is not None:
fluid.io.load_persistables(exe, checkpoint, main_program=train_prog)
fluid.load(program=train_prog, model_path=checkpoint, executor=exe)
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(
exe, pretrained_model, main_program=train_prog, predicate=if_exist)
fluid.load(program=train_prog, model_path=pretrained_model, executor=exe)
if args.use_gpu:
devicenum = get_gpu_num()
......@@ -287,7 +286,7 @@ def train_async(args):
str(iter_no))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
fluid.save(program=train_prog, model_path=model_path)
iter_no += 1
......
......@@ -115,9 +115,9 @@ def build_program(is_train, main_prog, startup_prog, args):
model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog):
queue_capacity = 64
image = fluid.layers.data(
image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data(
label = fluid.data(
name='label', shape=[None, 1], dtype='int64')
loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
......@@ -188,15 +188,15 @@ def train_async(args):
logging.debug('after run startup program')
if checkpoint is not None:
fluid.io.load_persistables(exe, checkpoint, main_program=train_prog)
fluid.load(program=train_prog, model_path=checkpoint, executor=exe)
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(
exe, pretrained_model, main_program=train_prog, predicate=if_exist)
fluid.load(program=train_prog, model_path=pretrained_model, executor=exe)
if args.use_gpu:
devicenum = get_gpu_num()
......@@ -283,7 +283,7 @@ def train_async(args):
str(iter_no))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog)
fluid.save(program=train_prog, model_path=model_path)
iter_no += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册