提交 98268e3f 编写于 作者: A Abhishek Rao

Rename inference mode inputs from predict to test

上级 41d6111d
......@@ -308,8 +308,8 @@ that it's running on something other than a Cloud TPU, which includes a GPU.
#### Prediction from classifier
Once you have trained your classifier you can use it in inference mode by using the --do_predict=true command.
You need to have a file named predict.tsv in the input folder.
Output will be created in file called predict_results.tsv in the output folder.
You need to have a file named test.tsv in the input folder.
Output will be created in file called test_results.tsv in the output folder.
Each line will contain output for each sample, columns are the class probabilities.
```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
......
......@@ -71,7 +71,7 @@ flags.DEFINE_bool("do_train", False, "Whether to run training.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_bool("do_predict", False, "Whether to run the model in inference mode on predict set.")
flags.DEFINE_bool("do_predict", False, "Whether to run the model in inference mode on the test set.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
......@@ -164,7 +164,7 @@ class DataProcessor(object):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_predict_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
......@@ -245,11 +245,11 @@ class MnliProcessor(DataProcessor):
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
"dev_matched")
def get_predict_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "predict.tsv")),
"predict")
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")),
"test")
def get_labels(self):
"""See base class."""
......@@ -283,10 +283,10 @@ class MrpcProcessor(DataProcessor):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_predict_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "predict.tsv")), "predict")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
......@@ -320,10 +320,10 @@ class ColaProcessor(DataProcessor):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_predict_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "predict.tsv")), "predict")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
......@@ -772,7 +772,7 @@ def main(_):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if FLAGS.do_predict:
predict_examples = processor.get_predict_examples(FLAGS.data_dir)
predict_examples = processor.get_test_examples(FLAGS.data_dir)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer, predict_file)
......@@ -795,7 +795,7 @@ def main(_):
result = estimator.predict(input_fn=predict_input_fn)
output_predict_file = os.path.join(FLAGS.output_dir, "predict_results.tsv")
output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
with tf.gfile.GFile(output_predict_file, "w") as writer:
tf.logging.info("***** Predict results *****")
for prediction in result:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册