未验证 提交 cb6f778b 编写于 作者: S suytingwan 提交者: GitHub

Metric learning api update (#4631)

* test=develop update one_hot usage according to 1.8

* test=develop update dataloader api by 1.8

* test=develop fix image shape

* test=develop update save load fluid.data
上级 9d2346a5
...@@ -52,8 +52,9 @@ def eval(args): ...@@ -52,8 +52,9 @@ def eval(args):
assert model_name in model_list, "{} is not in lists: {}".format(args.model, assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list) model_list)
image = fluid.layers.data(name='image', shape=[None] + image_shape, dtype='float32') image = fluid.data(name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[None, 1], dtype='int64') label = fluid.data(name='label', shape=[None, 1], dtype='int64')
test_loader = fluid.io.DataLoader.from_generator( test_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label], feed_list=[image, label],
capacity=64, capacity=64,
...@@ -75,7 +76,7 @@ def eval(args): ...@@ -75,7 +76,7 @@ def eval(args):
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) fluid.load(program=test_program, model_path=pretrained_model, executor=exe)
test_loader.set_sample_generator( test_loader.set_sample_generator(
reader.test(args), reader.test(args),
......
...@@ -51,7 +51,7 @@ def infer(args): ...@@ -51,7 +51,7 @@ def infer(args):
assert model_name in model_list, "{} is not in lists: {}".format(args.model, assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list) model_list)
image = fluid.layers.data(name='image', shape=[None] + image_shape, dtype='float32') image = fluid.data(name='image', shape=[None] + image_shape, dtype='float32')
infer_loader = fluid.io.DataLoader.from_generator( infer_loader = fluid.io.DataLoader.from_generator(
feed_list=[image], feed_list=[image],
...@@ -74,7 +74,7 @@ def infer(args): ...@@ -74,7 +74,7 @@ def infer(args):
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) fluid.load(model_path=pretrained_model, program=test_program, executor=exe)
infer_loader.set_sample_generator( infer_loader.set_sample_generator(
reader.test(args), reader.test(args),
......
...@@ -108,9 +108,9 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -108,9 +108,9 @@ def build_program(is_train, main_prog, startup_prog, args):
model = models.__dict__[args.model]() model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
queue_capacity = 64 queue_capacity = 64
image = fluid.layers.data( image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32') name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data( label = fluid.data(
name='label', shape=[None, 1], dtype='int64') name='label', shape=[None, 1], dtype='int64')
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label], feed_list=[image, label],
...@@ -190,15 +190,14 @@ def train_async(args): ...@@ -190,15 +190,14 @@ def train_async(args):
logging.debug('after run startup program') logging.debug('after run startup program')
if checkpoint is not None: if checkpoint is not None:
fluid.io.load_persistables(exe, checkpoint, main_program=train_prog) fluid.load(program=train_prog, model_path=checkpoint, executor=exe)
if pretrained_model: if pretrained_model:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars( fluid.load(program=train_prog, model_path=pretrained_model, executor=exe)
exe, pretrained_model, main_program=train_prog, predicate=if_exist)
if args.use_gpu: if args.use_gpu:
devicenum = get_gpu_num() devicenum = get_gpu_num()
...@@ -287,7 +286,7 @@ def train_async(args): ...@@ -287,7 +286,7 @@ def train_async(args):
str(iter_no)) str(iter_no))
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog) fluid.save(program=train_prog, model_path=model_path)
iter_no += 1 iter_no += 1
......
...@@ -115,9 +115,9 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -115,9 +115,9 @@ def build_program(is_train, main_prog, startup_prog, args):
model = models.__dict__[args.model]() model = models.__dict__[args.model]()
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
queue_capacity = 64 queue_capacity = 64
image = fluid.layers.data( image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32') name='image', shape=[None] + image_shape, dtype='float32')
label = fluid.layers.data( label = fluid.data(
name='label', shape=[None, 1], dtype='int64') name='label', shape=[None, 1], dtype='int64')
loader = fluid.io.DataLoader.from_generator( loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label], feed_list=[image, label],
...@@ -188,15 +188,15 @@ def train_async(args): ...@@ -188,15 +188,15 @@ def train_async(args):
logging.debug('after run startup program') logging.debug('after run startup program')
if checkpoint is not None: if checkpoint is not None:
fluid.io.load_persistables(exe, checkpoint, main_program=train_prog) fluid.load(program=train_prog, model_path=checkpoint, executor=exe)
if pretrained_model: if pretrained_model:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars( fluid.load(program=train_prog, model_path=pretrained_model, executor=exe)
exe, pretrained_model, main_program=train_prog, predicate=if_exist)
if args.use_gpu: if args.use_gpu:
devicenum = get_gpu_num() devicenum = get_gpu_num()
...@@ -283,7 +283,7 @@ def train_async(args): ...@@ -283,7 +283,7 @@ def train_async(args):
str(iter_no)) str(iter_no))
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path, main_program=train_prog) fluid.save(program=train_prog, model_path=model_path)
iter_no += 1 iter_no += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册