From 14c2d97208c82668a39eaddb2c9f6aa727ef0043 Mon Sep 17 00:00:00 2001 From: xyzhou-puck Date: Tue, 21 Apr 2020 05:37:48 +0000 Subject: [PATCH] fix bugs --- examples/bert/bert_classifier.py | 2 +- examples/bert_leveldb/bert_classifier.py | 2 +- hapi/text/bert/optimization.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/bert/bert_classifier.py b/examples/bert/bert_classifier.py index 6bb9ea3..ef43ae2 100644 --- a/examples/bert/bert_classifier.py +++ b/examples/bert/bert_classifier.py @@ -103,7 +103,7 @@ def main(): batch_size=config.batch_size, line_processor=mnli_line_processor) - dev_dataloader = BertDataLoader( + test_dataloader = BertDataLoader( "./data/glue_data/MNLI/dev_matched.tsv", tokenizer, ["contradiction", "entailment", "neutral"], max_seq_length=config.max_seq_len, diff --git a/examples/bert_leveldb/bert_classifier.py b/examples/bert_leveldb/bert_classifier.py index 8919013..11bc857 100644 --- a/examples/bert_leveldb/bert_classifier.py +++ b/examples/bert_leveldb/bert_classifier.py @@ -105,7 +105,7 @@ def main(): mode="leveldb", phase="train") - dev_dataloader = BertDataLoader( + test_dataloader = BertDataLoader( "./data/glue_data/MNLI/dev_matched.tsv", tokenizer, ["contradiction", "entailment", "neutral"], max_seq_length=config.max_seq_len, diff --git a/hapi/text/bert/optimization.py b/hapi/text/bert/optimization.py index 2bf6b7f..b2ba8f6 100755 --- a/hapi/text/bert/optimization.py +++ b/hapi/text/bert/optimization.py @@ -130,6 +130,18 @@ class Optimizer(object): return True return False + def state_dict(self): + return self.optimizer.state_dict() + + def set_dict(self, state_dict): + return self.optimizer.set_dict(state_dict) + + def get_opti_var_name_list(self): + return self.optimizer.get_opti_var_name_list() + + def current_step_lr(self): + return self.optimizer.current_step_lr() + def minimize(self, loss, use_data_parallel=False, model=None): param_list = dict() -- GitLab