未验证 提交 4e92b27e 编写于 作者: P pkpk 提交者: GitHub

Merge pull request #51 from xyzhou-puck/master

fix bugs
...@@ -103,7 +103,7 @@ def main(): ...@@ -103,7 +103,7 @@ def main():
batch_size=config.batch_size, batch_size=config.batch_size,
line_processor=mnli_line_processor) line_processor=mnli_line_processor)
dev_dataloader = BertDataLoader( test_dataloader = BertDataLoader(
"./data/glue_data/MNLI/dev_matched.tsv", "./data/glue_data/MNLI/dev_matched.tsv",
tokenizer, ["contradiction", "entailment", "neutral"], tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=config.max_seq_len, max_seq_length=config.max_seq_len,
......
...@@ -105,7 +105,7 @@ def main(): ...@@ -105,7 +105,7 @@ def main():
mode="leveldb", mode="leveldb",
phase="train") phase="train")
dev_dataloader = BertDataLoader( test_dataloader = BertDataLoader(
"./data/glue_data/MNLI/dev_matched.tsv", "./data/glue_data/MNLI/dev_matched.tsv",
tokenizer, ["contradiction", "entailment", "neutral"], tokenizer, ["contradiction", "entailment", "neutral"],
max_seq_length=config.max_seq_len, max_seq_length=config.max_seq_len,
......
...@@ -130,6 +130,18 @@ class Optimizer(object): ...@@ -130,6 +130,18 @@ class Optimizer(object):
return True return True
return False return False
def state_dict(self):
return self.optimizer.state_dict()
def set_dict(self, state_dict):
return self.optimizer.set_dict(state_dict)
def get_opti_var_name_list(self):
return self.optimizer.get_opti_var_name_list()
def current_step_lr(self):
return self.optimizer.current_step_lr()
def minimize(self, loss, use_data_parallel=False, model=None): def minimize(self, loss, use_data_parallel=False, model=None):
param_list = dict() param_list = dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册