提交 8dfe8f62 编写于 作者: W wuzewu

Update the interface used from fluid.xxx to paddle.xxx

上级 621931d2
...@@ -25,14 +25,15 @@ from paddlehub.utils import log, platform ...@@ -25,14 +25,15 @@ from paddlehub.utils import log, platform
class ListCommand: class ListCommand:
def execute(self, argv: List) -> bool: def execute(self, argv: List) -> bool:
manager = LocalModuleManager() manager = LocalModuleManager()
table = log.Table()
widths = [20, 40] if platform.is_windows() else [25, 50] widths = [20, 40] if platform.is_windows() else [25, 50]
aligns = ['^', '<'] aligns = ['^', '<']
table.append('ModuleName', 'Path', widths=widths, aligns=aligns, colors=['green', 'green']) table = log.Table(widths=widths, aligns=aligns)
table.append('ModuleName', 'Path', colors=['green', 'green'])
for module in manager.list(): for module in manager.list():
table.append(module.name, module.directory, widths=widths, aligns=aligns) table.append(module.name, module.directory)
print(table) print(table)
return True return True
...@@ -45,16 +45,16 @@ class ShowCommand: ...@@ -45,16 +45,16 @@ class ShowCommand:
print('{} is not existed!'.format(argv)) print('{} is not existed!'.format(argv))
return False return False
table = log.Table()
widths = [15, 40] if platform.is_windows else [15, 50] widths = [15, 40] if platform.is_windows else [15, 50]
aligns = ['^', '<'] aligns = ['^', '<']
colors = ['yellow', ''] colors = ['yellow', '']
table = log.Table(widths=widths, colors=colors, aligns=aligns)
table.append('ModuleName', module.name, widths=widths, colors=colors, aligns=aligns) table.append('ModuleName', module.name)
table.append('Version', str(module.version), widths=widths, colors=colors, aligns=aligns) table.append('Version', str(module.version))
table.append('Summary', module.summary, widths=widths, colors=colors, aligns=aligns) table.append('Summary', module.summary)
table.append('Author', module.author, widths=widths, colors=colors, aligns=aligns) table.append('Author', module.author)
table.append('Author-Email', module.author_email, widths=widths, colors=colors, aligns=aligns) table.append('Author-Email', module.author_email)
table.append('Location', module.directory, widths=widths, colors=colors, aligns=aligns) table.append('Location', module.directory)
print(table) print(table)
return True return True
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
import os import os
import paddle.fluid as fluid import paddle
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
class Flowers(fluid.io.Dataset): class Flowers(paddle.io.Dataset):
def __init__(self, transforms=None, mode='train'): def __init__(self, transforms=None, mode='train'):
self.mode = mode self.mode = mode
self.transforms = transforms self.transforms = transforms
......
...@@ -19,11 +19,8 @@ import time ...@@ -19,11 +19,8 @@ import time
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable from typing import Any, Callable
import paddle.fluid as fluid import paddle
from paddle.fluid.dygraph.base import to_variable from paddle.distributed import ParallelEnv
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from visualdl import LogWriter from visualdl import LogWriter
from paddlehub.utils.log import logger from paddlehub.utils.log import logger
...@@ -36,8 +33,8 @@ class Trainer(object): ...@@ -36,8 +33,8 @@ class Trainer(object):
''' '''
def __init__(self, def __init__(self,
model: fluid.dygraph.Layer, model: paddle.nn.Layer,
strategy: fluid.optimizer.Optimizer, strategy: paddle.optimizer.Optimizer,
use_vdl: bool = True, use_vdl: bool = True,
checkpoint_dir: str = None, checkpoint_dir: str = None,
compare_metrics: Callable = None): compare_metrics: Callable = None):
...@@ -59,8 +56,8 @@ class Trainer(object): ...@@ -59,8 +56,8 @@ class Trainer(object):
self.best_metrics = defaultdict(int) self.best_metrics = defaultdict(int)
if self.nranks > 1: if self.nranks > 1:
context = fluid.dygraph.prepare_context() context = paddle.distributed.init_parallel_env()
self.model = fluid.dygraph.DataParallel(self.model, context) self.model = paddle.DataParallel(self.model, context)
self.compare_metrics = self._compare_metrics if not compare_metrics else compare_metrics self.compare_metrics = self._compare_metrics if not compare_metrics else compare_metrics
self._load_checkpoint() self._load_checkpoint()
...@@ -96,7 +93,7 @@ class Trainer(object): ...@@ -96,7 +93,7 @@ class Trainer(object):
# load model from checkpoint # load model from checkpoint
model_path = os.path.join(self.checkpoint_dir, '{}_{}'.format('epoch', self.current_epoch), 'model') model_path = os.path.join(self.checkpoint_dir, '{}_{}'.format('epoch', self.current_epoch), 'model')
state_dict, _ = fluid.load_dygraph(model_path) state_dict, _ = paddle.load(model_path)
self.model.set_dict(state_dict) self.model.set_dict(state_dict)
def _save_checkpoint(self): def _save_checkpoint(self):
...@@ -107,7 +104,7 @@ class Trainer(object): ...@@ -107,7 +104,7 @@ class Trainer(object):
def save_model(self, save_dir: str): def save_model(self, save_dir: str):
'''Save model''' '''Save model'''
fluid.save_dygraph(self.model.state_dict(), save_dir) paddle.save(self.model.state_dict(), save_dir)
def _save_metrics(self): def _save_metrics(self):
with open(os.path.join(self.checkpoint_dir, 'metrics.pkl'), 'wb') as file: with open(os.path.join(self.checkpoint_dir, 'metrics.pkl'), 'wb') as file:
...@@ -118,156 +115,159 @@ class Trainer(object): ...@@ -118,156 +115,159 @@ class Trainer(object):
self.best_metrics = pickle.load(file) self.best_metrics = pickle.load(file)
def train(self, def train(self,
train_dataset: fluid.io.Dataset, train_dataset: paddle.io.Dataset,
epochs: int = 1, epochs: int = 1,
batch_size: int = 1, batch_size: int = 1,
num_workers: int = 0, num_workers: int = 0,
eval_dataset: fluid.io.Dataset = None, eval_dataset: paddle.io.Dataset = None,
log_interval: int = 10, log_interval: int = 10,
save_interval: int = 10): save_interval: int = 10):
''' '''
Train a model with specific config. Train a model with specific config.
Args: Args:
train_dataset(fluid.io.Dataset) : Dataset to train the model train_dataset(paddle.io.Dataset) : Dataset to train the model
epochs(int) : Number of training loops, default is 1. epochs(int) : Number of training loops, default is 1.
batch_size(int) : Batch size of per step, default is 1. batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0. num_workers(int) : Number of subprocess to load data, default is 0.
eval_dataset(fluid.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will execute evaluate function every `save_interval` epochs. eval_dataset(paddle.io.Dataset) : The validation dataset, deafult is None. If set, the Trainer will execute evaluate function every `save_interval` epochs.
log_interval(int) : Log the train infomation every `log_interval` steps. log_interval(int) : Log the train infomation every `log_interval` steps.
save_interval(int) : Save the checkpoint every `save_interval` epochs. save_interval(int) : Save the checkpoint every `save_interval` epochs.
''' '''
use_gpu = True use_gpu = True
place = fluid.CUDAPlace(ParallelEnv().dev_id) if use_gpu else fluid.CPUPlace() place = paddle.CUDAPlace(ParallelEnv().dev_id) if use_gpu else paddle.CPUPlace()
with fluid.dygraph.guard(place): paddle.disable_static(place)
batch_sampler = DistributedBatchSampler(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = DataLoader(
train_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True)
steps_per_epoch = len(batch_sampler)
timer = Timer(steps_per_epoch * epochs)
timer.start()
for i in range(epochs):
self.current_epoch += 1
avg_loss = 0
avg_metrics = defaultdict(int)
self.model.train()
for batch_idx, batch in enumerate(loader):
loss, metrics = self.training_step(batch, batch_idx)
self.optimizer_step(self.current_epoch, batch_idx, self.optimizer, loss)
self.optimizer_zero_grad(self.current_epoch, batch_idx, self.optimizer)
# calculate metrics and loss
avg_loss += loss.numpy()[0]
for metric, value in metrics.items():
avg_metrics[metric] += value.numpy()[0]
timer.count()
if (batch_idx + 1) % log_interval == 0 and self.local_rank == 0:
lr = self.optimizer.current_step_lr()
avg_loss /= log_interval
if self.use_vdl:
self.log_writer.add_scalar(tag='TRAIN/loss', step=timer.current_step, value=avg_loss)
print_msg = 'Epoch={}/{}, Step={}/{}'.format(self.current_epoch, epochs, batch_idx + 1, batch_sampler = paddle.io.DistributedBatchSampler(
steps_per_epoch) train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
print_msg += ' loss={:.4f}'.format(avg_loss) loader = paddle.io.DataLoader(
train_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True)
for metric, value in avg_metrics.items(): steps_per_epoch = len(batch_sampler)
value /= log_interval timer = Timer(steps_per_epoch * epochs)
if self.use_vdl: timer.start()
self.log_writer.add_scalar(
tag='TRAIN/{}'.format(metric), step=timer.current_step, value=value) for i in range(epochs):
print_msg += ' {}={:.4f}'.format(metric, value) self.current_epoch += 1
avg_loss = 0
avg_metrics = defaultdict(int)
self.model.train()
for batch_idx, batch in enumerate(loader):
loss, metrics = self.training_step(batch, batch_idx)
self.optimizer_step(self.current_epoch, batch_idx, self.optimizer, loss)
self.optimizer_zero_grad(self.current_epoch, batch_idx, self.optimizer)
# calculate metrics and loss
avg_loss += loss.numpy()[0]
for metric, value in metrics.items():
avg_metrics[metric] += value.numpy()[0]
timer.count()
if (batch_idx + 1) % log_interval == 0 and self.local_rank == 0:
lr = self.optimizer.current_step_lr()
avg_loss /= log_interval
if self.use_vdl:
self.log_writer.add_scalar(tag='TRAIN/loss', step=timer.current_step, value=avg_loss)
print_msg = 'Epoch={}/{}, Step={}/{}'.format(self.current_epoch, epochs, batch_idx + 1,
steps_per_epoch)
print_msg += ' loss={:.4f}'.format(avg_loss)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(lr, timer.timing, timer.eta) for metric, value in avg_metrics.items():
value /= log_interval
if self.use_vdl:
self.log_writer.add_scalar(
tag='TRAIN/{}'.format(metric), step=timer.current_step, value=value)
print_msg += ' {}={:.4f}'.format(metric, value)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(lr, timer.timing, timer.eta)
logger.train(print_msg) logger.train(print_msg)
avg_loss = 0 avg_loss = 0
avg_metrics = defaultdict(int) avg_metrics = defaultdict(int)
if self.current_epoch % save_interval == 0 and batch_idx + 1 == steps_per_epoch and self.local_rank == 0: if self.current_epoch % save_interval == 0 and batch_idx + 1 == steps_per_epoch and self.local_rank == 0:
if eval_dataset: if eval_dataset:
result = self.evaluate(eval_dataset, batch_size, num_workers) result = self.evaluate(eval_dataset, batch_size, num_workers)
eval_loss = result.get('loss', None) eval_loss = result.get('loss', None)
eval_metrics = result.get('metrics', {}) eval_metrics = result.get('metrics', {})
if self.use_vdl: if self.use_vdl:
if eval_loss: if eval_loss:
self.log_writer.add_scalar( self.log_writer.add_scalar(tag='EVAL/loss', step=timer.current_step, value=eval_loss)
tag='EVAL/loss', step=timer.current_step, value=eval_loss)
for metric, value in eval_metrics.items(): for metric, value in eval_metrics.items():
self.log_writer.add_scalar( self.log_writer.add_scalar(
tag='EVAL/{}'.format(metric), step=timer.current_step, value=value) tag='EVAL/{}'.format(metric), step=timer.current_step, value=value)
if not self.best_metrics or self.compare_metrics(self.best_metrics, eval_metrics): if not self.best_metrics or self.compare_metrics(self.best_metrics, eval_metrics):
self.best_metrics = eval_metrics self.best_metrics = eval_metrics
best_model_path = os.path.join(self.checkpoint_dir, 'best_model') best_model_path = os.path.join(self.checkpoint_dir, 'best_model')
self.save_model(best_model_path) self.save_model(best_model_path)
self._save_metrics() self._save_metrics()
metric_msg = [ metric_msg = [
'{}={:.4f}'.format(metric, value) for metric, value in self.best_metrics.items() '{}={:.4f}'.format(metric, value) for metric, value in self.best_metrics.items()
] ]
metric_msg = ' '.join(metric_msg) metric_msg = ' '.join(metric_msg)
logger.eval('Saving best model to {} [best {}]'.format(best_model_path, metric_msg)) logger.eval('Saving best model to {} [best {}]'.format(best_model_path, metric_msg))
self._save_checkpoint() self._save_checkpoint()
def evaluate(self, eval_dataset: fluid.io.Dataset, batch_size: int = 1, num_workers: int = 0): def evaluate(self, eval_dataset: paddle.io.Dataset, batch_size: int = 1, num_workers: int = 0):
''' '''
Run evaluation and returns metrics. Run evaluation and returns metrics.
Args: Args:
eval_dataset(fluid.io.Dataset) : The validation dataset eval_dataset(paddle.io.Dataset) : The validation dataset
batch_size(int) : Batch size of per step, default is 1. batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0. num_workers(int) : Number of subprocess to load data, default is 0.
''' '''
use_gpu = True use_gpu = True
place = fluid.CUDAPlace(ParallelEnv().dev_id) if use_gpu else fluid.CPUPlace() place = paddle.CUDAPlace(ParallelEnv().dev_id) if use_gpu else paddle.CPUPlace()
with fluid.dygraph.guard(place): paddle.disable_static(place)
batch_sampler = DistributedBatchSampler(eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
loader = DataLoader( batch_sampler = paddle.io.DistributedBatchSampler(
eval_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True) eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
self.model.eval() loader = paddle.io.DataLoader(
avg_loss = num_samples = 0 eval_dataset, batch_sampler=batch_sampler, places=place, num_workers=num_workers, return_list=True)
sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int)
for batch_idx, batch in enumerate(loader): self.model.eval()
result = self.validation_step(batch, batch_idx) avg_loss = num_samples = 0
loss = result.get('loss', None) sum_metrics = defaultdict(int)
metrics = result.get('metrics', {}) avg_metrics = defaultdict(int)
bs = batch[0].shape[0]
num_samples += bs
if loss: for batch_idx, batch in enumerate(loader):
avg_loss += loss.numpy()[0] * bs result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
num_samples += bs
for metric, value in metrics.items():
sum_metrics[metric] += value.numpy()[0] * bs
# print avg metrics and loss
print_msg = '[Evaluation result]'
if loss: if loss:
avg_loss /= num_samples avg_loss += loss.numpy()[0] * bs
print_msg += ' avg_loss={:.4f}'.format(avg_loss)
for metric, value in sum_metrics.items(): for metric, value in metrics.items():
avg_metrics[metric] = value / num_samples sum_metrics[metric] += value.numpy()[0] * bs
print_msg += ' avg_{}={:.4f}'.format(metric, avg_metrics[metric])
logger.eval(print_msg) # print avg metrics and loss
print_msg = '[Evaluation result]'
if loss:
avg_loss /= num_samples
print_msg += ' avg_loss={:.4f}'.format(avg_loss)
if loss: for metric, value in sum_metrics.items():
return {'loss': avg_loss, 'metrics': avg_metrics} avg_metrics[metric] = value / num_samples
return {'metrics': avg_metrics} print_msg += ' avg_{}={:.4f}'.format(metric, avg_metrics[metric])
logger.eval(print_msg)
if loss:
return {'loss': avg_loss, 'metrics': avg_metrics}
return {'metrics': avg_metrics}
def training_step(self, batch: Any, batch_idx: int): def training_step(self, batch: Any, batch_idx: int):
if self.nranks > 1: if self.nranks > 1:
...@@ -302,11 +302,11 @@ class Trainer(object): ...@@ -302,11 +302,11 @@ class Trainer(object):
result = self.model.validation_step(batch, batch_idx) result = self.model.validation_step(batch, batch_idx)
return result return result
def optimizer_step(self, current_epoch: int, batch_idx: int, optimizer: fluid.optimizer.Optimizer, def optimizer_step(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer,
loss: fluid.core.VarBase): loss: paddle.Tensor):
self.optimizer.minimize(loss) self.optimizer.minimize(loss)
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: fluid.optimizer.Optimizer): def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: paddle.optimizer.Optimizer):
self.model.clear_gradients() self.model.clear_gradients()
def _compare_metrics(self, old_metric: dict, new_metric: dict): def _compare_metrics(self, old_metric: dict, new_metric: dict):
......
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.nn.functional as F
from paddle.fluid.dygraph import to_variable
from paddlehub.module.module import serving, RunModule from paddlehub.module.module import serving, RunModule
from paddlehub.utils.utils import base64_to_cv2 from paddlehub.utils.utils import base64_to_cv2
...@@ -24,7 +25,7 @@ from paddlehub.utils.utils import base64_to_cv2 ...@@ -24,7 +25,7 @@ from paddlehub.utils.utils import base64_to_cv2
class ImageServing(object): class ImageServing(object):
@serving @serving
def serving_method(self, images, **kwargs): def serving_method(self, images: List[str], **kwargs) -> List[dict]:
"""Run as a service.""" """Run as a service."""
images_decode = [base64_to_cv2(image) for image in images] images_decode = [base64_to_cv2(image) for image in images]
results = self.predict(images=images_decode, **kwargs) results = self.predict(images=images_decode, **kwargs)
...@@ -32,25 +33,55 @@ class ImageServing(object): ...@@ -32,25 +33,55 @@ class ImageServing(object):
class ImageClassifierModule(RunModule, ImageServing): class ImageClassifierModule(RunModule, ImageServing):
def training_step(self, batch, batch_idx): def training_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for training, which should be called as forward computation.
Args:
batch(list[paddle.Variable]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as loss and metrics.
'''
return self.validation_step(batch, batch_idx) return self.validation_step(batch, batch_idx)
def validation_step(self, batch, batch_idx): def validation_step(self, batch: int, batch_idx: int) -> dict:
'''
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Variable]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
'''
images = batch[0] images = batch[0]
labels = paddle.unsqueeze(batch[1], axes=-1) labels = paddle.unsqueeze(batch[1], axis=-1)
preds = self(images) preds = self(images)
loss, _ = fluid.layers.softmax_with_cross_entropy(preds, labels, return_softmax=True, axis=1) loss, _ = F.softmax_with_cross_entropy(preds, labels, return_softmax=True, axis=1)
loss = fluid.layers.mean(loss) loss = paddle.mean(loss)
acc = fluid.layers.accuracy(preds, labels) acc = paddle.metric.accuracy(preds, labels)
return {'loss': loss, 'metrics': {'acc': acc}} return {'loss': loss, 'metrics': {'acc': acc}}
def predict(self, images, top_k=1): def predict(self, images: List[np.ndarray], top_k: int = 1) -> List[dict]:
'''
Predict images
Args:
images(list[numpy.ndarray]) : Images to be predicted, consist of np.ndarray in bgr format.
top_k(int) : Output top k result of each image.
Returns:
results(list[dict]) : The prediction result of each input image
'''
images = self.transforms(images) images = self.transforms(images)
if len(images.shape) == 3: if len(images.shape) == 3:
images = images[np.newaxis, :] images = images[np.newaxis, :]
preds = self(to_variable(images)) preds = self(paddle.to_variable(images))
preds = fluid.layers.softmax(preds, axis=1).numpy() preds = F.softmax(preds, axis=1).numpy()
pred_idxs = np.argsort(preds)[::-1][:, :top_k] pred_idxs = np.argsort(preds)[::-1][:, :top_k]
res = [] res = []
for i, pred in enumerate(pred_idxs): for i, pred in enumerate(pred_idxs):
...@@ -61,5 +92,5 @@ class ImageClassifierModule(RunModule, ImageServing): ...@@ -61,5 +92,5 @@ class ImageClassifierModule(RunModule, ImageServing):
res.append(res_dict) res.append(res_dict)
return res return res
def is_better_score(self, old_score, new_score): def is_better_score(self, old_score: dict, new_score: dict):
return old_score['acc'] < new_score['acc'] return old_score['acc'] < new_score['acc']
...@@ -17,14 +17,13 @@ import inspect ...@@ -17,14 +17,13 @@ import inspect
import importlib import importlib
import os import os
import sys import sys
from typing import Callable, List, Optional, Generic
import paddle.fluid as fluid
from paddlehub.utils import utils from paddlehub.utils import utils
class InvalidHubModule(Exception): class InvalidHubModule(Exception):
def __init__(self, directory): def __init__(self, directory: str):
self.directory = directory self.directory = directory
def __str__(self): def __str__(self):
...@@ -35,7 +34,7 @@ _module_serving_func = {} ...@@ -35,7 +34,7 @@ _module_serving_func = {}
_module_runnable_func = {} _module_runnable_func = {}
def runnable(func): def runnable(func: Callable) -> Callable:
mod = func.__module__ + '.' + inspect.stack()[1][3] mod = func.__module__ + '.' + inspect.stack()[1][3]
_module_runnable_func[mod] = func.__name__ _module_runnable_func[mod] = func.__name__
...@@ -45,7 +44,7 @@ def runnable(func): ...@@ -45,7 +44,7 @@ def runnable(func):
return _wrapper return _wrapper
def serving(func): def serving(func: Callable) -> Callable:
mod = func.__module__ + '.' + inspect.stack()[1][3] mod = func.__module__ + '.' + inspect.stack()[1][3]
_module_serving_func[mod] = func.__name__ _module_serving_func[mod] = func.__name__
...@@ -69,7 +68,7 @@ class Module(object): ...@@ -69,7 +68,7 @@ class Module(object):
return module return module
@classmethod @classmethod
def load(cls, directory: str): def load(cls, directory: str) -> Generic:
if directory.endswith(os.sep): if directory.endswith(os.sep):
directory = directory[:-1] directory = directory[:-1]
...@@ -95,9 +94,7 @@ class Module(object): ...@@ -95,9 +94,7 @@ class Module(object):
def init_with_name(cls, name: str, version: str = None, **kwargs): def init_with_name(cls, name: str, version: str = None, **kwargs):
from paddlehub.module.manager import LocalModuleManager from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager() manager = LocalModuleManager()
search_result = manager.search(name) user_module_cls = manager.search(name)
user_module_cls = search_result.get('module', None)
directory = search_result.get('directory', None)
if not user_module_cls or not user_module_cls.version.match(version): if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name, version) user_module_cls = manager.install(name, version)
...@@ -121,7 +118,7 @@ class RunModule(object): ...@@ -121,7 +118,7 @@ class RunModule(object):
self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func) self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
self._is_initialize = True self._is_initialize = True
def _get_func_name(self, current_cls, module_func_dict): def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
mod = current_cls.__module__ + '.' + current_cls.__name__ mod = current_cls.__module__ + '.' + current_cls.__name__
if mod in module_func_dict: if mod in module_func_dict:
_func_name = module_func_dict[mod] _func_name = module_func_dict[mod]
...@@ -133,7 +130,7 @@ class RunModule(object): ...@@ -133,7 +130,7 @@ class RunModule(object):
return None return None
@classmethod @classmethod
def get_py_requirements(cls): def get_py_requirements(cls) -> List[str]:
py_module = sys.modules[cls.__module__] py_module = sys.modules[cls.__module__]
directory = os.path.dirname(py_module.__file__) directory = os.path.dirname(py_module.__file__)
req_file = os.path.join(directory, 'requirements.txt') req_file = os.path.join(directory, 'requirements.txt')
...@@ -156,8 +153,8 @@ def moduleinfo(name: str, ...@@ -156,8 +153,8 @@ def moduleinfo(name: str,
author_email: str = None, author_email: str = None,
summary: str = None, summary: str = None,
type: str = None, type: str = None,
meta=None): meta=None) -> Callable:
def _wrapper(cls): def _wrapper(cls: Generic) -> Generic:
wrap_cls = cls wrap_cls = cls
_meta = RunModule if not meta else meta _meta = RunModule if not meta else meta
if not issubclass(cls, _meta): if not issubclass(cls, _meta):
......
...@@ -53,19 +53,16 @@ def discard_oe(): ...@@ -53,19 +53,16 @@ def discard_oe():
Redirect input and output stream to temporary file. In a sense, Redirect input and output stream to temporary file. In a sense,
it is equivalent discarded the output and error messages it is equivalent discarded the output and error messages
''' '''
with generate_tempfile() as _file: with generate_tempfile(mode='w') as _stream:
with open(_file.name, 'w') as _stream: with redirect_ostream(_stream), redirect_estream(_stream):
with redirect_ostream(_stream): yield
with redirect_estream(_stream):
yield
@contextlib.contextmanager @contextlib.contextmanager
def typein(chars: str = 'y'): def typein(chars: str = 'y'):
# typein chars to input stream # typein chars to input stream
with generate_tempfile() as _file: with generate_tempfile(mode='w+') as _stream:
with open(_file.name, 'w+') as _stream: with redirect_istream(_stream):
with redirect_istream(_stream): _stream.write('{}\n'.format(chars))
_stream.write('{}\n'.format(chars)) _stream.seek(0)
_stream.seek(0) yield
yield
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import functools import functools
import logging import logging
import sys import sys
...@@ -55,7 +56,14 @@ log_config = { ...@@ -55,7 +56,14 @@ log_config = {
class Logger(object): class Logger(object):
def __init__(self, name=None): '''
Deafult logger in PaddleHub
Args:
name(str) : Logger name, default is 'PaddleHub'
'''
def __init__(self, name: str = None):
name = 'PaddleHub' if not name else name name = 'PaddleHub' if not name else name
self.logger = logging.getLogger(name) self.logger = logging.getLogger(name)
...@@ -80,6 +88,23 @@ class Logger(object): ...@@ -80,6 +88,23 @@ class Logger(object):
class ProgressBar(object): class ProgressBar(object):
''' '''
Progress bar printer
Args:
title(str) : Title text
flush_interval(float): Flush rate of progress bar, default is 0.1.
Examples:
.. code-block:: python
with ProgressBar('Download module') as bar:
for i in range(100):
bar.update(i / 100)
# with continuous bar.update, the progress bar in the terminal
# will continue to update until 100%
#
# Download module
# [##################################################] 100.00%
''' '''
def __init__(self, title: str, flush_interval: float = 0.1): def __init__(self, title: str, flush_interval: float = 0.1):
...@@ -99,7 +124,9 @@ class ProgressBar(object): ...@@ -99,7 +124,9 @@ class ProgressBar(object):
else: else:
sys.stdout.write('\n') sys.stdout.write('\n')
def update(self, progress): def update(self, progress: float):
'''
'''
msg = '[{:<50}] {:.2f}%'.format('#' * int(progress * 50), progress * 100) msg = '[{:<50}] {:.2f}%'.format('#' * int(progress * 50), progress * 100)
need_flush = (time.time() - self.last_flush_time) >= self.flush_interval need_flush = (time.time() - self.last_flush_time) >= self.flush_interval
...@@ -112,16 +139,32 @@ class ProgressBar(object): ...@@ -112,16 +139,32 @@ class ProgressBar(object):
sys.stdout.write('\n') sys.stdout.write('\n')
class _FormattedText(object): class FormattedText(object):
'''
Cross-platform formatted string
Args:
text(str) : Text content
width(int) : Text length, if the text is less than the specified length, it will be filled with spaces
align(str) : it must be:
======== ==================
Charater Meaning
-------- ------------------
'<' left aligned
'^' middle aligned
'>' right aligned
======== ==================
color(str) : Text color, default is None(depends on terminal configuration)
'''
_MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE} _MAP = {'red': Fore.RED, 'yellow': Fore.YELLOW, 'green': Fore.GREEN, 'blue': Fore.BLUE}
def __init__(self, text: str, width: int, align='<', color=None): def __init__(self, text: str, width: int, align: str = '<', color: str = None):
self.text = text self.text = text
self.align = align self.align = align
self.color = _FormattedText._MAP[color] if color else color self.color = FormattedText._MAP[color] if color else color
self.width = width self.width = width
def __repr__(self): def __repr__(self) -> str:
form = ':{}{}'.format(self.align, self.width) form = ':{}{}'.format(self.align, self.width)
text = ('{' + form + '}').format(self.text) text = ('{' + form + '}').format(self.text)
if not self.color: if not self.color:
...@@ -130,12 +173,14 @@ class _FormattedText(object): ...@@ -130,12 +173,14 @@ class _FormattedText(object):
class TableCell(object): class TableCell(object):
'''The basic components of a table'''
def __init__(self, content: str = '', width: int = 0, align: str = '<', color: str = ''): def __init__(self, content: str = '', width: int = 0, align: str = '<', color: str = ''):
self._width = width if width else len(content) self._width = width if width else len(content)
self._width = 1 if self._width < 1 else self._width self._width = 1 if self._width < 1 else self._width
self._contents = [] self._contents = []
for i in range(0, len(content), self._width): for i in range(0, len(content), self._width):
text = _FormattedText(content[i:i + self._width], width, align, color) text = FormattedText(content[i:i + self._width], width, align, color)
self._contents.append(text) self._contents.append(text)
self.align = align self.align = align
self.color = color self.color = color
...@@ -158,7 +203,7 @@ class TableCell(object): ...@@ -158,7 +203,7 @@ class TableCell(object):
def height(self, value: int): def height(self, value: int):
if value < self.height: if value < self.height:
raise RuntimeError(self.height, value) raise RuntimeError(self.height, value)
self._contents += [_FormattedText('', width=self.width, align=self.align, color=self.color) self._contents += [FormattedText('', width=self.width, align=self.align, color=self.color)
] * (value - self.height) ] * (value - self.height)
def __len__(self) -> int: def __len__(self) -> int:
...@@ -171,7 +216,9 @@ class TableCell(object): ...@@ -171,7 +216,9 @@ class TableCell(object):
return '\n'.join([str(item) for item in self._contents]) return '\n'.join([str(item) for item in self._contents])
class TableLine(object): class TableRow(object):
'''Table row composed of TableCell'''
def __init__(self): def __init__(self):
self.cells = [] self.cells = []
...@@ -179,23 +226,23 @@ class TableLine(object): ...@@ -179,23 +226,23 @@ class TableLine(object):
self.cells.append(cell) self.cells.append(cell)
@property @property
def width(self): def width(self) -> int:
_width = 0 _width = 0
for cell in self.cells(): for cell in self.cells():
_width += cell.width _width += cell.width
return _width return _width
@property @property
def height(self): def height(self) -> int:
_height = -1 _height = -1
for cell in self.cells: for cell in self.cells:
_height = max(_height, cell.height) _height = max(_height, cell.height)
return _height return _height
def __len__(self): def __len__(self) -> int:
return len(self.cells) return len(self.cells)
def __repr__(self): def __repr__(self) -> str:
content = '' content = ''
for i in range(self.height): for i in range(self.height):
content += '|' content += '|'
...@@ -211,7 +258,9 @@ class TableLine(object): ...@@ -211,7 +258,9 @@ class TableLine(object):
return self.cells[idx] return self.cells[idx]
class TableRow(object): class TableColumn(object):
'''Table column composed of TableCell'''
def __init__(self): def __init__(self):
self.cells = [] self.cells = []
...@@ -219,20 +268,20 @@ class TableRow(object): ...@@ -219,20 +268,20 @@ class TableRow(object):
self.cells.append(cell) self.cells.append(cell)
@property @property
def width(self): def width(self) -> int:
_width = -1 _width = -1
for cell in self.cells: for cell in self.cells:
_width = max(_width, cell.width) _width = max(_width, cell.width)
return _width return _width
@property @property
def height(self): def height(self) -> int:
_height = 0 _height = 0
for cell in self.cells: for cell in self.cells:
_height += cell.height _height += cell.height
return _height return _height
def __len__(self): def __len__(self) -> int:
return len(self.cells) return len(self.cells)
def __getitem__(self, idx: int) -> TableCell: def __getitem__(self, idx: int) -> TableCell:
...@@ -240,12 +289,63 @@ class TableRow(object): ...@@ -240,12 +289,63 @@ class TableRow(object):
class Table(object): class Table(object):
def __init__(self): '''
self.lines = [] Table with adaptive width and height
Args:
colors(list[str]) : Text colors of contents one by one
aligns(list[str]) : Text aligns of contents one by one
widths(list[str]) : Text widths of contents one by one
Examples:
.. code-block:: python
table = Table(widths=[12, 20])
table.append('name', 'PaddleHub')
table.append('version', '2.0.0')
table.append(
'description',
'PaddleHub is a pretrainied model application tool under the PaddlePaddle')
table.append('author')
print(table)
# the result is
# +------------+--------------------+
# |name |PaddleHub |
# +------------+--------------------+
# |version |2.0.0 |
# +------------+--------------------+
# |description |PaddleHub is a pretr|
# | |ainied model applica|
# | |tion tool under the |
# | |PaddlePaddle |
# +------------+--------------------+
# |author | |
# +------------+--------------------+
'''
def __init__(self, colors: List[str] = [], aligns: List[str] = [], widths: List[int] = []):
self.rows = [] self.rows = []
self.columns = []
self.colors = colors
self.aligns = aligns
self.widths = widths
def append(self, *contents, colors: List[str] = [], aligns: List[str] = [], widths: List[int] = []): def append(self, *contents, colors: List[str] = [], aligns: List[str] = [], widths: List[int] = []):
newline = TableLine() '''
Add a row to the table
Args:
*contents(*list): Contents of the row, each content will be placed in a separate cell
colors(list[str]) : Text colors of contents one by one, if not set, the default value will be used.
aligns(list[str]) : Text aligns of contents one by one, if not set, the default value will be used.
widths(list[str]) : Text widths of contents one by one, if not set, the default value will be used.
'''
newrow = TableRow()
widths = copy.deepcopy(self.widths) if not widths else widths
colors = copy.deepcopy(self.colors) if not colors else colors
aligns = copy.deepcopy(self.aligns) if not aligns else aligns
for idx, content in enumerate(contents): for idx, content in enumerate(contents):
width = widths[idx] if idx < len(widths) else len(content) width = widths[idx] if idx < len(widths) else len(content)
...@@ -253,73 +353,66 @@ class Table(object): ...@@ -253,73 +353,66 @@ class Table(object):
align = aligns[idx] if idx < len(aligns) else '' align = aligns[idx] if idx < len(aligns) else ''
newcell = TableCell(content, width=width, color=color, align=align) newcell = TableCell(content, width=width, color=color, align=align)
newline.append(newcell) newrow.append(newcell)
if idx >= len(self.rows): if idx >= len(self.columns):
newrow = TableRow() newcolumn = TableColumn()
for line in self.lines: for row in self.rows:
cell = TableCell(width=width, color=color, align=align) cell = TableCell(width=width, color=color, align=align)
line.append(cell) row.append(cell)
newrow.append(cell) newcolumn.append(cell)
newrow.append(newcell) newcolumn.append(newcell)
self.rows.append(newrow) self.columns.append(newcolumn)
else: else:
self.rows[idx].append(newcell) self.columns[idx].append(newcell)
for idx in range(len(newline), len(self.rows)): for idx in range(len(newrow), len(self.columns)):
width = widths[idx] if idx < len(widths) else self.rows[idx].width width = widths[idx] if idx < len(widths) else self.columns[idx].width
color = colors[idx] if idx < len(colors) else '' color = colors[idx] if idx < len(colors) else ''
align = aligns[idx] if idx < len(aligns) else '' align = aligns[idx] if idx < len(aligns) else ''
cell = TableCell(width=width, color=color, align=align) cell = TableCell(width=width, color=color, align=align)
newline.append(cell) newrow.append(cell)
self.lines.append(newline) self.rows.append(newrow)
self._adjust() self._adjust()
def _adjust(self): def _adjust(self):
for row in self.rows: '''Adjust the width and height of the cells in each row and column.'''
for column in self.columns:
_width = -1 _width = -1
for cell in row: for cell in column:
_width = max(_width, cell.width) _width = max(_width, cell.width)
for cell in row: for cell in column:
cell.width = _width cell.width = _width
for line in self.lines: for row in self.rows:
_height = -1 _height = -1
for cell in line: for cell in row:
_height = max(_height, cell.height) _height = max(_height, cell.height)
for cell in line: for cell in row:
cell.height = _height cell.height = _height
@property @property
def width(self): def width(self) -> int:
_width = -1 _width = -1
for line in self.lines: for row in self.rows:
_width = max(_width, line.width) _width = max(_width, row.width)
return _width return _width
@property @property
def height(self): def height(self) -> int:
_height = -1 _height = -1
for row in self.rows: for column in self.columns:
_height = max(_height, row.height) _height = max(_height, column.height)
return _height return _height
def __repr__(self): def __repr__(self) -> str:
sepline = '+{}+\n'.format('+'.join(['-' * row.width for row in self.rows])) seprow = '+{}+\n'.format('+'.join(['-' * column.width for column in self.columns]))
content = '' content = ''
for line in self.lines: for row in self.rows:
content = content + str(line) content = content + str(row)
content += sepline content += seprow
return sepline + content return seprow + content
# table = Table()
# table.append('123', '234')
# table.append('122223', '22444')
# table.append('121111111111111111111111111111111113', '234', widths=[10, 20], colors=['red', 'yellow'], aligns=['^', '>'])
# table.append('122223', '22444')
# table.append('122223', '22444', '123')
# print(table)
logger = Logger() logger = Logger()
...@@ -54,9 +54,8 @@ def install(package: str, version: str = '', upgrade=False) -> bool: ...@@ -54,9 +54,8 @@ def install(package: str, version: str = '', upgrade=False) -> bool:
def uninstall(package: str) -> bool: def uninstall(package: str) -> bool:
'''Uninstall the python package.''' '''Uninstall the python package.'''
with discard_oe(): with discard_oe(), typein('y'):
# type in 'y' to confirm the uninstall operation # type in 'y' to confirm the uninstall operation
with typein('y'): cmds = ['uninstall', '{}'.format(package)]
cmds = ['uninstall', '{}'.format(package)] result = pip.main(cmds)
result = pip.main(cmds)
return result == 0 return result == 0
...@@ -18,9 +18,9 @@ import contextlib ...@@ -18,9 +18,9 @@ import contextlib
import cv2 import cv2
import math import math
import os import os
import requests
import sys import sys
import time import time
import requests
import tempfile import tempfile
import numpy as np import numpy as np
from typing import Generator from typing import Generator
...@@ -119,7 +119,7 @@ class Timer(object): ...@@ -119,7 +119,7 @@ class Timer(object):
return seconds_to_hms(remaining_time) return seconds_to_hms(remaining_time)
def seconds_to_hms(seconds: int): def seconds_to_hms(seconds: int) -> str:
'''Convert the number of seconds to hh:mm:ss''' '''Convert the number of seconds to hh:mm:ss'''
h = math.floor(seconds / 3600) h = math.floor(seconds / 3600)
m = math.floor((seconds - h * 3600) / 60) m = math.floor((seconds - h * 3600) / 60)
...@@ -128,7 +128,7 @@ def seconds_to_hms(seconds: int): ...@@ -128,7 +128,7 @@ def seconds_to_hms(seconds: int):
return hms_str return hms_str
def base64_to_cv2(b64str: str): def base64_to_cv2(b64str: str) -> np.ndarray:
'''Convert a string in base64 format to cv2 data''' '''Convert a string in base64 format to cv2 data'''
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
...@@ -137,22 +137,22 @@ def base64_to_cv2(b64str: str): ...@@ -137,22 +137,22 @@ def base64_to_cv2(b64str: str):
@contextlib.contextmanager @contextlib.contextmanager
def generate_tempfile(directory: str = None): def generate_tempfile(directory: str = None, **kwargs):
'''Generate a temporary file''' '''Generate a temporary file'''
directory = hubenv.TMP_HOME if not directory else directory directory = hubenv.TMP_HOME if not directory else directory
with tempfile.NamedTemporaryFile(dir=directory) as file: with tempfile.NamedTemporaryFile(dir=directory, **kwargs) as file:
yield file yield file
@contextlib.contextmanager @contextlib.contextmanager
def generate_tempdir(directory: str = None): def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory''' '''Generate a temporary directory'''
directory = hubenv.TMP_HOME if not directory else directory directory = hubenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory) as _dir: with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir yield _dir
def download(url: str, path: str = None): def download(url: str, path: str = None) -> str:
''' '''
Download a file Download a file
......
...@@ -22,9 +22,7 @@ import rarfile ...@@ -22,9 +22,7 @@ import rarfile
class XarInfo(object): class XarInfo(object):
''' '''Informational class which holds the details about an archive member given by a XarFile.'''
Informational class which holds the details about an archive member given by a XarFile.
'''
def __init__(self, _xarinfo, arctype='tar'): def __init__(self, _xarinfo, arctype='tar'):
self._info = _xarinfo self._info = _xarinfo
...@@ -136,29 +134,21 @@ class XarFile(object): ...@@ -136,29 +134,21 @@ class XarFile(object):
self._archive_fp.write(item) self._archive_fp.write(item)
def extract(self, name: str, path: str): def extract(self, name: str, path: str):
''' '''Extract a file from the archive to the specified path.'''
Extract a file from the archive to the specified path
'''
return self._archive_fp.extract(name, path) return self._archive_fp.extract(name, path)
def extractall(self, path: str): def extractall(self, path: str):
''' '''Extract all files from the archive to the specified path.'''
Extract all files from the archive to the specified path
'''
return self._archive_fp.extractall(path) return self._archive_fp.extractall(path)
def getnames(self) -> List[str]: def getnames(self) -> List[str]:
''' '''Return a list of file names in the archive.'''
Return a list of file names in the archive.
'''
if self.arctype == 'tar': if self.arctype == 'tar':
return self._archive_fp.getnames() return self._archive_fp.getnames()
return self._archive_fp.namelist() return self._archive_fp.namelist()
def getxarinfo(self, name: str) -> List[XarInfo]: def getxarinfo(self, name: str) -> List[XarInfo]:
''' '''Return the instance of XarInfo given 'name'.'''
Return the instance of XarInfo given 'name'.
'''
if self.arctype == 'tar': if self.arctype == 'tar':
return XarInfo(self._archive_fp.getmember(name), self.arctype) return XarInfo(self._archive_fp.getmember(name), self.arctype)
return XarInfo(self._archive_fp.getinfo(name), self.arctype) return XarInfo(self._archive_fp.getinfo(name), self.arctype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册