From b06bd21e94da813db2a8ffa1b54dabbd8b4221b5 Mon Sep 17 00:00:00 2001 From: wqz960 <362379625@qq.com> Date: Mon, 20 Jul 2020 06:34:51 +0000 Subject: [PATCH] modify eval and vgg --- ppcls/modeling/architectures/vgg.py | 2 +- tools/eval.py | 43 ++++++++++------------------- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/ppcls/modeling/architectures/vgg.py b/ppcls/modeling/architectures/vgg.py index b439f267..28845b3e 100644 --- a/ppcls/modeling/architectures/vgg.py +++ b/ppcls/modeling/architectures/vgg.py @@ -106,7 +106,7 @@ class VGGNet(fluid.dygraph.Layer): x = self._conv_block_4(x) x = self._conv_block_5(x) - x = fluid.layers.flatten(x, axis=0) + x = fluid.layers.reshape(x, [0,-1]) x = self._fc1(x) x = self._drop(x) x = self._fc2(x) diff --git a/tools/eval.py b/tools/eval.py index 291f77f0..c4589518 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -19,13 +19,10 @@ from __future__ import print_function import os import argparse -import paddle.fluid as fluid - -import program - from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model +from ppcls.utils import logger from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.base import role_maker @@ -45,37 +42,25 @@ def parse_args(): action='append', default=[], help='config options to be overridden') - args = parser.parse_args() return args def main(args): - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - - config = get_config(args.config, overrides=args.override, show=True) - gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) + # assign the place + gpu_id = fluid.dygraph.parallel.Env().dev_id place = fluid.CUDAPlace(gpu_id) - - startup_prog = fluid.Program() - valid_prog = fluid.Program() - valid_dataloader, valid_fetchs = program.build( - config, valid_prog, startup_prog, is_train=False) - valid_prog = valid_prog.clone(for_test=True) - - exe = fluid.Executor(place) - exe.run(startup_prog) - - init_model(config, valid_prog, exe) - - valid_reader = Reader(config, 'valid')() - valid_dataloader.set_sample_list_generator(valid_reader, place) - - compiled_valid_prog = program.compile(config, valid_prog) - program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, -1, - 'eval') - + with fluid.dygraph.guard(place): + pre_weights_dict = fluid.dygraph.load_dygraph(config.pretrained_model)[0] + strategy = fluid.dygraph.parallel.prepare_context() + net = program.create_model(config.ARCHITECTURE, config.classes_num) + net = fluid.dygraph.parallel.DataParallel(net, strategy) + net.set_dict(pre_weights_dict) + valid_dataloader = program.create_dataloader() + valid_reader = Reader(config, 'valid')() + valid_dataloader.set_sample_list_generator(valid_reader, place) + net.eval() + top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid') if __name__ == '__main__': args = parse_args() -- GitLab