From 4233d6a84cc11f049052fa8911fc99c3a27a05c2 Mon Sep 17 00:00:00 2001 From: baiyfbupt Date: Tue, 7 Jul 2020 21:23:44 +0800 Subject: [PATCH] code fix --- demo/bert/train_distill.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/demo/bert/train_distill.py b/demo/bert/train_distill.py index 44eecbf4..8f4758b0 100755 --- a/demo/bert/train_distill.py +++ b/demo/bert/train_distill.py @@ -100,12 +100,14 @@ def main(): teacher_model_dir = "./data/teacher_model/mrpc" num_samples = 3668 max_layer = 4 + num_labels = 2 processor_func = MrpcProcessor elif task_name == 'mnli': data_dir = "./data/glue_data/MNLI/" - teacher_model_dir = "./data/teacher_model/step_23000" + teacher_model_dir = "./data/teacher_model/steps_23000" num_samples = 392702 max_layer = 8 + num_labels = 3 processor_func = MnliProcessor device_num = fluid.dygraph.parallel.Env().nranks @@ -120,7 +122,7 @@ def main(): np.random.seed(1) fluid.default_main_program().random_seed = 1 model = AdaBERTClassifier( - 2, + num_labels, n_layer=max_layer, hidden_size=hidden_size, task_name=task_name, -- GitLab