提交 4233d6a8 编写于 作者: B baiyfbupt

code fix

上级 f7e57cb6
......@@ -100,12 +100,14 @@ def main():
teacher_model_dir = "./data/teacher_model/mrpc"
num_samples = 3668
max_layer = 4
num_labels = 2
processor_func = MrpcProcessor
elif task_name == 'mnli':
data_dir = "./data/glue_data/MNLI/"
teacher_model_dir = "./data/teacher_model/step_23000"
teacher_model_dir = "./data/teacher_model/steps_23000"
num_samples = 392702
max_layer = 8
num_labels = 3
processor_func = MnliProcessor
device_num = fluid.dygraph.parallel.Env().nranks
......@@ -120,7 +122,7 @@ def main():
np.random.seed(1)
fluid.default_main_program().random_seed = 1
model = AdaBERTClassifier(
2,
num_labels,
n_layer=max_layer,
hidden_size=hidden_size,
task_name=task_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册