提交 8666b7b8 编写于 作者: S Steffy-zxf

update preset net

上级 e43d0b26
......@@ -28,6 +28,7 @@ parser.add_argument("--warmup_proportion", type=float, default=0.1, help="Warmup
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
parser.add_argument("--network", type=str, default='bilstm', help="Preset network which was connected after Transformer model, such as ERNIE, BERT ,RoBERTa and ELECTRA.")
parser.add_argument("--use_data_parallel", type=ast.literal_eval, default=False, help="Whether use data parallel.")
args = parser.parse_args()
# yapf: enable.
......@@ -57,7 +58,7 @@ if __name__ == '__main__':
# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
pooled_output = outputs["pooled_output"]
pooled_output = outputs["sequence_output"]
# Setup feed list for data feeder
# Must feed all the tensor of module need
......@@ -88,6 +89,7 @@ if __name__ == '__main__':
data_reader=reader,
feature=pooled_output,
feed_list=feed_list,
network='dpcnn',
num_classes=dataset.num_labels,
config=config,
metrics_choices=metrics_choices)
......
......@@ -997,11 +997,6 @@ class BaseTask(object):
Returns:
RunState: the running result of predict phase
"""
if not version_compare(paddle.__version__, "1.6.2") and accelerate_mode:
logger.warning(
"Fail to open predict accelerate mode as it does not support paddle < 1.6.2. Please update PaddlePaddle."
)
accelerate_mode = False
self.accelerate_mode = accelerate_mode
with self.phase_guard(phase="predict"):
......
......@@ -193,12 +193,7 @@ class TextClassifierTask(ClassifierTask):
def _build_net(self):
self.seq_len = fluid.layers.data(
name="seq_len", shape=[1], dtype='int64', lod_level=0)
if version_compare(paddle.__version__, "1.6"):
self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1])
else:
self.seq_len_used = self.seq_len
self.seq_len_used = fluid.layers.squeeze(self.seq_len, axes=[1])
unpad_feature = fluid.layers.sequence_unpad(
self.feature, length=self.seq_len_used)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册