From d2fece3e09a1f143fd720d92211d90504e434494 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Fri, 4 Sep 2020 18:07:56 +0800 Subject: [PATCH] Update comment --- paddlehub/finetune/trainer.py | 63 ++++++++++++++++++++++++++++--- paddlehub/module/cv_module.py | 13 +++---- paddlehub/server/git_source.py | 37 ++++++++++++++++-- paddlehub/server/server.py | 29 +++++++++++--- paddlehub/server/server_source.py | 37 ++++++++++++++++-- paddlehub/utils/io.py | 2 +- paddlehub/utils/log.py | 50 +++++++++++++----------- paddlehub/utils/utils.py | 8 ++-- paddlehub/utils/xarfile.py | 3 -- 9 files changed, 184 insertions(+), 58 deletions(-) diff --git a/paddlehub/finetune/trainer.py b/paddlehub/finetune/trainer.py index 35ef8213..620dd9d6 100644 --- a/paddlehub/finetune/trainer.py +++ b/paddlehub/finetune/trainer.py @@ -17,7 +17,7 @@ import os import pickle import time from collections import defaultdict -from typing import Any, Callable +from typing import Any, Callable, List import paddle from paddle.distributed import ParallelEnv @@ -29,7 +29,25 @@ from paddlehub.utils.utils import Timer class Trainer(object): ''' - Trainer + Model trainer + + Args: + model(paddle.nn.Layer) : Model to train or evaluate. + strategy(paddle.optimizer.Optimizer) : Optimizer strategy. + use_vdl(bool) : Whether to use visualdl to record training data. + checkpoint_dir(str) : Directory where the checkpoint is saved, and the trainer will restore the + state and model parameters from the checkpoint. + compare_metrics(callable) : The method of comparing the model metrics. If not specified, the main + metric return by `validation_step` will be used for comparison by default, the larger the + value, the better the effect. This method will affect the saving of the best model. If the + default behavior does not meet your requirements, please pass in a custom method. + + Example: + .. code-block:: python + + def compare_metrics(old_metric: dict, new_metric: dict): + mainkey = list(new_metric.keys())[0] + return old_metric[mainkey] < new_metric[mainkey] ''' def __init__(self, @@ -130,7 +148,8 @@ class Trainer(object): epochs(int) : Number of training loops, default is 1. batch_size(int) : Batch size of per step, default is 1. num_workers(int) : Number of subprocess to load data, default is 0. - eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will execute evaluate function every `save_interval` epochs. + eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will + execute evaluate function every `save_interval` epochs. log_interval(int) : Log the train infomation every `log_interval` steps. save_interval(int) : Save the checkpoint every `save_interval` epochs. ''' @@ -269,7 +288,14 @@ class Trainer(object): return {'loss': avg_loss, 'metrics': avg_metrics} return {'metrics': avg_metrics} - def training_step(self, batch: Any, batch_idx: int): + def training_step(self, batch: List[paddle.Tensor], batch_idx: int): + ''' + One step for training, which should be called as forward computation. + + Args: + batch(list[paddle.Tensor]) : The one batch data + batch_idx(int) : The index of batch. + ''' if self.nranks > 1: result = self.model._layers.training_step(batch, batch_idx) else: @@ -296,17 +322,42 @@ class Trainer(object): return loss, metrics def validation_step(self, batch: Any, batch_idx: int): + ''' + One step for validation, which should be called as forward computation. + + Args: + batch(list[paddle.Tensor]) : The one batch data + batch_idx(int) : The index of batch. + ''' if self.nranks > 1: result = self.model._layers.validation_step(batch, batch_idx) else: result = self.model.validation_step(batch, batch_idx) return result - def optimizer_step(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer, + def optimizer_step(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer, loss: paddle.Tensor): + ''' + One step for optimize. + + Args: + epoch_idx(int) : The index of epoch. + batch_idx(int) : The index of batch. + optimizer(paddle.optimizer.Optimizer) : Optimizer used. + loss(paddle.Tensor) : Loss tensor. + ''' self.optimizer.minimize(loss) - def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer): + def optimizer_zero_grad(self, epoch_idx: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer): + ''' + One step for clear gradients. + + Args: + epoch_idx(int) : The index of epoch. + batch_idx(int) : The index of batch. + optimizer(paddle.optimizer.Optimizer) : Optimizer used. + loss(paddle.Tensor) : Loss tensor. + ''' self.model.clear_gradients() def _compare_metrics(self, old_metric: dict, new_metric: dict): diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index 5329bed9..2418b15b 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -38,8 +38,8 @@ class ImageClassifierModule(RunModule, ImageServing): One step for training, which should be called as forward computation. Args: - batch(list[paddle.Variable]): The one batch data, which contains images and labels. - batch_idx(int): The index of batch. + batch(list[paddle.Tensor]) : The one batch data, which contains images and labels. + batch_idx(int) : The index of batch. Returns: results(dict) : The model outputs, such as loss and metrics. @@ -51,8 +51,8 @@ class ImageClassifierModule(RunModule, ImageServing): One step for validation, which should be called as forward computation. Args: - batch(list[paddle.Variable]): The one batch data, which contains images and labels. - batch_idx(int): The index of batch. + batch(list[paddle.Tensor]) : The one batch data, which contains images and labels. + batch_idx(int) : The index of batch. Returns: results(dict) : The model outputs, such as metrics. @@ -80,7 +80,7 @@ class ImageClassifierModule(RunModule, ImageServing): images = self.transforms(images) if len(images.shape) == 3: images = images[np.newaxis, :] - preds = self(paddle.to_variable(images)) + preds = self(paddle.to_tensor(images)) preds = F.softmax(preds, axis=1).numpy() pred_idxs = np.argsort(preds)[::-1][:, :top_k] res = [] @@ -91,6 +91,3 @@ class ImageClassifierModule(RunModule, ImageServing): res_dict[class_name] = preds[i][k] res.append(res_dict) return res - - def is_better_score(self, old_score: dict, new_score: dict): - return old_score['acc'] < new_score['acc'] diff --git a/paddlehub/server/git_source.py b/paddlehub/server/git_source.py index c5b7e54c..7ee98ecc 100644 --- a/paddlehub/server/git_source.py +++ b/paddlehub/server/git_source.py @@ -29,7 +29,15 @@ from paddlehub.utils import log class GitSource(object): - def __init__(self, url, path=None): + ''' + Git source for PaddleHub module + + Args: + url(str) : Url of git repository + path(str) : Path to store the git repository + ''' + + def __init__(self, url: str, path: str = None): self.url = url self._parse_result = urlparse(self.url) @@ -66,10 +74,25 @@ class GitSource(object): log.logger.warning('An error occurred while loading {}'.format(self.path)) sys.path.remove(self.path) - def search_module(self, name, version=None): + def search_module(self, name: str, version: str = None) -> dict: + ''' + Search PaddleHub module + + Args: + name(str) : PaddleHub module name + version(str) : PaddleHub module version + ''' return self.search_resouce(type='module', name=name, version=version) - def search_resouce(self, type, name, version=None): + def search_resouce(self, type: str, name: str, version: str = None) -> dict: + ''' + Search PaddleHub Resource + + Args: + type(str) : Resource type + name(str) : Resource name + version(str) : Resource version + ''' module = self.hub_modules.get(name, None) if module and module.version.match(version): return { @@ -82,7 +105,13 @@ class GitSource(object): return None @classmethod - def check(cls, url): + def check(cls, url: str) -> bool: + ''' + Check if the specified url is a valid git repository link + + Args: + url(str) : Url to check + ''' try: git.cmd.Git().ls_remote(url) return True diff --git a/paddlehub/server/server.py b/paddlehub/server/server.py index 01ea98d0..6f1b0c29 100644 --- a/paddlehub/server/server.py +++ b/paddlehub/server/server.py @@ -21,10 +21,12 @@ PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub' class HubServer(object): + '''PaddleHub server''' + def __init__(self): self.sources = OrderedDict() - def _generate_source(self, url): + def _generate_source(self, url: str): if ServerSource.check(url): source = ServerSource(url) elif GitSource.check(url): @@ -33,17 +35,34 @@ class HubServer(object): raise RuntimeError() return source - def add_source(self, url, key=None): + def add_source(self, url: str, key: str = None): + '''Add a module source(GitSource or ServerSource)''' key = "source_{}".format(len(self.sources)) if not key else key self.sources[key] = self._generate_source(url) - def remove_source(self, url=None, key=None): + def remove_source(self, url: str = None, key: str = None): + '''Remove a module source''' self.sources.pop(key) - def search_module(self, name, version=None, source=None): + def search_module(self, name: str, version: str = None, source: str = None) -> dict: + ''' + Search PaddleHub module + + Args: + name(str) : PaddleHub module name + version(str) : PaddleHub module version + ''' return self.search_resouce(type='module', name=name, version=version, source=source) - def search_resouce(self, type, name, version=None, source=None): + def search_resouce(self, type: str, name: str, version: str = None, source: str = None) -> dict: + ''' + Search PaddleHub Resource + + Args: + type(str) : Resource type + name(str) : Resource name + version(str) : Resource version + ''' sources = self.sources.values() if not source else [self._generate_source(source)] for source in sources: result = source.search_resouce(name=name, type=type, version=version) diff --git a/paddlehub/server/server_source.py b/paddlehub/server/server_source.py index 3848f170..64aee1da 100644 --- a/paddlehub/server/server_source.py +++ b/paddlehub/server/server_source.py @@ -25,14 +25,37 @@ from paddlehub.utils import utils class ServerSource(object): - def __init__(self, url, timeout=10): + ''' + PaddleHub server source + + Args: + url(str) : Url of the server + timeout(int) : Request timeout + ''' + + def __init__(self, url: str, timeout: int = 10): self._url = url self._timeout = timeout - def search_module(self, name, version=None): + def search_module(self, name: str, version: str = None) -> dict: + ''' + Search PaddleHub module + + Args: + name(str) : PaddleHub module name + version(str) : PaddleHub module version + ''' return self.search_resouce(type='module', name=name, version=version) - def search_resouce(self, type, name, version=None): + def search_resouce(self, type: str, name: str, version: str = None) -> dict: + ''' + Search PaddleHub Resource + + Args: + type(str) : Resource type + name(str) : Resource name + version(str) : Resource version + ''' payload = {'environments': {}} payload['word'] = name @@ -59,7 +82,13 @@ class ServerSource(object): return None @classmethod - def check(cls, url): + def check(cls, url: str) -> bool: + ''' + Check if the specified url is a valid paddlehub server + + Args: + url(str) : Url to check + ''' try: r = requests.get(url + '/search') return r.status_code == 200 diff --git a/paddlehub/utils/io.py b/paddlehub/utils/io.py index de8b61d1..f89802f3 100644 --- a/paddlehub/utils/io.py +++ b/paddlehub/utils/io.py @@ -50,7 +50,7 @@ def redirect_estream(stream: IO): @contextlib.contextmanager def discard_oe(): ''' - Redirect input and output stream to temporary file. In a sense, + Redirect output and error stream to temporary file. In a sense, it is equivalent discarded the output and error messages ''' with generate_tempfile(mode='w') as _stream: diff --git a/paddlehub/utils/log.py b/paddlehub/utils/log.py index f0f14b35..18f9c4f2 100644 --- a/paddlehub/utils/log.py +++ b/paddlehub/utils/log.py @@ -96,15 +96,16 @@ class ProgressBar(object): Examples: .. code-block:: python - with ProgressBar('Download module') as bar: - for i in range(100): - bar.update(i / 100) - - # with continuous bar.update, the progress bar in the terminal - # will continue to update until 100% - # - # Download module - # [##################################################] 100.00% + + with ProgressBar('Download module') as bar: + for i in range(100): + bar.update(i / 100) + + # with continuous bar.update, the progress bar in the terminal + # will continue to update until 100% + # + # Download module + # [##################################################] 100.00% ''' def __init__(self, title: str, flush_interval: float = 0.1): @@ -126,6 +127,10 @@ class ProgressBar(object): def update(self, progress: float): ''' + Update progress bar + + Args: + progress: Processing progress, from 0.0 to 1.0 ''' msg = '[{:<50}] {:.2f}%'.format('#' * int(progress * 50), progress * 100) need_flush = (time.time() - self.last_flush_time) >= self.flush_interval @@ -146,14 +151,14 @@ class FormattedText(object): Args: text(str) : Text content width(int) : Text length, if the text is less than the specified length, it will be filled with spaces - align(str) : it must be: - ======== ================== + align(str) : Text alignment, it must be: + ======== ==================================== Charater Meaning - -------- ------------------ - '<' left aligned - '^' middle aligned - '>' right aligned - ======== ================== + -------- ------------------------------------ + '<' The text will remain left aligned + '^' The text will remain middle aligned + '>' The text will remain right aligned + ======== ==================================== color(str) : Text color, default is None(depends on terminal configuration) ''' _MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE} @@ -293,12 +298,13 @@ class Table(object): Table with adaptive width and height Args: - colors(list[str]) : Text colors of contents one by one - aligns(list[str]) : Text aligns of contents one by one - widths(list[str]) : Text widths of contents one by one + colors(list[str]) : Text colors + aligns(list[str]) : Text alignments + widths(list[str]) : Text widths Examples: .. code-block:: python + table = Table(widths=[12, 20]) table.append('name', 'PaddleHub') table.append('version', '2.0.0') @@ -337,9 +343,9 @@ class Table(object): Args: *contents(*list): Contents of the row, each content will be placed in a separate cell - colors(list[str]) : Text colors of contents one by one, if not set, the default value will be used. - aligns(list[str]) : Text aligns of contents one by one, if not set, the default value will be used. - widths(list[str]) : Text widths of contents one by one, if not set, the default value will be used. + colors(list[str]) : Text colors + aligns(list[str]) : Text alignments + widths(list[str]) : Text widths ''' newrow = TableRow() diff --git a/paddlehub/utils/utils.py b/paddlehub/utils/utils.py index bcaf3321..a9e8785e 100644 --- a/paddlehub/utils/utils.py +++ b/paddlehub/utils/utils.py @@ -32,7 +32,7 @@ import paddlehub.env as hubenv class Version(packaging.version.Version): - '''Expand realization of packaging.version.Version''' + '''Extended implementation of packaging.version.Version''' def match(self, condition: str) -> bool: ''' @@ -45,9 +45,9 @@ class Version(packaging.version.Version): bool: True if the given version condition are met, else False Examples: - from paddlehub.utils import Version + .. code-block:: python - Version('1.2.0').match('>=1.2.0a') + Version('1.2.0').match('>=1.2.0a') ''' if not condition: return True @@ -162,7 +162,6 @@ def download(url: str, path: str = None) -> str: Examples: .. code-block:: python - from paddlehub.utils.utils import download url = 'https://xxxxx.xx/xx.tar.gz' download(url, path='./output') @@ -182,7 +181,6 @@ def download_with_progress(url: str, path: str = None) -> Generator[str, int, in Examples: .. code-block:: python - from paddlehub.utils.utils import download_with_progress url = 'https://xxxxx.xx/xx.tar.gz' for filename, download_size, total_szie in download_with_progress(url, path='./output'): diff --git a/paddlehub/utils/xarfile.py b/paddlehub/utils/xarfile.py index fec66d3a..c70cb406 100644 --- a/paddlehub/utils/xarfile.py +++ b/paddlehub/utils/xarfile.py @@ -177,7 +177,6 @@ def archive(filename: str, recursive: bool = True, exclude: Callable = None, arc Examples: .. code-block:: python - from paddlehub.utils import archive archive_path = '/PATH/TO/FILE' archive(archive_path, arcname='output.tar.gz', arctype='tar.gz') @@ -200,7 +199,6 @@ def unarchive(name: str, path: str): Examples: .. code-block:: python - from paddlehub.utils import unarchive unarchive_path = '/PATH/TO/FILE' unarchive(unarchive_path, path='./output') @@ -219,7 +217,6 @@ def unarchive_with_progress(name: str, path: str) -> Generator[str, int, int]: Examples: .. code-block:: python - from paddlehub.utils.xarfile import unarchive_with_progress unarchive_path = 'test.tar.gz' for filename, extract_size, total_szie in unarchive_with_progress(unarchive_path, path='./output'): -- GitLab