提交 40950add 编写于 作者: C chenguowei01

update train.py

上级 1c06b49d
...@@ -144,8 +144,6 @@ def train(model, ...@@ -144,8 +144,6 @@ def train(model,
return_list=True, return_list=True,
) )
num_steps_each_epoch = len(train_dataset) // batch_size
for epoch in range(num_epochs): for epoch in range(num_epochs):
for step, data in enumerate(loader): for step, data in enumerate(loader):
images = data[0] images = data[0]
...@@ -165,8 +163,7 @@ def train(model, ...@@ -165,8 +163,7 @@ def train(model,
loss.numpy())) loss.numpy()))
if ((epoch + 1) % save_interval_epochs == 0 if ((epoch + 1) % save_interval_epochs == 0
or num_steps_each_epoch == num_epochs - 1 or epoch == num_epochs - 1) and ParallelEnv().local_rank == 0:
) and ParallelEnv().local_rank == 0:
current_save_dir = os.path.join(save_dir, current_save_dir = os.path.join(save_dir,
"epoch_{}".format(epoch + 1)) "epoch_{}".format(epoch + 1))
if not os.path.isdir(current_save_dir): if not os.path.isdir(current_save_dir):
...@@ -223,7 +220,10 @@ def main(args): ...@@ -223,7 +220,10 @@ def main(args):
num_classes=train_dataset.num_classes, ignore_index=255) num_classes=train_dataset.num_classes, ignore_index=255)
# Creat optimizer # Creat optimizer
num_steps_each_epoch = len(train_dataset) // args.batch_size # todo, may less one than len(loader)
num_steps_each_epoch = len(train_dataset) // (
args.batch_size * ParallelEnv().nranks)
print(num_steps_each_epoch, 'num_steps_each_epoch')
decay_step = args.num_epochs * num_steps_each_epoch decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay( lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0, power=0.9) args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册