From 6d876a0072f2f60478c5f40db38f9386d73f5f32 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Tue, 7 Apr 2020 02:29:13 +0000 Subject: [PATCH] Clear code --- cyclegan/data.py | 23 +---------------------- cyclegan/infer.py | 3 ++- cyclegan/test.py | 3 ++- cyclegan/train.py | 15 +++++++-------- 4 files changed, 12 insertions(+), 32 deletions(-) diff --git a/cyclegan/data.py b/cyclegan/data.py index 4b8c4ae..effa4ee 100644 --- a/cyclegan/data.py +++ b/cyclegan/data.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -119,24 +119,3 @@ class ImagePool(object): return temp else: return image - - -if __name__ == '__main__': - place = fluid.CUDAPlace(0) - #fluid.enable_dygraph(place) - dataset = DataA(shuffle=False) - a_loader = fluid.io.DataLoader( - dataset, - feed_list=[ - fluid.data( - name='im', shape=[ - None, - 2, - 2, - ], dtype='float32') - ], - places=place, - return_list=False, - batch_size=2) - for data in a_loader: - print(data) diff --git a/cyclegan/infer.py b/cyclegan/infer.py index f21202f..0b61a95 100644 --- a/cyclegan/infer.py +++ b/cyclegan/infer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -102,6 +102,7 @@ if __name__ == "__main__": parser.add_argument( "-s", "--input_style", type=str, default='A', help="A or B") FLAGS = parser.parse_args() + print(FLAGS) check_gpu(str.lower(FLAGS.device) == 'gpu') check_version() main() diff --git a/cyclegan/test.py b/cyclegan/test.py index 8f20898..9956630 100644 --- a/cyclegan/test.py +++ b/cyclegan/test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -97,6 +97,7 @@ if __name__ == "__main__": default='checkpoint/199', help="The init model file of directory.") FLAGS = parser.parse_args() + print(FLAGS) check_gpu(str.lower(FLAGS.device) == 'gpu') check_version() main() diff --git a/cyclegan/train.py b/cyclegan/train.py index 249f71c..c2203fc 100644 --- a/cyclegan/train.py +++ b/cyclegan/train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -88,21 +88,15 @@ def main(): loader_A = fluid.io.DataLoader( data.DataA(), - feed_list=[x.forward() for x in [input_A]] - if not FLAGS.dynamic else None, places=place, shuffle=True, return_list=True, - use_buffer_reader=True, batch_size=FLAGS.batch_size) loader_B = fluid.io.DataLoader( data.DataB(), - feed_list=[x.forward() for x in [input_B]] - if not FLAGS.dynamic else None, places=place, shuffle=True, return_list=True, - use_buffer_reader=True, batch_size=FLAGS.batch_size) A_pool = data.ImagePool() @@ -136,7 +130,11 @@ if __name__ == "__main__": parser.add_argument( "-d", "--dynamic", action='store_false', help="Enable dygraph mode") parser.add_argument( - "--device", type=str, default='gpu', help="device to use, gpu or cpu") + "-p", + "--device", + type=str, + default='gpu', + help="device to use, gpu or cpu") parser.add_argument( "-e", "--epoch", default=200, type=int, help="Epoch number") parser.add_argument( @@ -154,6 +152,7 @@ if __name__ == "__main__": type=str, help="checkpoint path to resume") FLAGS = parser.parse_args() + print(FLAGS) check_gpu(str.lower(FLAGS.device) == 'gpu') check_version() main() -- GitLab