提交 d2fece3e 编写于 作者: W wuzewu

Update comment

上级 1562d65f
...@@ -17,7 +17,7 @@ import os ...@@ -17,7 +17,7 @@ import os
import pickle import pickle
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable from typing import Any, Callable, List
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
...@@ -29,7 +29,25 @@ from paddlehub.utils.utils import Timer ...@@ -29,7 +29,25 @@ from paddlehub.utils.utils import Timer
class Trainer(object): class Trainer(object):
''' '''
Trainer Model trainer
Args:
model(paddle.nn.Layer) : Model to train or evaluate.
strategy(paddle.optimizer.Optimizer) : Optimizer strategy.
use_vdl(bool) : Whether to use visualdl to record training data.
checkpoint_dir(str) : Directory where the checkpoint is saved, and the trainer will restore the
state and model parameters from the checkpoint.
compare_metrics(callable) : The method of comparing the model metrics. If not specified, the main
metric return by `validation_step` will be used for comparison by default, the larger the
value, the better the effect. This method will affect the saving of the best model. If the
default behavior does not meet your requirements, please pass in a custom method.
Example:
.. code-block:: python
def compare_metrics(old_metric: dict, new_metric: dict):
mainkey = list(new_metric.keys())[0]
return old_metric[mainkey] < new_metric[mainkey]
''' '''
def __init__(self, def __init__(self,
...@@ -130,7 +148,8 @@ class Trainer(object): ...@@ -130,7 +148,8 @@ class Trainer(object):
epochs(int) : Number of training loops, default is 1. epochs(int) : Number of training loops, default is 1.
batch_size(int) : Batch size of per step, default is 1. batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0. num_workers(int) : Number of subprocess to load data, default is 0.
eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will execute evaluate function every `save_interval` epochs. eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will
execute evaluate function every `save_interval` epochs.
log_interval(int) : Log the train infomation every `log_interval` steps. log_interval(int) : Log the train infomation every `log_interval` steps.
save_interval(int) : Save the checkpoint every `save_interval` epochs. save_interval(int) : Save the checkpoint every `save_interval` epochs.
''' '''
...@@ -269,7 +288,14 @@ class Trainer(object): ...@@ -269,7 +288,14 @@ class Trainer(object):
return {'loss': avg_loss, 'metrics': avg_metrics} return {'loss': avg_loss, 'metrics': avg_metrics}
return {'metrics': avg_metrics} return {'metrics': avg_metrics}
def training_step(self, batch: Any, batch_idx: int): def training_step(self, batch: List[paddle.Tensor], batch_idx: int):
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]) : The one batch data
batch_idx(int) : The index of batch.
'''
if self.nranks > 1: if self.nranks > 1:
result = self.model._layers.training_step(batch, batch_idx) result = self.model._layers.training_step(batch, batch_idx)
else: else:
...@@ -296,17 +322,42 @@ class Trainer(object): ...@@ -296,17 +322,42 @@ class Trainer(object):
return loss, metrics return loss, metrics
def validation_step(self, batch: Any, batch_idx: int): def validation_step(self, batch: Any, batch_idx: int):
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]) : The one batch data
batch_idx(int) : The index of batch.
'''
if self.nranks > 1: if self.nranks > 1:
result = self.model._layers.validation_step(batch, batch_idx) result = self.model._layers.validation_step(batch, batch_idx)
else: else:
result = self.model.validation_step(batch, batch_idx) result = self.model.validation_step(batch, batch_idx)
return result return result
def optimizer_step(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer, def optimizer_step(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer,
loss: paddle.Tensor): loss: paddle.Tensor):
'''
One step for optimize.
Args:
epoch_idx(int) : The index of epoch.
batch_idx(int) : The index of batch.
optimizer(paddle.optimizer.Optimizer) : Optimizer used.
loss(paddle.Tensor) : Loss tensor.
'''
self.optimizer.minimize(loss) self.optimizer.minimize(loss)
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer): def optimizer_zero_grad(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer):
'''
One step for clear gradients.
Args:
epoch_idx(int) : The index of epoch.
batch_idx(int) : The index of batch.
optimizer(paddle.optimizer.Optimizer) : Optimizer used.
loss(paddle.Tensor) : Loss tensor.
'''
self.model.clear_gradients() self.model.clear_gradients()
def _compare_metrics(self, old_metric: dict, new_metric: dict): def _compare_metrics(self, old_metric: dict, new_metric: dict):
......
...@@ -38,8 +38,8 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -38,8 +38,8 @@ class ImageClassifierModule(RunModule, ImageServing):
One step for training, which should be called as forward computation. One step for training, which should be called as forward computation.
Args: Args:
batch(list[paddle.Variable]): The one batch data, which contains images and labels. batch(list[paddle.Tensor]) : The one batch data, which contains images and labels.
batch_idx(int): The index of batch. batch_idx(int) : The index of batch.
Returns: Returns:
results(dict) : The model outputs, such as loss and metrics. results(dict) : The model outputs, such as loss and metrics.
...@@ -51,8 +51,8 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -51,8 +51,8 @@ class ImageClassifierModule(RunModule, ImageServing):
One step for validation, which should be called as forward computation. One step for validation, which should be called as forward computation.
Args: Args:
batch(list[paddle.Variable]): The one batch data, which contains images and labels. batch(list[paddle.Tensor]) : The one batch data, which contains images and labels.
batch_idx(int): The index of batch. batch_idx(int) : The index of batch.
Returns: Returns:
results(dict) : The model outputs, such as metrics. results(dict) : The model outputs, such as metrics.
...@@ -80,7 +80,7 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -80,7 +80,7 @@ class ImageClassifierModule(RunModule, ImageServing):
images = self.transforms(images) images = self.transforms(images)
if len(images.shape) == 3: if len(images.shape) == 3:
images = images[np.newaxis, :] images = images[np.newaxis, :]
preds = self(paddle.to_variable(images)) preds = self(paddle.to_tensor(images))
preds = F.softmax(preds, axis=1).numpy() preds = F.softmax(preds, axis=1).numpy()
pred_idxs = np.argsort(preds)[::-1][:, :top_k] pred_idxs = np.argsort(preds)[::-1][:, :top_k]
res = [] res = []
...@@ -91,6 +91,3 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -91,6 +91,3 @@ class ImageClassifierModule(RunModule, ImageServing):
res_dict[class_name] = preds[i][k] res_dict[class_name] = preds[i][k]
res.append(res_dict) res.append(res_dict)
return res return res
def is_better_score(self, old_score: dict, new_score: dict):
return old_score['acc'] < new_score['acc']
...@@ -29,7 +29,15 @@ from paddlehub.utils import log ...@@ -29,7 +29,15 @@ from paddlehub.utils import log
class GitSource(object): class GitSource(object):
def __init__(self, url, path=None): '''
Git source for PaddleHub module
Args:
url(str) : Url of git repository
path(str) : Path to store the git repository
'''
def __init__(self, url: str, path: str = None):
self.url = url self.url = url
self._parse_result = urlparse(self.url) self._parse_result = urlparse(self.url)
...@@ -66,10 +74,25 @@ class GitSource(object): ...@@ -66,10 +74,25 @@ class GitSource(object):
log.logger.warning('An error occurred while loading {}'.format(self.path)) log.logger.warning('An error occurred while loading {}'.format(self.path))
sys.path.remove(self.path) sys.path.remove(self.path)
def search_module(self, name, version=None): def search_module(self, name: str, version: str = None) -> dict:
'''
Search PaddleHub module
Args:
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version) return self.search_resouce(type='module', name=name, version=version)
def search_resouce(self, type, name, version=None): def search_resouce(self, type: str, name: str, version: str = None) -> dict:
'''
Search PaddleHub Resource
Args:
type(str) : Resource type
name(str) : Resource name
version(str) : Resource version
'''
module = self.hub_modules.get(name, None) module = self.hub_modules.get(name, None)
if module and module.version.match(version): if module and module.version.match(version):
return { return {
...@@ -82,7 +105,13 @@ class GitSource(object): ...@@ -82,7 +105,13 @@ class GitSource(object):
return None return None
@classmethod @classmethod
def check(cls, url): def check(cls, url: str) -> bool:
'''
Check if the specified url is a valid git repository link
Args:
url(str) : Url to check
'''
try: try:
git.cmd.Git().ls_remote(url) git.cmd.Git().ls_remote(url)
return True return True
......
...@@ -21,10 +21,12 @@ PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub' ...@@ -21,10 +21,12 @@ PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'
class HubServer(object): class HubServer(object):
'''PaddleHub server'''
def __init__(self): def __init__(self):
self.sources = OrderedDict() self.sources = OrderedDict()
def _generate_source(self, url): def _generate_source(self, url: str):
if ServerSource.check(url): if ServerSource.check(url):
source = ServerSource(url) source = ServerSource(url)
elif GitSource.check(url): elif GitSource.check(url):
...@@ -33,17 +35,34 @@ class HubServer(object): ...@@ -33,17 +35,34 @@ class HubServer(object):
raise RuntimeError() raise RuntimeError()
return source return source
def add_source(self, url, key=None): def add_source(self, url: str, key: str = None):
'''Add a module source(GitSource or ServerSource)'''
key = "source_{}".format(len(self.sources)) if not key else key key = "source_{}".format(len(self.sources)) if not key else key
self.sources[key] = self._generate_source(url) self.sources[key] = self._generate_source(url)
def remove_source(self, url=None, key=None): def remove_source(self, url: str = None, key: str = None):
'''Remove a module source'''
self.sources.pop(key) self.sources.pop(key)
def search_module(self, name, version=None, source=None): def search_module(self, name: str, version: str = None, source: str = None) -> dict:
'''
Search PaddleHub module
Args:
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version, source=source) return self.search_resouce(type='module', name=name, version=version, source=source)
def search_resouce(self, type, name, version=None, source=None): def search_resouce(self, type: str, name: str, version: str = None, source: str = None) -> dict:
'''
Search PaddleHub Resource
Args:
type(str) : Resource type
name(str) : Resource name
version(str) : Resource version
'''
sources = self.sources.values() if not source else [self._generate_source(source)] sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources: for source in sources:
result = source.search_resouce(name=name, type=type, version=version) result = source.search_resouce(name=name, type=type, version=version)
......
...@@ -25,14 +25,37 @@ from paddlehub.utils import utils ...@@ -25,14 +25,37 @@ from paddlehub.utils import utils
class ServerSource(object): class ServerSource(object):
def __init__(self, url, timeout=10): '''
PaddleHub server source
Args:
url(str) : Url of the server
timeout(int) : Request timeout
'''
def __init__(self, url: str, timeout: int = 10):
self._url = url self._url = url
self._timeout = timeout self._timeout = timeout
def search_module(self, name, version=None): def search_module(self, name: str, version: str = None) -> dict:
'''
Search PaddleHub module
Args:
name(str) : PaddleHub module name
version(str) : PaddleHub module version
'''
return self.search_resouce(type='module', name=name, version=version) return self.search_resouce(type='module', name=name, version=version)
def search_resouce(self, type, name, version=None): def search_resouce(self, type: str, name: str, version: str = None) -> dict:
'''
Search PaddleHub Resource
Args:
type(str) : Resource type
name(str) : Resource name
version(str) : Resource version
'''
payload = {'environments': {}} payload = {'environments': {}}
payload['word'] = name payload['word'] = name
...@@ -59,7 +82,13 @@ class ServerSource(object): ...@@ -59,7 +82,13 @@ class ServerSource(object):
return None return None
@classmethod @classmethod
def check(cls, url): def check(cls, url: str) -> bool:
'''
Check if the specified url is a valid paddlehub server
Args:
url(str) : Url to check
'''
try: try:
r = requests.get(url + '/search') r = requests.get(url + '/search')
return r.status_code == 200 return r.status_code == 200
......
...@@ -50,7 +50,7 @@ def redirect_estream(stream: IO): ...@@ -50,7 +50,7 @@ def redirect_estream(stream: IO):
@contextlib.contextmanager @contextlib.contextmanager
def discard_oe(): def discard_oe():
''' '''
Redirect input and output stream to temporary file. In a sense, Redirect output and error stream to temporary file. In a sense,
it is equivalent discarded the output and error messages it is equivalent discarded the output and error messages
''' '''
with generate_tempfile(mode='w') as _stream: with generate_tempfile(mode='w') as _stream:
......
...@@ -96,6 +96,7 @@ class ProgressBar(object): ...@@ -96,6 +96,7 @@ class ProgressBar(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
with ProgressBar('Download module') as bar: with ProgressBar('Download module') as bar:
for i in range(100): for i in range(100):
bar.update(i / 100) bar.update(i / 100)
...@@ -126,6 +127,10 @@ class ProgressBar(object): ...@@ -126,6 +127,10 @@ class ProgressBar(object):
def update(self, progress: float): def update(self, progress: float):
''' '''
Update progress bar
Args:
progress: Processing progress, from 0.0 to 1.0
''' '''
msg = '[{:<50}] {:.2f}%'.format('#' * int(progress * 50), progress * 100) msg = '[{:<50}] {:.2f}%'.format('#' * int(progress * 50), progress * 100)
need_flush = (time.time() - self.last_flush_time) >= self.flush_interval need_flush = (time.time() - self.last_flush_time) >= self.flush_interval
...@@ -146,14 +151,14 @@ class FormattedText(object): ...@@ -146,14 +151,14 @@ class FormattedText(object):
Args: Args:
text(str) : Text content text(str) : Text content
width(int) : Text length, if the text is less than the specified length, it will be filled with spaces width(int) : Text length, if the text is less than the specified length, it will be filled with spaces
align(str) : it must be: align(str) : Text alignment, it must be:
======== ================== ======== ====================================
Charater Meaning Charater Meaning
-------- ------------------ -------- ------------------------------------
'<' left aligned '<' The text will remain left aligned
'^' middle aligned '^' The text will remain middle aligned
'>' right aligned '>' The text will remain right aligned
======== ================== ======== ====================================
color(str) : Text color, default is None(depends on terminal configuration) color(str) : Text color, default is None(depends on terminal configuration)
''' '''
_MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE} _MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE}
...@@ -293,12 +298,13 @@ class Table(object): ...@@ -293,12 +298,13 @@ class Table(object):
Table with adaptive width and height Table with adaptive width and height
Args: Args:
colors(list[str]) : Text colors of contents one by one colors(list[str]) : Text colors
aligns(list[str]) : Text aligns of contents one by one aligns(list[str]) : Text alignments
widths(list[str]) : Text widths of contents one by one widths(list[str]) : Text widths
Examples: Examples:
.. code-block:: python .. code-block:: python
table = Table(widths=[12, 20]) table = Table(widths=[12, 20])
table.append('name', 'PaddleHub') table.append('name', 'PaddleHub')
table.append('version', '2.0.0') table.append('version', '2.0.0')
...@@ -337,9 +343,9 @@ class Table(object): ...@@ -337,9 +343,9 @@ class Table(object):
Args: Args:
*contents(*list): Contents of the row, each content will be placed in a separate cell *contents(*list): Contents of the row, each content will be placed in a separate cell
colors(list[str]) : Text colors of contents one by one, if not set, the default value will be used. colors(list[str]) : Text colors
aligns(list[str]) : Text aligns of contents one by one, if not set, the default value will be used. aligns(list[str]) : Text alignments
widths(list[str]) : Text widths of contents one by one, if not set, the default value will be used. widths(list[str]) : Text widths
''' '''
newrow = TableRow() newrow = TableRow()
......
...@@ -32,7 +32,7 @@ import paddlehub.env as hubenv ...@@ -32,7 +32,7 @@ import paddlehub.env as hubenv
class Version(packaging.version.Version): class Version(packaging.version.Version):
'''Expand realization of packaging.version.Version''' '''Extended implementation of packaging.version.Version'''
def match(self, condition: str) -> bool: def match(self, condition: str) -> bool:
''' '''
...@@ -45,7 +45,7 @@ class Version(packaging.version.Version): ...@@ -45,7 +45,7 @@ class Version(packaging.version.Version):
bool: True if the given version condition are met, else False bool: True if the given version condition are met, else False
Examples: Examples:
from paddlehub.utils import Version .. code-block:: python
Version('1.2.0').match('>=1.2.0a') Version('1.2.0').match('>=1.2.0a')
''' '''
...@@ -162,7 +162,6 @@ def download(url: str, path: str = None) -> str: ...@@ -162,7 +162,6 @@ def download(url: str, path: str = None) -> str:
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlehub.utils.utils import download
url = 'https://xxxxx.xx/xx.tar.gz' url = 'https://xxxxx.xx/xx.tar.gz'
download(url, path='./output') download(url, path='./output')
...@@ -182,7 +181,6 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in ...@@ -182,7 +181,6 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlehub.utils.utils import download_with_progress
url = 'https://xxxxx.xx/xx.tar.gz' url = 'https://xxxxx.xx/xx.tar.gz'
for filename, download_size, total_szie in download_with_progress(url, path='./output'): for filename, download_size, total_szie in download_with_progress(url, path='./output'):
......
...@@ -177,7 +177,6 @@ def archive(filename: str, recursive: bool = True, exclude: Callable = None, arc ...@@ -177,7 +177,6 @@ def archive(filename: str, recursive: bool = True, exclude: Callable = None, arc
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlehub.utils import archive
archive_path = '/PATH/TO/FILE' archive_path = '/PATH/TO/FILE'
archive(archive_path, arcname='output.tar.gz', arctype='tar.gz') archive(archive_path, arcname='output.tar.gz', arctype='tar.gz')
...@@ -200,7 +199,6 @@ def unarchive(name: str, path: str): ...@@ -200,7 +199,6 @@ def unarchive(name: str, path: str):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlehub.utils import unarchive
unarchive_path = '/PATH/TO/FILE' unarchive_path = '/PATH/TO/FILE'
unarchive(unarchive_path, path='./output') unarchive(unarchive_path, path='./output')
...@@ -219,7 +217,6 @@ def unarchive_with_progress(name: str, path: str) -> Generator[str, int, int]: ...@@ -219,7 +217,6 @@ def unarchive_with_progress(name: str, path: str) -> Generator[str, int, int]:
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlehub.utils.xarfile import unarchive_with_progress
unarchive_path = 'test.tar.gz' unarchive_path = 'test.tar.gz'
for filename, extract_size, total_szie in unarchive_with_progress(unarchive_path, path='./output'): for filename, extract_size, total_szie in unarchive_with_progress(unarchive_path, path='./output'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册