提交 4920392b 编写于 作者: Z Zeyu Chen

migrade all tools package to common

上级 6d13d7a4
...@@ -77,6 +77,7 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3): ...@@ -77,6 +77,7 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
def prepare_batch_data(insts, def prepare_batch_data(insts,
total_token_num, total_token_num,
voc_size=0, voc_size=0,
max_seq_len=128,
pad_id=None, pad_id=None,
cls_id=None, cls_id=None,
sep_id=None, sep_id=None,
...@@ -115,15 +116,17 @@ def prepare_batch_data(insts, ...@@ -115,15 +116,17 @@ def prepare_batch_data(insts,
out = batch_src_ids out = batch_src_ids
# Second step: padding # Second step: padding
src_id, self_input_mask = pad_batch_data( src_id, self_input_mask = pad_batch_data(
out, pad_idx=pad_id, return_input_mask=True) out, pad_idx=pad_id, max_seq_len=max_seq_len, return_input_mask=True)
pos_id = pad_batch_data( pos_id = pad_batch_data(
batch_pos_ids, batch_pos_ids,
pad_idx=pad_id, pad_idx=pad_id,
max_seq_len=max_seq_len,
return_pos=False, return_pos=False,
return_input_mask=False) return_input_mask=False)
sent_id = pad_batch_data( sent_id = pad_batch_data(
batch_sent_ids, batch_sent_ids,
pad_idx=pad_id, pad_idx=pad_id,
max_seq_len=max_seq_len,
return_pos=False, return_pos=False,
return_input_mask=False) return_input_mask=False)
...@@ -139,6 +142,7 @@ def prepare_batch_data(insts, ...@@ -139,6 +142,7 @@ def prepare_batch_data(insts,
def pad_batch_data(insts, def pad_batch_data(insts,
pad_idx=0, pad_idx=0,
max_seq_len=128,
return_pos=False, return_pos=False,
return_input_mask=False, return_input_mask=False,
return_max_len=False, return_max_len=False,
...@@ -149,7 +153,7 @@ def pad_batch_data(insts, ...@@ -149,7 +153,7 @@ def pad_batch_data(insts,
""" """
return_list = [] return_list = []
#max_len = max(len(inst) for inst in insts) #max_len = max(len(inst) for inst in insts)
max_len = 50 max_len = max_seq_len
# Any token included in dict can be used to pad, since the paddings' loss # Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients. # will be masked out by weights and make no effect on parameter gradients.
......
...@@ -93,6 +93,7 @@ class DataProcessor(object): ...@@ -93,6 +93,7 @@ class DataProcessor(object):
batch_data, batch_data,
total_token_num, total_token_num,
voc_size=-1, voc_size=-1,
max_seq_len=self.max_seq_len,
pad_id=self.vocab["[PAD]"], pad_id=self.vocab["[PAD]"],
cls_id=self.vocab["[CLS]"], cls_id=self.vocab["[CLS]"],
sep_id=self.vocab["[SEP]"], sep_id=self.vocab["[SEP]"],
......
...@@ -15,5 +15,5 @@ python -u finetune_with_hub.py \ ...@@ -15,5 +15,5 @@ python -u finetune_with_hub.py \
--checkpoint_dir $CKPT_DIR \ --checkpoint_dir $CKPT_DIR \
--warmup_proportion 0.0 \ --warmup_proportion 0.0 \
--epoch 3 \ --epoch 3 \
--max_seq_len 50 \ --max_seq_len 128 \
--learning_rate 5e-5 --learning_rate 5e-5
...@@ -11,21 +11,25 @@ ...@@ -11,21 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dir import USER_HOME
from .dir import HUB_HOME
from .dir import MODULE_HOME
from .dir import CACHE_HOME
from . import module from . import module
from . import tools from . import common
from . import io from . import io
from .common.dir import USER_HOME
from .common.dir import HUB_HOME
from .common.dir import MODULE_HOME
from .common.dir import CACHE_HOME
from .common.logger import logger
from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server
from .module.module import Module, create_module from .module.module import Module, create_module
from .module.base_processor import BaseProcessor from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature from .module.signature import Signature, create_signature
from .module.manager import default_module_manager from .module.manager import default_module_manager
from .tools.logger import logger
from .tools.paddle_helper import connect_program
from .io.type import DataType from .io.type import DataType
from .hub_server import default_hub_server
from .finetune.network import append_mlp_classifier from .finetune.network import append_mlp_classifier
from .finetune.finetune import finetune_and_eval from .finetune.finetune import finetune_and_eval
from .finetune.config import FinetuneConfig from .finetune.config import FinetuneConfig
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.tools.logger import logger from paddle_hub.common.logger import logger
import six import six
import distutils.util import distutils.util
......
...@@ -26,8 +26,8 @@ import requests ...@@ -26,8 +26,8 @@ import requests
import tempfile import tempfile
import tarfile import tarfile
from paddle_hub.tools import utils from paddle_hub.common import utils
from paddle_hub.tools.logger import logger from paddle_hub.common.logger import logger
from paddle_hub.io.reader import csv_reader from paddle_hub.io.reader import csv_reader
__all__ = ['Downloader'] __all__ = ['Downloader']
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.tools import utils from paddle_hub.common import utils
from paddle_hub.tools.downloader import default_downloader from paddle_hub.common.downloader import default_downloader
from paddle_hub.io.reader import csv_reader from paddle_hub.io.reader import csv_reader
import os import os
import time import time
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.module import module_desc_pb2 from ..module import module_desc_pb2
from paddle_hub.tools.utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj from .utils import from_pyobj_to_flexible_data, from_flexible_data_to_pyobj
from paddle_hub.tools.logger import logger from .logger import logger
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import copy import copy
......
...@@ -18,7 +18,7 @@ from __future__ import absolute_import ...@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.module import module_desc_pb2 from paddle_hub.module import module_desc_pb2
from paddle_hub.tools.logger import logger from paddle_hub.common.logger import logger
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import os import os
......
...@@ -22,7 +22,21 @@ from collections import namedtuple ...@@ -22,7 +22,21 @@ from collections import namedtuple
DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp_data.tar.gz" DATA_URL = "https://paddlehub-dataset.bj.bcebos.com/chnsenticorp_data.tar.gz"
class ChnSentiCorp(object): class HubDataset(object):
def get_train_examples(self):
raise NotImplementedError()
def get_dev_examples(self):
raise NotImplementedError()
def get_test_examples(self):
raise NotImplementedError()
def get_val_examples(self):
return self.get_dev_examples()
class ChnSentiCorp(HubDataset):
def __init__(self): def __init__(self):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress( ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=DATA_URL, save_path=DATA_HOME, print_progress=True) url=DATA_URL, save_path=DATA_HOME, print_progress=True)
......
...@@ -23,7 +23,7 @@ import paddle ...@@ -23,7 +23,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from visualdl import LogWriter from visualdl import LogWriter
from paddle_hub.tools.logger import logger from paddle_hub.common.logger import logger
from paddle_hub.finetune.optimization import bert_finetune from paddle_hub.finetune.optimization import bert_finetune
from paddle_hub.finetune.checkpoint import load_checkpoint, save_checkpoint from paddle_hub.finetune.checkpoint import load_checkpoint, save_checkpoint
......
...@@ -16,7 +16,7 @@ from __future__ import absolute_import ...@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from paddle_hub.tools import utils from paddle_hub.common import utils
import numpy as np import numpy as np
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
from enum import Enum from enum import Enum
from PIL import Image from PIL import Image
from paddle_hub.tools.logger import logger from paddle_hub.common.logger import logger
from paddle_hub.tools import utils from paddle_hub.common import utils
class DataType(Enum): class DataType(Enum):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.tools.logger import logger from paddle_hub.common.logger import logger
from paddle_hub.module import check_info_pb2 from paddle_hub.module import check_info_pb2
from paddle_hub.version import hub_version, module_proto_version from paddle_hub.version import hub_version, module_proto_version
import os import os
......
...@@ -19,8 +19,8 @@ from __future__ import print_function ...@@ -19,8 +19,8 @@ from __future__ import print_function
import os import os
import shutil import shutil
from paddle_hub.tools import utils from paddle_hub.common import utils
from paddle_hub.tools.downloader import default_downloader from paddle_hub.common.downloader import default_downloader
import paddle_hub as hub import paddle_hub as hub
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle_hub.tools import utils from paddle_hub.common import utils
from paddle_hub.tools.logger import logger from paddle_hub.common.logger import logger
from paddle_hub.tools.downloader import default_downloader from paddle_hub.common.downloader import default_downloader
from paddle_hub.tools import paddle_helper from paddle_hub.common import paddle_helper
from paddle_hub.module import module_desc_pb2 from paddle_hub.module import module_desc_pb2
from paddle_hub.module import check_info_pb2 from paddle_hub.module import check_info_pb2
from paddle_hub.module.signature import Signature, create_signature from paddle_hub.module.signature import Signature, create_signature
...@@ -458,7 +458,6 @@ class Module(object): ...@@ -458,7 +458,6 @@ class Module(object):
# TODO(ZeyuChen) encapsulate into a funtion # TODO(ZeyuChen) encapsulate into a funtion
# update BERT/ERNIE's input tensor's sequence length to max_seq_len # update BERT/ERNIE's input tensor's sequence length to max_seq_len
if self.name.startswith("bert") or self.name.startswith("ernie"): if self.name.startswith("bert") or self.name.startswith("ernie"):
print("module_name", self.name)
MAX_SEQ_LENGTH = 512 MAX_SEQ_LENGTH = 512
if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0: if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
raise ValueError( raise ValueError(
......
...@@ -16,7 +16,7 @@ from __future__ import absolute_import ...@@ -16,7 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle_hub.tools.utils import to_list from paddle_hub.common.utils import to_list
class Signature: class Signature:
......
test_downloader #test_downloader
test_export_n_load_module #test_export_n_load_module
#test_module #test_module
test_train_w2v #test_train_w2v
test_pyobj_serialize #test_pyobj_serialize
test_signature #test_signature
test_param_serialize #test_param_serialize
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册