提交 4233d6a8 编写于 作者: B baiyfbupt

code fix

上级 f7e57cb6
...@@ -100,12 +100,14 @@ def main(): ...@@ -100,12 +100,14 @@ def main():
teacher_model_dir = "./data/teacher_model/mrpc" teacher_model_dir = "./data/teacher_model/mrpc"
num_samples = 3668 num_samples = 3668
max_layer = 4 max_layer = 4
num_labels = 2
processor_func = MrpcProcessor processor_func = MrpcProcessor
elif task_name == 'mnli': elif task_name == 'mnli':
data_dir = "./data/glue_data/MNLI/" data_dir = "./data/glue_data/MNLI/"
teacher_model_dir = "./data/teacher_model/step_23000" teacher_model_dir = "./data/teacher_model/steps_23000"
num_samples = 392702 num_samples = 392702
max_layer = 8 max_layer = 8
num_labels = 3
processor_func = MnliProcessor processor_func = MnliProcessor
device_num = fluid.dygraph.parallel.Env().nranks device_num = fluid.dygraph.parallel.Env().nranks
...@@ -120,7 +122,7 @@ def main(): ...@@ -120,7 +122,7 @@ def main():
np.random.seed(1) np.random.seed(1)
fluid.default_main_program().random_seed = 1 fluid.default_main_program().random_seed = 1
model = AdaBERTClassifier( model = AdaBERTClassifier(
2, num_labels,
n_layer=max_layer, n_layer=max_layer,
hidden_size=hidden_size, hidden_size=hidden_size,
task_name=task_name, task_name=task_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册