提交 1ce7581a 编写于 作者: K kinghuin

fix bugs

上级 2341277b
......@@ -511,10 +511,6 @@ class CombinedStrategy(DefaultStrategy):
unfreeze_depths=self.
sorted_depth[:self.max_depth * self.epoch //
self.scheduler["gradual_unfreeze"]["blocks"]])
else:
logger.warning(
"The max op-depth in the network is %s. That results in that can't use the gradual unfreeze finetune strategy."
% (self.max_depth))
elif self.scheduler["gradual_unfreeze"]["params_layer"]:
max_layer = max(
self.scheduler["gradual_unfreeze"]["params_layer"].values())
......@@ -631,8 +627,9 @@ class ULMFiTStrategy(CombinedStrategy):
ratio=32,
dis_blocks=3,
factor=2.6,
dis_params_layer=None,
frz_blocks=3,
params_layer=None):
frz_params_layer=None):
scheduler = {
"slanted_triangle": {
......@@ -641,12 +638,12 @@ class ULMFiTStrategy(CombinedStrategy):
},
"gradual_unfreeze": {
"blocks": frz_blocks,
"params_layer": params_layer
"params_layer": frz_params_layer
},
"discriminative": {
"blocks": dis_blocks,
"factor": factor,
"params_layer": params_layer
"params_layer": dis_params_layer
}
}
regularization = {}
......
......@@ -951,12 +951,6 @@ class BaseTask(object):
Returns:
RunState: the running result of predict phase
"""
if isinstance(self._base_data_reader, hub.reader.LACClassifyReader):
raise Exception(
"LACClassifyReader does not support predictor, please close accelerate_mode"
)
global_run_states = []
period_run_states = []
......@@ -993,16 +987,22 @@ class BaseTask(object):
data (list): the data will be predicted.
load_best_model (bool): load the best model or not
return_result (bool): return a readable result or just the raw run result
accelerate_mode (bool): use high-performance predictor or not
accelerate_mode (bool): use high-performance predictor or not.
Returns:
RunState: the running result of predict phase
"""
if not version_compare(paddle.__version__, "1.6.2") and accelerate_mode:
if accelerate_mode:
if not version_compare(paddle.__version__, "1.6.1"):
logger.warning(
"Fail to open predict accelerate mode as it does not support paddle < 1.6.2. Please update PaddlePaddle."
)
accelerate_mode = False
if isinstance(self._base_data_reader, hub.reader.LACClassifyReader):
logger.warning(
"LACClassifyReader does not support predictor, the accelerate_mode is closed now."
)
accelerate_mode = False
self.accelerate_mode = accelerate_mode
with self.phase_guard(phase="predict"):
......
......@@ -205,7 +205,7 @@ def get_predictions(all_examples, all_features, all_results, n_best_size,
for (feature_index, feature) in enumerate(features):
if feature.unique_id not in unique_id_to_result:
logger.info(
"As using pyreader, the last one batch is so small that the feature %s in the last batch is discarded "
"As using multidevice, the last one batch is so small that the feature %s in the last batch is discarded "
% feature.unique_id)
continue
result = unique_id_to_result[feature.unique_id]
......
......@@ -397,7 +397,8 @@ class TransformerModule(NLPBaseModule):
return inputs, outputs, module_program
def get_embedding(self, texts, use_gpu=False, batch_size=1):
def get_embedding(self, texts, max_seq_len=512, use_gpu=False,
batch_size=1):
"""
get pooled_output and sequence_output for input texts.
Warnings: this method depends on Paddle Inference Library, it may not work properly in PaddlePaddle <= 1.6.2.
......@@ -405,6 +406,7 @@ class TransformerModule(NLPBaseModule):
Args:
texts (list): each element is a text sample, each sample include text_a and text_b where text_b can be omitted.
for example: [[sample0_text_a, sample0_text_b], [sample1_text_a, sample1_text_b], ...]
max_seq_len (int): the max sequence length.
use_gpu (bool): use gpu or not, default False.
batch_size (int): the data batch size, default 1.
......@@ -417,12 +419,12 @@ class TransformerModule(NLPBaseModule):
) or self.emb_job["batch_size"] != batch_size or self.emb_job[
"use_gpu"] != use_gpu:
inputs, outputs, program = self.context(
trainable=True, max_seq_len=self.MAX_SEQ_LEN)
trainable=True, max_seq_len=max_seq_len)
reader = hub.reader.ClassifyReader(
dataset=None,
vocab_path=self.get_vocab_path(),
max_seq_len=self.MAX_SEQ_LEN,
max_seq_len=max_seq_len,
sp_model_path=self.get_spm_path() if hasattr(
self, "get_spm_path") else None,
word_dict_path=self.get_word_dict_path() if hasattr(
......
......@@ -1213,7 +1213,7 @@ class LACClassifyReader(BaseReader):
return processed
if not self.has_processed[phase]:
if not self.has_processed[phase] or phase == "predict":
logger.info(
"processing %s data now... this may take a few minutes" % phase)
for i in range(len(data)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册