未验证 提交 a0a66616 编写于 作者: C chajchaj 提交者: GitHub

fix bug: use_gpu false (#4403)

上级 47e1bf21
...@@ -256,8 +256,9 @@ def process_batch_data(input_data, settings, mode, color_jitter, rotate): ...@@ -256,8 +256,9 @@ def process_batch_data(input_data, settings, mode, color_jitter, rotate):
class ImageNetReader: class ImageNetReader:
def __init__(self, seed=None): def __init__(self, seed=None, place_num=1):
self.shuffle_seed = seed self.shuffle_seed = seed
self.place_num = place_num
def set_shuffle_seed(self, seed): def set_shuffle_seed(self, seed):
assert isinstance(seed, int), "shuffle seed must be int" assert isinstance(seed, int), "shuffle seed must be int"
...@@ -275,8 +276,7 @@ class ImageNetReader: ...@@ -275,8 +276,7 @@ class ImageNetReader:
if mode == 'test': if mode == 'test':
batch_size = 1 batch_size = 1
else: else:
batch_size = settings.batch_size / paddle.fluid.core.get_cuda_device_count( batch_size = settings.batch_size / self.place_num
)
def reader(): def reader():
def read_file_list(): def read_file_list():
...@@ -365,8 +365,7 @@ class ImageNetReader: ...@@ -365,8 +365,7 @@ class ImageNetReader:
reader = create_mixup_reader(settings, reader) reader = create_mixup_reader(settings, reader)
reader = fluid.io.batch( reader = fluid.io.batch(
reader, reader,
batch_size=int(settings.batch_size / batch_size=int(settings.batch_size / self.place_num),
paddle.fluid.core.get_cuda_device_count()),
drop_last=True) drop_last=True)
return reader return reader
......
...@@ -42,10 +42,11 @@ def eval(net, test_data_loader, eop): ...@@ -42,10 +42,11 @@ def eval(net, test_data_loader, eop):
total_acc5 = 0.0 total_acc5 = 0.0
total_sample = 0 total_sample = 0
t_last = 0 t_last = 0
place_num = paddle.fluid.core.get_cuda_device_count() if args.use_gpu else int(os.environ.get('CPU_NUM', 1))
for img, label in test_data_loader(): for img, label in test_data_loader():
t1 = time.time() t1 = time.time()
label = to_variable(label.numpy().astype('int64').reshape( label = to_variable(label.numpy().astype('int64').reshape(
int(args.batch_size // paddle.fluid.core.get_cuda_device_count()), int(args.batch_size // place_num),
1)) 1))
out = net(img) out = net(img)
softmax_out = fluid.layers.softmax(out, use_cudnn=False) softmax_out = fluid.layers.softmax(out, use_cudnn=False)
...@@ -77,6 +78,7 @@ def train_mobilenet(): ...@@ -77,6 +78,7 @@ def train_mobilenet():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
# 1. init net and optimizer # 1. init net and optimizer
place_num = paddle.fluid.core.get_cuda_device_count() if args.use_gpu else int(os.environ.get('CPU_NUM', 1))
if args.ce: if args.ce:
print("ce mode") print("ce mode")
seed = 33 seed = 33
...@@ -118,7 +120,7 @@ def train_mobilenet(): ...@@ -118,7 +120,7 @@ def train_mobilenet():
test_data_loader, test_data = utility.create_data_loader( test_data_loader, test_data = utility.create_data_loader(
is_train=False, args=args) is_train=False, args=args)
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
imagenet_reader = reader.ImageNetReader(0) imagenet_reader = reader.ImageNetReader(seed=0,place_num=place_num)
train_reader = imagenet_reader.train(settings=args) train_reader = imagenet_reader.train(settings=args)
test_reader = imagenet_reader.val(settings=args) test_reader = imagenet_reader.val(settings=args)
train_data_loader.set_sample_list_generator(train_reader, place) train_data_loader.set_sample_list_generator(train_reader, place)
...@@ -140,8 +142,8 @@ def train_mobilenet(): ...@@ -140,8 +142,8 @@ def train_mobilenet():
for img, label in train_data_loader(): for img, label in train_data_loader():
t1 = time.time() t1 = time.time()
label = to_variable(label.numpy().astype('int64').reshape( label = to_variable(label.numpy().astype('int64').reshape(
int(args.batch_size // int(args.batch_size // place_num),
paddle.fluid.core.get_cuda_device_count()), 1)) 1))
t_start = time.time() t_start = time.time()
# 4.1.1 call net() # 4.1.1 call net()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册