diff --git a/dygraph/bert/run_classifier.py b/dygraph/bert/run_classifier.py index 737feb4f1bf3d17134b36983642768e078e43ab7..5f93ad4e66b04c44df4c0e2370b8836a9647a2c1 100755 --- a/dygraph/bert/run_classifier.py +++ b/dygraph/bert/run_classifier.py @@ -84,6 +84,7 @@ run_type_g.add_arg("task_name", str, None, run_type_g.add_arg("do_train", bool, True, "Whether to perform training.") run_type_g.add_arg("do_test", bool, False, "Whether to perform evaluation on test data set.") run_type_g.add_arg("use_data_parallel", bool, False, "The flag indicating whether to shuffle instances in each pass.") +run_type_g.add_arg("enable_ce", bool, False, help="The flag indicating whether to run the task for continuous evaluation.") args = parser.parse_args() @@ -184,6 +185,8 @@ def train(args): steps = 0 time_begin = time.time() + ce_time = [] + ce_acc = [] for batch in train_data_generator(): data_ids = create_data(batch) loss, accuracy, num_seqs = cls_model(data_ids) @@ -197,6 +200,8 @@ def train(args): current_example, current_epoch = processor.get_train_progress() localtime = time.asctime(time.localtime(time.time())) print("%s, epoch: %s, steps: %s, dy_graph loss: %f, acc: %f, speed: %f steps/s" % (localtime, current_epoch, steps, loss.numpy(), accuracy.numpy(), args.skip_steps / used_time)) + ce_time.append(used_time) + ce_acc.append(accuracy.numpy()) time_begin = time.time() if steps != 0 and steps % args.save_steps == 0 and fluid.dygraph.parallel.Env().local_rank == 0: @@ -220,6 +225,18 @@ def train(args): optimizer.optimizer.state_dict(), save_path) print("Save model parameters and optimizer status at %s" % save_path) + + if args.enable_ce: + _acc = 0 + _time = 0 + try: + _time = ce_time[-1] + _acc = ce_acc[-1] + except: + print("ce info error") + print("kpis\ttrain_duration_card%s\t%s" % (dev_count, _time)) + print("kpis\ttrain_acc_card%s\t%f" % (dev_count, _acc)) + return cls_model def predict(args, cls_model = None):