提交 6d876a00 编写于 作者: Q qingqing01

Clear code

上级 44d49573
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -119,24 +119,3 @@ class ImagePool(object): ...@@ -119,24 +119,3 @@ class ImagePool(object):
return temp return temp
else: else:
return image return image
if __name__ == '__main__':
place = fluid.CUDAPlace(0)
#fluid.enable_dygraph(place)
dataset = DataA(shuffle=False)
a_loader = fluid.io.DataLoader(
dataset,
feed_list=[
fluid.data(
name='im', shape=[
None,
2,
2,
], dtype='float32')
],
places=place,
return_list=False,
batch_size=2)
for data in a_loader:
print(data)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -102,6 +102,7 @@ if __name__ == "__main__": ...@@ -102,6 +102,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-s", "--input_style", type=str, default='A', help="A or B") "-s", "--input_style", type=str, default='A', help="A or B")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu') check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version() check_version()
main() main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -97,6 +97,7 @@ if __name__ == "__main__": ...@@ -97,6 +97,7 @@ if __name__ == "__main__":
default='checkpoint/199', default='checkpoint/199',
help="The init model file of directory.") help="The init model file of directory.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu') check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version() check_version()
main() main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -88,21 +88,15 @@ def main(): ...@@ -88,21 +88,15 @@ def main():
loader_A = fluid.io.DataLoader( loader_A = fluid.io.DataLoader(
data.DataA(), data.DataA(),
feed_list=[x.forward() for x in [input_A]]
if not FLAGS.dynamic else None,
places=place, places=place,
shuffle=True, shuffle=True,
return_list=True, return_list=True,
use_buffer_reader=True,
batch_size=FLAGS.batch_size) batch_size=FLAGS.batch_size)
loader_B = fluid.io.DataLoader( loader_B = fluid.io.DataLoader(
data.DataB(), data.DataB(),
feed_list=[x.forward() for x in [input_B]]
if not FLAGS.dynamic else None,
places=place, places=place,
shuffle=True, shuffle=True,
return_list=True, return_list=True,
use_buffer_reader=True,
batch_size=FLAGS.batch_size) batch_size=FLAGS.batch_size)
A_pool = data.ImagePool() A_pool = data.ImagePool()
...@@ -136,7 +130,11 @@ if __name__ == "__main__": ...@@ -136,7 +130,11 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode") "-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument( parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu") "-p",
"--device",
type=str,
default='gpu',
help="device to use, gpu or cpu")
parser.add_argument( parser.add_argument(
"-e", "--epoch", default=200, type=int, help="Epoch number") "-e", "--epoch", default=200, type=int, help="Epoch number")
parser.add_argument( parser.add_argument(
...@@ -154,6 +152,7 @@ if __name__ == "__main__": ...@@ -154,6 +152,7 @@ if __name__ == "__main__":
type=str, type=str,
help="checkpoint path to resume") help="checkpoint path to resume")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu') check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version() check_version()
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册