提交 3ffaf7f2 编写于 作者: 文幕地方's avatar 文幕地方

add distributed train support

上级 b069a091
...@@ -36,6 +36,9 @@ from ppocr.utils.logging import get_logger ...@@ -36,6 +36,9 @@ from ppocr.utils.logging import get_logger
def train(args): def train(args):
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log")) logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
rank = paddle.distributed.get_rank()
distributed = paddle.distributed.get_world_size() > 1
print_arguments(args, logger) print_arguments(args, logger)
# Added here for reproducibility (even between python 2 and 3) # Added here for reproducibility (even between python 2 and 3)
...@@ -45,7 +48,7 @@ def train(args): ...@@ -45,7 +48,7 @@ def train(args):
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode # dist mode
if paddle.distributed.get_world_size() > 1: if distributed:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
...@@ -59,8 +62,8 @@ def train(args): ...@@ -59,8 +62,8 @@ def train(args):
args.model_name_or_path) args.model_name_or_path)
# dist mode # dist mode
if paddle.distributed.get_world_size() > 1: if distributed:
model = paddle.distributed.DataParallel(model) model = paddle.DataParallel(model)
train_dataset = XFUNDataset( train_dataset = XFUNDataset(
tokenizer, tokenizer,
...@@ -90,8 +93,7 @@ def train(args): ...@@ -90,8 +93,7 @@ def train(args):
train_sampler = paddle.io.DistributedBatchSampler( train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True) train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
args.train_batch_size = args.per_gpu_train_batch_size * \
max(1, paddle.distributed.get_world_size())
train_dataloader = paddle.io.DataLoader( train_dataloader = paddle.io.DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
...@@ -136,7 +138,8 @@ def train(args): ...@@ -136,7 +138,8 @@ def train(args):
args.per_gpu_train_batch_size)) args.per_gpu_train_batch_size))
logger.info( logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = {}". " Total train batch size (w. parallel, distributed & accumulation) = {}".
format(args.train_batch_size * paddle.distributed.get_world_size())) format(args.per_gpu_train_batch_size *
paddle.distributed.get_world_size()))
logger.info(" Total optimization steps = {}".format(t_total)) logger.info(" Total optimization steps = {}".format(t_total))
global_step = 0 global_step = 0
...@@ -170,7 +173,7 @@ def train(args): ...@@ -170,7 +173,7 @@ def train(args):
global_step += 1 global_step += 1
total_samples += batch['image'].shape[0] total_samples += batch['image'].shape[0]
if step % print_step == 0: if rank == 0 and step % print_step == 0:
logger.info( logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec". "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch, args.num_train_epochs, step, format(epoch, args.num_train_epochs, step,
...@@ -185,38 +188,38 @@ def train(args): ...@@ -185,38 +188,38 @@ def train(args):
train_run_cost = 0.0 train_run_cost = 0.0
total_samples = 0 total_samples = 0
if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
global_step % args.eval_steps == 0):
# Log metrics # Log metrics
if (paddle.distributed.get_rank() == 0 and args. # Only evaluate when single GPU otherwise metrics may not average well
evaluate_during_training): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(model, eval_dataloader, logger) results = evaluate(model, eval_dataloader, logger)
if results['f1'] >= best_metirc['f1']: if results['f1'] >= best_metirc['f1']:
best_metirc = results best_metirc = results
output_dir = os.path.join(args.output_dir, "best_model") output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
paddle.save(args, paddle.save(args,
os.path.join(output_dir, os.path.join(output_dir, "training_args.bin"))
"training_args.bin"))
logger.info("Saving model checkpoint to {}".format( logger.info("Saving model checkpoint to {}".format(
output_dir)) output_dir))
logger.info("eval results: {}".format(results)) logger.info("eval results: {}".format(results))
logger.info("best_metirc: {}".format(best_metirc)) logger.info("best_metirc: {}".format(best_metirc))
reader_start = time.time()
if paddle.distributed.get_rank() == 0: if rank == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, "latest_model") output_dir = os.path.join(args.output_dir, "latest_model")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if paddle.distributed.get_rank() == 0: if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
paddle.save(args, paddle.save(args, os.path.join(output_dir, "training_args.bin"))
os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to {}".format(output_dir))
logger.info("Saving model checkpoint to {}".format(
output_dir))
reader_start = time.time()
logger.info("best_metirc: {}".format(best_metirc)) logger.info("best_metirc: {}".format(best_metirc))
......
...@@ -37,6 +37,9 @@ from ppocr.utils.logging import get_logger ...@@ -37,6 +37,9 @@ from ppocr.utils.logging import get_logger
def train(args): def train(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
rank = paddle.distributed.get_rank()
distributed = paddle.distributed.get_world_size() > 1
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log")) logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
print_arguments(args, logger) print_arguments(args, logger)
...@@ -44,7 +47,7 @@ def train(args): ...@@ -44,7 +47,7 @@ def train(args):
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode # dist mode
if paddle.distributed.get_world_size() > 1: if distributed:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path) tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
...@@ -59,7 +62,7 @@ def train(args): ...@@ -59,7 +62,7 @@ def train(args):
args.model_name_or_path) args.model_name_or_path)
# dist mode # dist mode
if paddle.distributed.get_world_size() > 1: if distributed:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
train_dataset = XFUNDataset( train_dataset = XFUNDataset(
...@@ -88,9 +91,6 @@ def train(args): ...@@ -88,9 +91,6 @@ def train(args):
train_sampler = paddle.io.DistributedBatchSampler( train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True) train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
args.train_batch_size = args.per_gpu_train_batch_size * max(
1, paddle.distributed.get_world_size())
train_dataloader = paddle.io.DataLoader( train_dataloader = paddle.io.DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
...@@ -134,7 +134,7 @@ def train(args): ...@@ -134,7 +134,7 @@ def train(args):
args.per_gpu_train_batch_size) args.per_gpu_train_batch_size)
logger.info( logger.info(
" Total train batch size (w. parallel, distributed) = %d", " Total train batch size (w. parallel, distributed) = %d",
args.train_batch_size * paddle.distributed.get_world_size(), ) args.per_gpu_train_batch_size * paddle.distributed.get_world_size(), )
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
global_step = 0 global_step = 0
...@@ -168,7 +168,7 @@ def train(args): ...@@ -168,7 +168,7 @@ def train(args):
global_step += 1 global_step += 1
total_samples += batch['image'].shape[0] total_samples += batch['image'].shape[0]
if step % print_step == 0: if rank == 0 and step % print_step == 0:
logger.info( logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec". "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch_id, args.num_train_epochs, step, format(epoch_id, args.num_train_epochs, step,
...@@ -183,47 +183,43 @@ def train(args): ...@@ -183,47 +183,43 @@ def train(args):
train_run_cost = 0.0 train_run_cost = 0.0
total_samples = 0 total_samples = 0
if (paddle.distributed.get_rank() == 0 and args.eval_steps > 0 and if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
global_step % args.eval_steps == 0):
# Log metrics # Log metrics
# Only evaluate when single GPU otherwise metrics may not average well # Only evaluate when single GPU otherwise metrics may not average well
if paddle.distributed.get_rank( results, _ = evaluate(args, model, tokenizer, eval_dataloader,
) == 0 and args.evaluate_during_training: label2id_map, id2label_map,
results, _ = evaluate( pad_token_label_id, logger)
args, model, tokenizer, eval_dataloader, label2id_map,
id2label_map, pad_token_label_id, logger) if best_metrics is None or results["f1"] >= best_metrics["f1"]:
if best_metrics is None or results["f1"] >= best_metrics[
"f1"]:
best_metrics = copy.deepcopy(results) best_metrics = copy.deepcopy(results)
output_dir = os.path.join(args.output_dir, "best_model") output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if paddle.distributed.get_rank() == 0: if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
paddle.save( paddle.save(args,
args,
os.path.join(output_dir, "training_args.bin")) os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", logger.info("Saving model checkpoint to %s", output_dir)
output_dir)
logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format( logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
epoch_id, args.num_train_epochs, step, epoch_id, args.num_train_epochs, step,
len(train_dataloader), results)) len(train_dataloader), results))
if best_metrics is not None: if best_metrics is not None:
logger.info("best metrics: {}".format(best_metrics)) logger.info("best metrics: {}".format(best_metrics))
reader_start = time.time()
if paddle.distributed.get_rank() == 0: if rank == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, "latest_model") output_dir = os.path.join(args.output_dir, "latest_model")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if paddle.distributed.get_rank() == 0: if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir) model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)
paddle.save(args, paddle.save(args, os.path.join(output_dir, "training_args.bin"))
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
reader_start = time.time()
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册