提交 6d876a00 编写于 作者: Q qingqing01

Clear code

上级 44d49573
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -119,24 +119,3 @@ class ImagePool(object):
return temp
else:
return image
if __name__ == '__main__':
place = fluid.CUDAPlace(0)
#fluid.enable_dygraph(place)
dataset = DataA(shuffle=False)
a_loader = fluid.io.DataLoader(
dataset,
feed_list=[
fluid.data(
name='im', shape=[
None,
2,
2,
], dtype='float32')
],
places=place,
return_list=False,
batch_size=2)
for data in a_loader:
print(data)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -102,6 +102,7 @@ if __name__ == "__main__":
parser.add_argument(
"-s", "--input_style", type=str, default='A', help="A or B")
FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -97,6 +97,7 @@ if __name__ == "__main__":
default='checkpoint/199',
help="The init model file of directory.")
FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -88,21 +88,15 @@ def main():
loader_A = fluid.io.DataLoader(
data.DataA(),
feed_list=[x.forward() for x in [input_A]]
if not FLAGS.dynamic else None,
places=place,
shuffle=True,
return_list=True,
use_buffer_reader=True,
batch_size=FLAGS.batch_size)
loader_B = fluid.io.DataLoader(
data.DataB(),
feed_list=[x.forward() for x in [input_B]]
if not FLAGS.dynamic else None,
places=place,
shuffle=True,
return_list=True,
use_buffer_reader=True,
batch_size=FLAGS.batch_size)
A_pool = data.ImagePool()
......@@ -136,7 +130,11 @@ if __name__ == "__main__":
parser.add_argument(
"-d", "--dynamic", action='store_false', help="Enable dygraph mode")
parser.add_argument(
"--device", type=str, default='gpu', help="device to use, gpu or cpu")
"-p",
"--device",
type=str,
default='gpu',
help="device to use, gpu or cpu")
parser.add_argument(
"-e", "--epoch", default=200, type=int, help="Epoch number")
parser.add_argument(
......@@ -154,6 +152,7 @@ if __name__ == "__main__":
type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
print(FLAGS)
check_gpu(str.lower(FLAGS.device) == 'gpu')
check_version()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册