未验证 提交 cea62c00 编写于 作者: W WangZhen 提交者: GitHub

Eval during train for ResNet (#52768)

* Eval during train for ResNet
上级 9f2e3064
...@@ -131,31 +131,13 @@ def optimizer_setting(parameter_list=None): ...@@ -131,31 +131,13 @@ def optimizer_setting(parameter_list=None):
return optimizer return optimizer
def train(to_static, enable_prim, enable_cinn): def run(model, data_loader, optimizer, mode):
if core.is_compiled_with_cuda(): if mode == 'train':
paddle.set_device('gpu') model.train()
else: end_step = 9
paddle.set_device('cpu') elif mode == 'eval':
np.random.seed(SEED) model.eval()
paddle.seed(SEED) end_step = 1
paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled(enable_prim)
train_reader = paddle.batch(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size,
drop_last=True,
)
data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True)
data_loader.set_sample_list_generator(train_reader)
resnet = resnet50(False)
if to_static:
build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
for epoch in range(epoch_num): for epoch in range(epoch_num):
total_acc1 = 0.0 total_acc1 = 0.0
...@@ -167,7 +149,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -167,7 +149,7 @@ def train(to_static, enable_prim, enable_cinn):
start_time = time.time() start_time = time.time()
img, label = data img, label = data
pred = resnet(img) pred = model(img)
avg_loss = paddle.nn.functional.cross_entropy( avg_loss = paddle.nn.functional.cross_entropy(
input=pred, input=pred,
label=label, label=label,
...@@ -179,9 +161,10 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -179,9 +161,10 @@ def train(to_static, enable_prim, enable_cinn):
acc_top1 = paddle.static.accuracy(input=pred, label=label, k=1) acc_top1 = paddle.static.accuracy(input=pred, label=label, k=1)
acc_top5 = paddle.static.accuracy(input=pred, label=label, k=5) acc_top5 = paddle.static.accuracy(input=pred, label=label, k=5)
if mode == 'train':
avg_loss.backward() avg_loss.backward()
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
resnet.clear_gradients() model.clear_gradients()
total_acc1 += acc_top1 total_acc1 += acc_top1
total_acc5 += acc_top5 total_acc5 += acc_top5
...@@ -190,8 +173,9 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -190,8 +173,9 @@ def train(to_static, enable_prim, enable_cinn):
end_time = time.time() end_time = time.time()
print( print(
"epoch %d | batch step %d, loss %0.8f, acc1 %0.3f, acc5 %0.3f, time %f" "[%s]epoch %d | batch step %d, loss %0.8f, acc1 %0.3f, acc5 %0.3f, time %f"
% ( % (
mode,
epoch, epoch,
batch_id, batch_id,
avg_loss, avg_loss,
...@@ -200,7 +184,7 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -200,7 +184,7 @@ def train(to_static, enable_prim, enable_cinn):
end_time - start_time, end_time - start_time,
) )
) )
if batch_id >= 9: if batch_id >= end_step:
# avoid dataloader throw abort signaal # avoid dataloader throw abort signaal
data_loader._reset() data_loader._reset()
break break
...@@ -208,6 +192,38 @@ def train(to_static, enable_prim, enable_cinn): ...@@ -208,6 +192,38 @@ def train(to_static, enable_prim, enable_cinn):
return losses return losses
def train(to_static, enable_prim, enable_cinn):
if core.is_compiled_with_cuda():
paddle.set_device('gpu')
else:
paddle.set_device('cpu')
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
fluid.core._set_prim_all_enabled(enable_prim)
train_reader = paddle.batch(
reader_decorator(paddle.dataset.flowers.train(use_xmap=False)),
batch_size=batch_size,
drop_last=True,
)
data_loader = fluid.io.DataLoader.from_generator(capacity=5, iterable=True)
data_loader.set_sample_list_generator(train_reader)
resnet = resnet50(False)
if to_static:
build_strategy = paddle.static.BuildStrategy()
if enable_cinn:
build_strategy.build_cinn_pass = True
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
train_losses = run(resnet, data_loader, optimizer, 'train')
if to_static and enable_prim and enable_cinn:
eval_losses = run(resnet, data_loader, optimizer, 'eval')
return train_losses
class TestResnet(unittest.TestCase): class TestResnet(unittest.TestCase):
@unittest.skipIf( @unittest.skipIf(
not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()), not (paddle.is_compiled_with_cinn() and paddle.is_compiled_with_cuda()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册