未验证 提交 b0984c7c 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the timeout problem of test_multi_precision_fp16_train UT. (#33596)

上级 bb1216f5
...@@ -19,11 +19,35 @@ import paddle.fluid as fluid ...@@ -19,11 +19,35 @@ import paddle.fluid as fluid
import contextlib import contextlib
import unittest import unittest
import numpy as np import numpy as np
from paddle.io import Dataset
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16 from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
paddle.enable_static() paddle.enable_static()
class RandomDataset(Dataset):
def __init__(self, num_samples, seed=123):
super(RandomDataset, self).__init__()
np.random.seed(seed)
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([3, 32, 32]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
def reader_decorator(reader):
def __reader__():
for i in range(len(reader)):
yield reader[i]
return __reader__
def resnet_cifar10(input, depth=32): def resnet_cifar10(input, depth=32):
def conv_bn_layer(input, def conv_bn_layer(input,
ch_out, ch_out,
...@@ -76,7 +100,6 @@ def resnet_cifar10(input, depth=32): ...@@ -76,7 +100,6 @@ def resnet_cifar10(input, depth=32):
def train(use_pure_fp16=True, use_nesterov=False, optimizer=""): def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
classdim = 10 classdim = 10
data_shape = [3, 32, 32] data_shape = [3, 32, 32]
BATCH_SIZE = 32
PASS_NUM = 1 PASS_NUM = 1
train_program = fluid.Program() train_program = fluid.Program()
...@@ -124,25 +147,31 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""): ...@@ -124,25 +147,31 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
optimizer.minimize(sum_cost) optimizer.minimize(sum_cost)
# no shuffle for unit test
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.cifar.train10(), batch_size=BATCH_SIZE) reader_decorator(RandomDataset(
16 * 5, seed=123)),
batch_size=16,
drop_last=True)
test_reader = paddle.batch( test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) reader_decorator(RandomDataset(
4 * 5, seed=456)),
batch_size=4,
drop_last=True)
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) feeder = fluid.DataFeeder(place=place, feed_list=[images, label])
def train_loop(main_program): def train_loop():
exe.run(startup_prog) exe.run(startup_prog)
if use_pure_fp16: if use_pure_fp16:
optimizer.amp_init( optimizer.amp_init(
place, test_program=test_program, use_fp16_test=True) place, test_program=test_program, use_fp16_test=True)
loss = 0.0
train_loss_list = []
test_loss_list = []
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
train_loss_list = []
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
loss, = exe.run(train_program, loss, = exe.run(train_program,
feed=feeder.feed(data), feed=feeder.feed(data),
...@@ -152,21 +181,17 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""): ...@@ -152,21 +181,17 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
format(pass_id, batch_id + 1, float(loss_v))) format(pass_id, batch_id + 1, float(loss_v)))
train_loss_list.append(float(loss_v)) train_loss_list.append(float(loss_v))
if batch_id >= 4: # For speeding up CI for tid, test_data in enumerate(test_reader()):
test_loss_list = [] loss_t, = exe.run(program=test_program,
for tid, test_data in enumerate(test_reader()): feed=feeder.feed(test_data),
loss_t, = exe.run(program=test_program, fetch_list=[sum_cost])
feed=feeder.feed(test_data), test_loss_list.append(float(loss_t))
fetch_list=[sum_cost]) print('PassID {0:1}, Test Batch ID {1:04}, test loss {2:2.4}'.
test_loss_list.append(float(loss_t)) format(pass_id, tid + 1, float(loss_t)))
print(
'PassID {0:1}, Test Batch ID {1:04}, test loss {2:2.4}'. return train_loss_list, test_loss_list
format(pass_id, tid + 1, float(loss_t)))
if tid >= 4: return train_loop()
break # For speeding up CI
return train_loss_list, test_loss_list
return train_loop(train_program)
class TestImageMultiPrecision(unittest.TestCase): class TestImageMultiPrecision(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册