提交 4920392b 编写于 作者: Z Zeyu Chen

migrade all tools package to common

上级 6d13d7a4
......@@ -77,6 +77,7 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
def prepare_batch_data(insts,
total_token_num,
voc_size=0,
max_seq_len=128,
pad_id=None,
cls_id=None,
sep_id=None,
......@@ -115,15 +116,17 @@ def prepare_batch_data(insts,
out = batch_src_ids
# Second step: padding
src_id, self_input_mask = pad_batch_data(
out, pad_idx=pad_id, return_input_mask=True)
out, pad_idx=pad_id, max_seq_len=max_seq_len, return_input_mask=True)
pos_id = pad_batch_data(
batch_pos_ids,
pad_idx=pad_id,
max_seq_len=max_seq_len,
return_pos=False,
return_input_mask=False)
sent_id = pad_batch_data(
batch_sent_ids,
pad_idx=pad_id,
max_seq_len=max_seq_len,
return_pos=False,
return_input_mask=False)
......@@ -139,6 +142,7 @@ def prepare_batch_data(insts,
def pad_batch_data(insts,
pad_idx=0,
max_seq_len=128,
return_pos=False,
return_input_mask=False,
return_max_len=False,
......@@ -149,7 +153,7 @@ def pad_batch_data(insts,
"""
return_list = []
#max_len = max(len(inst) for inst in insts)
max_len = 50
max_len = max_seq_len
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
......
......@@ -93,6 +93,7 @@ class DataProcessor(object):
batch_data,
total_token_num,
voc_size=-1,
max_seq_len=self.max_seq_len,
pad_id=self.vocab["[PAD]"],
cls_id=self.vocab["[CLS]"],
sep_id=self.vocab["[SEP]"],
......
......@@ -15,5 +15,5 @@ python -u finetune_with_hub.py \
--checkpoint_dir $CKPT_DIR \
--warmup_proportion 0.0 \
--epoch 3 \
--max_seq_len 50 \
--max_seq_len 128 \
--learning_rate 5e-5
......@@ -11,21 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dir import USER_HOME
from .dir import HUB_HOME
from .dir import MODULE_HOME
from .dir import CACHE_HOME
from . import module
from . import tools
from . import common
from . import io
from .common.dir import USER_HOME
from .common.dir import HUB_HOME
from .common.dir import MODULE_HOME
from .common.dir import CACHE_HOME
from .common.logger import logger
from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server
from .module.module import Module, create_module
from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature
from .module.manager import default_module_manager
from .tools.logger import logger
from .tools.paddle_helper import connect_program
from .io.type import DataType
from .hub_server import default_hub_server
from .finetune.network import append_mlp_classifier
from .finetune.finetune import finetune_and_eval
from .finetune.config import FinetuneConfig
......
......@@ -15,7 +15,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools.logger import logger
from paddle_hub.common.logger import logger
import six
import distutils.util
......
......@@ -26,8 +26,8 @@ import requests
import tempfile
import tarfile
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
from paddle_hub.common import utils
from paddle_hub.common.logger import logger
from paddle_hub.io.reader import csv_reader
__all__ = ['Downloader']
......
......@@ -15,8 +15,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.common import utils
from paddle_hub.common.downloader import default_downloader
from paddle_hub.io.reader import csv_reader
import os
import time
......
......@@ -15,9 +15,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.module import module_desc_pb2
from paddle_hub.tools.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.tools.logger import logger
from ..module import module_desc_pb2
from .utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from .logger import logger
import paddle
import paddle.fluid as fluid
import copy
......
......@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.module import module_desc_pb2
from paddle_hub.tools.logger import logger
from paddle_hub.common.logger import logger
import paddle
import paddle.fluid as fluid
import os
......
......@@ -22,7 +22,21 @@ from collections import namedtuple
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp_data.tar.gz"
class ChnSentiCorp(object):
class HubDataset(object):
def get_train_examples(self):
raise NotImplementedError()
def get_dev_examples(self):
raise NotImplementedError()
def get_test_examples(self):
raise NotImplementedError()
def get_val_examples(self):
return self.get_dev_examples()
class ChnSentiCorp(HubDataset):
def __init__(self):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True)
......
......@@ -23,7 +23,7 @@ import paddle
import paddle.fluid as fluid
from visualdl import LogWriter
from paddle_hub.tools.logger import logger
from paddle_hub.common.logger import logger
from paddle_hub.finetune.optimization import bert_finetune
from paddle_hub.finetune.checkpoint import load_checkpoint, save_checkpoint
......
......@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from PIL import Image, ImageEnhance
from paddle_hub.tools import utils
from paddle_hub.common import utils
import numpy as np
......
......@@ -14,8 +14,8 @@
from enum import Enum
from PIL import Image
from paddle_hub.tools.logger import logger
from paddle_hub.tools import utils
from paddle_hub.common.logger import logger
from paddle_hub.common import utils
class DataType(Enum):
......
......@@ -14,7 +14,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools.logger import logger
from paddle_hub.common.logger import logger
from paddle_hub.module import check_info_pb2
from paddle_hub.version import hub_version, module_proto_version
import os
......
......@@ -19,8 +19,8 @@ from __future__ import print_function
import os
import shutil
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.common import utils
from paddle_hub.common.downloader import default_downloader
import paddle_hub as hub
......
......@@ -15,10 +15,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.tools import paddle_helper
from paddle_hub.common import utils
from paddle_hub.common.logger import logger
from paddle_hub.common.downloader import default_downloader
from paddle_hub.common import paddle_helper
from paddle_hub.module import module_desc_pb2
from paddle_hub.module import check_info_pb2
from paddle_hub.module.signature import Signature, create_signature
......@@ -458,7 +458,6 @@ class Module(object):
# TODO(ZeyuChen) encapsulate into a funtion
# update BERT/ERNIE's input tensor's sequence length to max_seq_len
if self.name.startswith("bert") or self.name.startswith("ernie"):
print("module_name", self.name)
MAX_SEQ_LENGTH = 512
if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
raise ValueError(
......
......@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle.fluid.framework import Variable
from paddle_hub.tools.utils import to_list
from paddle_hub.common.utils import to_list
class Signature:
......
test_downloader
test_export_n_load_module
#test_downloader
#test_export_n_load_module
#test_module
test_train_w2v
test_pyobj_serialize
test_signature
test_param_serialize
\ No newline at end of file
#test_train_w2v
#test_pyobj_serialize
#test_signature
#test_param_serialize
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册