提交 b06bd21e 编写于 作者: W wqz960

modify eval and vgg

上级 a68d90c5
...@@ -106,7 +106,7 @@ class VGGNet(fluid.dygraph.Layer): ...@@ -106,7 +106,7 @@ class VGGNet(fluid.dygraph.Layer):
x = self._conv_block_4(x) x = self._conv_block_4(x)
x = self._conv_block_5(x) x = self._conv_block_5(x)
x = fluid.layers.flatten(x, axis=0) x = fluid.layers.reshape(x, [0,-1])
x = self._fc1(x) x = self._fc1(x)
x = self._drop(x) x = self._drop(x)
x = self._fc2(x) x = self._fc2(x)
......
...@@ -19,13 +19,10 @@ from __future__ import print_function ...@@ -19,13 +19,10 @@ from __future__ import print_function
import os import os
import argparse import argparse
import paddle.fluid as fluid
import program
from ppcls.data import Reader from ppcls.data import Reader
from ppcls.utils.config import get_config from ppcls.utils.config import get_config
from ppcls.utils.save_load import init_model from ppcls.utils.save_load import init_model
from ppcls.utils import logger
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
...@@ -45,37 +42,25 @@ def parse_args(): ...@@ -45,37 +42,25 @@ def parse_args():
action='append', action='append',
default=[], default=[],
help='config options to be overridden') help='config options to be overridden')
args = parser.parse_args() args = parser.parse_args()
return args return args
def main(args): def main(args):
role = role_maker.PaddleCloudRoleMaker(is_collective=True) # assign the place
fleet.init(role) gpu_id = fluid.dygraph.parallel.Env().dev_id
config = get_config(args.config, overrides=args.override, show=True)
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) place = fluid.CUDAPlace(gpu_id)
with fluid.dygraph.guard(place):
startup_prog = fluid.Program() pre_weights_dict = fluid.dygraph.load_dygraph(config.pretrained_model)[0]
valid_prog = fluid.Program() strategy = fluid.dygraph.parallel.prepare_context()
valid_dataloader, valid_fetchs = program.build( net = program.create_model(config.ARCHITECTURE, config.classes_num)
config, valid_prog, startup_prog, is_train=False) net = fluid.dygraph.parallel.DataParallel(net, strategy)
valid_prog = valid_prog.clone(for_test=True) net.set_dict(pre_weights_dict)
valid_dataloader = program.create_dataloader()
exe = fluid.Executor(place)
exe.run(startup_prog)
init_model(config, valid_prog, exe)
valid_reader = Reader(config, 'valid')() valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place) valid_dataloader.set_sample_list_generator(valid_reader, place)
net.eval()
compiled_valid_prog = program.compile(config, valid_prog) top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid')
program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1,
'eval')
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册