diff --git a/tools/eval.py b/tools/eval.py index 2e43c52d1978de96ad9eea65104af01ebc1f654c..d6bd82c0ae6fc33bbbfeb42808588cb6e505fe31 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from paddle.distributed import ParallelEnv +import paddle +from ppcls.utils import logger +from ppcls.utils.save_load import init_model +from ppcls.utils.config import get_config +from ppcls.data import Reader +import program import argparse import os import sys @@ -19,14 +26,6 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) -import program -from ppcls.data import Reader -from ppcls.utils.config import get_config -from ppcls.utils.save_load import init_model -from ppcls.utils import logger - -import paddle -from paddle.distributed import ParallelEnv def parse_args(): parser = argparse.ArgumentParser("PaddleClas eval script") @@ -66,7 +65,8 @@ def main(args): valid_reader = Reader(config, 'valid')() valid_dataloader.set_sample_list_generator(valid_reader, place) net.eval() - top1_acc = program.run(valid_dataloader, config, net, None, 0, 'valid') + top1_acc = program.run(valid_dataloader, config, net, None, None, 0, + 'valid') if __name__ == '__main__':