提交 86da02ba 编写于 作者: K kinghuin 提交者: wuzewu

fix paddle1.6 unpad and log handler (#243)

* fix unpad
上级 ff5907aa
......@@ -43,7 +43,7 @@ PaddleHub是基于PaddlePaddle生态下的预训练模型管理和迁移学习
除上述依赖外,PaddleHub的预训练模型和预置数据集需要连接服务端进行下载,请确保机器可以正常访问网络。若本地已存在相关的数据集和预训练模型,则可以离线运行PaddleHub。
**NOTE:**
**NOTE:**
1. 若是出现离线运行PaddleHub错误,请更新PaddleHub 1.1.1版本之上。
pip安装方式如下:
......
......@@ -261,6 +261,10 @@ class BasicTask(object):
var = self.env.main_program.global_block().vars[var_name]
var.persistable = True
# to avoid to print logger two times in result of the logger usage of paddle-fluid
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if self.is_train_phase:
with fluid.program_guard(self.env.main_program,
self._base_startup_program):
......@@ -287,10 +291,6 @@ class BasicTask(object):
self.exe.run(self.env.startup_program)
# to avoid to print logger two times in result of the logger usage of paddle-fluid
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
self._build_env_end_event()
@property
......
......@@ -19,9 +19,12 @@ from __future__ import print_function
import time
from collections import OrderedDict
import numpy as np
import paddle
import paddle.fluid as fluid
from paddlehub.finetune.evaluate import chunk_eval, calculate_f1
from paddlehub.common.utils import version_compare
from .basic_task import BasicTask
......@@ -61,8 +64,12 @@ class SequenceLabelTask(BasicTask):
return True
def _build_net(self):
self.seq_len = fluid.layers.data(
name="seq_len", shape=[1], dtype='int64')
if version_compare(paddle.__version__, "1.6"):
self.seq_len = fluid.layers.data(
name="seq_len", shape=[-1], dtype='int64')
else:
self.seq_len = fluid.layers.data(
name="seq_len", shape=[1], dtype='int64')
seq_len = fluid.layers.assign(self.seq_len)
if self.add_crf:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册