提交 b06bd21e 编写于 作者: W wqz960

modify eval and vgg

上级 a68d90c5
......@@ -106,7 +106,7 @@ class VGGNet(fluid.dygraph.Layer):
x = self._conv_block_4(x)
x = self._conv_block_5(x)
x = fluid.layers.flatten(x, axis=0)
x = fluid.layers.reshape(x, [0,-1])
x = self._fc1(x)
x = self._drop(x)
x = self._fc2(x)
......
......@@ -19,13 +19,10 @@ from __future__ import print_function
import os
import argparse
import paddle.fluid as fluid
import program
from ppcls.data import Reader
from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model
from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker
......@@ -45,37 +42,25 @@ def parse_args():
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return args
def main(args):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
config = get_config(args.config, overrides=args.override, show=True)
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
# assign the place
gpu_id = fluid.dygraph.parallel.Env().dev_id
place = fluid.CUDAPlace(gpu_id)
startup_prog = fluid.Program()
valid_prog = fluid.Program()
valid_dataloader, valid_fetchs = program.build(
config, valid_prog, startup_prog, is_train=False)
valid_prog = valid_prog.clone(for_test=True)
exe = fluid.Executor(place)
exe.run(startup_prog)
init_model(config, valid_prog, exe)
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
compiled_valid_prog = program.compile(config, valid_prog)
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
'eval')
with fluid.dygraph.guard(place):
pre_weights_dict = fluid.dygraph.load_dygraph(config.pretrained_model)[0]
strategy = fluid.dygraph.parallel.prepare_context()
net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = fluid.dygraph.parallel.DataParallel(net, strategy)
net.set_dict(pre_weights_dict)
valid_dataloader = program.create_dataloader()
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
net.eval()
top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid')
if __name__ == '__main__':
args = parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册