未验证 提交 b0984c7c 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the timeout problem of test_multi_precision_fp16_train UT. (#33596)

上级 bb1216f5
......@@ -19,11 +19,35 @@ import paddle.fluid as fluid
import contextlib
import unittest
import numpy as np
from paddle.io import Dataset
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
paddle.enable_static()
class RandomDataset(Dataset):
def __init__(self, num_samples, seed=123):
super(RandomDataset, self).__init__()
np.random.seed(seed)
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([3, 32, 32]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
def reader_decorator(reader):
def __reader__():
for i in range(len(reader)):
yield reader[i]
return __reader__
def resnet_cifar10(input, depth=32):
def conv_bn_layer(input,
ch_out,
......@@ -76,7 +100,6 @@ def resnet_cifar10(input, depth=32):
def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
classdim = 10
data_shape = [3, 32, 32]
BATCH_SIZE = 32
PASS_NUM = 1
train_program = fluid.Program()
......@@ -124,25 +147,31 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
optimizer.minimize(sum_cost)
# no shuffle for unit test
train_reader = paddle.batch(
paddle.dataset.cifar.train10(), batch_size=BATCH_SIZE)
reader_decorator(RandomDataset(
16 * 5, seed=123)),
batch_size=16,
drop_last=True)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE)
reader_decorator(RandomDataset(
4 * 5, seed=456)),
batch_size=4,
drop_last=True)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
def train_loop(main_program):
def train_loop():
exe.run(startup_prog)
if use_pure_fp16:
optimizer.amp_init(
place, test_program=test_program, use_fp16_test=True)
loss = 0.0
train_loss_list = []
test_loss_list = []
for pass_id in range(PASS_NUM):
train_loss_list = []
for batch_id, data in enumerate(train_reader()):
loss, = exe.run(train_program,
feed=feeder.feed(data),
......@@ -152,21 +181,17 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
format(pass_id, batch_id + 1, float(loss_v)))
train_loss_list.append(float(loss_v))
if batch_id >= 4: # For speeding up CI
test_loss_list = []
for tid, test_data in enumerate(test_reader()):
loss_t, = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[sum_cost])
test_loss_list.append(float(loss_t))
print(
'PassID {0:1}, Test Batch ID {1:04}, test loss {2:2.4}'.
format(pass_id, tid + 1, float(loss_t)))
if tid >= 4:
break # For speeding up CI
return train_loss_list, test_loss_list
return train_loop(train_program)
for tid, test_data in enumerate(test_reader()):
loss_t, = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[sum_cost])
test_loss_list.append(float(loss_t))
print('PassID {0:1}, Test Batch ID {1:04}, test loss {2:2.4}'.
format(pass_id, tid + 1, float(loss_t)))
return train_loss_list, test_loss_list
return train_loop()
class TestImageMultiPrecision(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册