提交 9d5527df 编写于 作者: W wuzewu

Remove redundant files

上级 09736415
## 命令行预测
```
$ hub run retinanet_resnet50_fpn_coco2017 --input_path "/PATH/TO/IMAGE"
```
## API
```
def context(trainable=True,
pretrained=True,
get_prediction=False)
```
提取特征,用于迁移学习。
**参数**
* trainable(bool): 参数是否可训练;
* pretrained (bool): 是否加载预训练模型;
* get\_prediction (bool): 是否执行预测。
**返回**
* inputs (dict): 模型的输入,keys 包括 'image', 'im\_size',相应的取值为:
* image (Variable): 图像变量
* im\_size (Variable): 图片的尺寸
* outputs (dict): 模型的输出。如果 get\_prediction 为 False,输出 'head\_fatures',否则输出 'bbox\_out'。
* context\_prog (Program): 用于迁移学习的 Program.
```python
def object_detection(paths=None,
images=None,
batch_size=1,
use_gpu=False,
output_dir='detection_result',
score_thresh=0.5,
visualization=True)
```
预测API,检测输入图片中的所有目标的位置。
**参数**
* paths (list\[str\]): 图片的路径;
* images (list\[numpy.ndarray\]): 图片数据,ndarray.shape 为 \[H, W, C\],BGR格式;
* batch\_size (int): batch 的大小;
* use\_gpu (bool): 是否使用 GPU;
* score\_thresh (float): 识别置信度的阈值;
* visualization (bool): 是否将识别结果保存为图片文件;
* output\_dir (str): 图片的保存路径,默认设为 detection\_result;
**返回**
* res (list\[dict\]): 识别结果的列表,列表中每一个元素为 dict,各字段为:
* data (list): 检测结果,list的每一个元素为 dict,各字段为:
* confidence (float): 识别的置信度;
* label (str): 标签;
* left (int): 边界框的左上角x坐标;
* top (int): 边界框的左上角y坐标;
* right (int): 边界框的右下角x坐标;
* bottom (int): 边界框的右下角y坐标;
* save\_path (str, optional): 识别结果的保存路径 (仅当visualization=True时存在)。
```python
def save_inference_model(dirname,
model_filename=None,
params_filename=None,
combined=True)
```
将模型保存到指定路径。
**参数**
* dirname: 存在模型的目录名称
* model\_filename: 模型文件名称,默认为\_\_model\_\_
* params\_filename: 参数文件名称,默认为\_\_params\_\_(仅当`combined`为True时生效)
* combined: 是否将参数保存到统一的一个文件中
## 代码示例
```python
import paddlehub as hub
import cv2
object_detector = hub.Module(name="retinanet_resnet50_fpn_coco2017")
result = object_detector.object_detection(images=[cv2.imread('/PATH/TO/IMAGE')])
# or
# result = object_detector.object_detection((paths=['/PATH/TO/IMAGE'])
```
## 服务部署
PaddleHub Serving可以部署一个目标检测的在线服务。
## 第一步:启动PaddleHub Serving
运行启动命令:
```shell
$ hub serving start -m retinanet_resnet50_fpn_coco2017
```
这样就完成了一个目标检测的服务化API的部署,默认端口号为8866。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
## 第二步:发送预测请求
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
```python
import requests
import json
import cv2
import base64
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
# 发送HTTP请求
data = {'images':[cv2_to_base64(cv2.imread("/PATH/TO/IMAGE"))]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/retinanet_resnet50_fpn_coco2017"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 打印预测结果
print(r.json()["results"])
```
### 依赖
paddlepaddle >= 1.6.2
paddlehub >= 1.6.0
# coding=utf-8
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
from collections import OrderedDict
import numpy as np
import cv2
from PIL import Image, ImageEnhance
from paddle import fluid
__all__ = ['test_reader', 'padding_minibatch']
def test_reader(paths=None, images=None):
"""
data generator
Args:
paths (list[str]): paths to images.
images (list(numpy.ndarray)): data of images, shape of each is [H, W, C]
Yield:
res (dict): key contains 'image' and 'im_info', the corresponding values is:
image (numpy.ndarray): the image to be fed into network
im_info (numpy.ndarray): the info about the preprocessed.
"""
img_list = list()
if paths:
for img_path in paths:
assert os.path.isfile(img_path), "The {} isn't a valid file path.".format(img_path)
img = cv2.imread(img_path).astype('float32')
img_list.append(img)
if images is not None:
for img in images:
img_list.append(img)
for im in img_list:
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = im.astype(np.float32, copy=False)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
mean = np.array(mean)[np.newaxis, np.newaxis, :]
std = np.array(std)[np.newaxis, np.newaxis, :]
im = im / 255.0
im -= mean
im /= std
target_size = 800
max_size = 1333
shape = im.shape
# im_shape holds the original shape of image.
# im_shape = np.array([shape[0], shape[1], 1.0]).astype('float32')
im_size_min = np.min(shape[0:2])
im_size_max = np.max(shape[0:2])
im_scale = float(target_size) / float(im_size_min)
if np.round(im_scale * im_size_max) > max_size:
im_scale = float(max_size) / float(im_size_max)
resize_w = np.round(im_scale * float(shape[1]))
resize_h = np.round(im_scale * float(shape[0]))
# im_info holds the resize info of image.
im_info = np.array([resize_h, resize_w, im_scale]).astype('float32')
im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR)
# HWC --> CHW
im = np.swapaxes(im, 1, 2)
im = np.swapaxes(im, 1, 0)
yield {'image': im, 'im_info': im_info}
def padding_minibatch(batch_data, coarsest_stride=0, use_padded_im_info=True):
max_shape_org = np.array([data['image'].shape for data in batch_data]).max(axis=0)
if coarsest_stride > 0:
max_shape = np.zeros((3)).astype('int32')
max_shape[1] = int(np.ceil(max_shape_org[1] / coarsest_stride) * coarsest_stride)
max_shape[2] = int(np.ceil(max_shape_org[2] / coarsest_stride) * coarsest_stride)
else:
max_shape = max_shape_org.astype('int32')
padding_image = list()
padding_info = list()
padding_shape = list()
for data in batch_data:
im_c, im_h, im_w = data['image'].shape
# image
padding_im = np.zeros((im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, 0:im_h, 0:im_w] = data['image']
padding_image.append(padding_im)
# im_info
data['im_info'][0] = max_shape[1] if use_padded_im_info else max_shape_org[1]
data['im_info'][1] = max_shape[2] if use_padded_im_info else max_shape_org[2]
padding_info.append(data['im_info'])
padding_image = np.array(padding_image).astype('float32')
padding_info = np.array(padding_info).astype('float32')
return padding_image, padding_info
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Xavier
from paddle.fluid.regularizer import L2Decay
__all__ = ['FPN']
def ConvNorm(input,
num_filters,
filter_size,
stride=1,
groups=1,
norm_decay=0.,
norm_type='affine_channel',
norm_groups=32,
dilation=1,
lr_scale=1,
freeze_norm=False,
act=None,
norm_name=None,
initializer=None,
name=None):
fan = num_filters
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=((filter_size - 1) // 2) * dilation,
dilation=dilation,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights", initializer=initializer, learning_rate=lr_scale),
bias_attr=False,
name=name + '.conv2d.output.1')
norm_lr = 0. if freeze_norm else 1.
pattr = ParamAttr(name=norm_name + '_scale', learning_rate=norm_lr * lr_scale, regularizer=L2Decay(norm_decay))
battr = ParamAttr(name=norm_name + '_offset', learning_rate=norm_lr * lr_scale, regularizer=L2Decay(norm_decay))
if norm_type in ['bn', 'sync_bn']:
global_stats = True if freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=norm_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=norm_name + '_mean',
moving_variance_name=norm_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif norm_type == 'gn':
out = fluid.layers.group_norm(
input=conv, act=act, name=norm_name + '.output.1', groups=norm_groups, param_attr=pattr, bias_attr=battr)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]], dtype=conv.dtype, attr=pattr, default_initializer=fluid.initializer.Constant(1.))
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]], dtype=conv.dtype, attr=battr, default_initializer=fluid.initializer.Constant(0.))
out = fluid.layers.affine_channel(x=conv, scale=scale, bias=bias, act=act)
if freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
class FPN(object):
"""
Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
Args:
num_chan (int): number of feature channels
min_level (int): lowest level of the backbone feature map to use
max_level (int): highest level of the backbone feature map to use
spatial_scale (list): feature map scaling factor
has_extra_convs (bool): whether has extral convolutions in higher levels
norm_type (str|None): normalization type, 'bn'/'sync_bn'/'affine_channel'
"""
__shared__ = ['norm_type', 'freeze_norm']
def __init__(self,
num_chan=256,
min_level=2,
max_level=6,
spatial_scale=[1. / 32., 1. / 16., 1. / 8., 1. / 4.],
has_extra_convs=False,
norm_type=None,
freeze_norm=False):
self.freeze_norm = freeze_norm
self.num_chan = num_chan
self.min_level = min_level
self.max_level = max_level
self.spatial_scale = spatial_scale
self.has_extra_convs = has_extra_convs
self.norm_type = norm_type
def _add_topdown_lateral(self, body_name, body_input, upper_output):
lateral_name = 'fpn_inner_' + body_name + '_lateral'
topdown_name = 'fpn_topdown_' + body_name
fan = body_input.shape[1]
if self.norm_type:
initializer = Xavier(fan_out=fan)
lateral = ConvNorm(
body_input,
self.num_chan,
1,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=lateral_name,
norm_name=lateral_name)
else:
lateral = fluid.layers.conv2d(
body_input,
self.num_chan,
1,
param_attr=ParamAttr(name=lateral_name + "_w", initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(name=lateral_name + "_b", learning_rate=2., regularizer=L2Decay(0.)),
name=lateral_name)
topdown = fluid.layers.resize_nearest(upper_output, scale=2., name=topdown_name)
return lateral + topdown
def get_output(self, body_dict):
"""
Add FPN onto backbone.
Args:
body_dict(OrderedDict): Dictionary of variables and each element is the
output of backbone.
Return:
fpn_dict(OrderedDict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
"""
spatial_scale = copy.deepcopy(self.spatial_scale)
body_name_list = list(body_dict.keys())[::-1]
num_backbone_stages = len(body_name_list)
self.fpn_inner_output = [[] for _ in range(num_backbone_stages)]
fpn_inner_name = 'fpn_inner_' + body_name_list[0]
body_input = body_dict[body_name_list[0]]
fan = body_input.shape[1]
if self.norm_type:
initializer = Xavier(fan_out=fan)
self.fpn_inner_output[0] = ConvNorm(
body_input,
self.num_chan,
1,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_inner_name,
norm_name=fpn_inner_name)
else:
self.fpn_inner_output[0] = fluid.layers.conv2d(
body_input,
self.num_chan,
1,
param_attr=ParamAttr(name=fpn_inner_name + "_w", initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(name=fpn_inner_name + "_b", learning_rate=2., regularizer=L2Decay(0.)),
name=fpn_inner_name)
for i in range(1, num_backbone_stages):
body_name = body_name_list[i]
body_input = body_dict[body_name]
top_output = self.fpn_inner_output[i - 1]
fpn_inner_single = self._add_topdown_lateral(body_name, body_input, top_output)
self.fpn_inner_output[i] = fpn_inner_single
fpn_dict = {}
fpn_name_list = []
for i in range(num_backbone_stages):
fpn_name = 'fpn_' + body_name_list[i]
fan = self.fpn_inner_output[i].shape[1] * 3 * 3
if self.norm_type:
initializer = Xavier(fan_out=fan)
fpn_output = ConvNorm(
self.fpn_inner_output[i],
self.num_chan,
3,
initializer=initializer,
norm_type=self.norm_type,
freeze_norm=self.freeze_norm,
name=fpn_name,
norm_name=fpn_name)
else:
fpn_output = fluid.layers.conv2d(
self.fpn_inner_output[i],
self.num_chan,
filter_size=3,
padding=1,
param_attr=ParamAttr(name=fpn_name + "_w", initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(name=fpn_name + "_b", learning_rate=2., regularizer=L2Decay(0.)),
name=fpn_name)
fpn_dict[fpn_name] = fpn_output
fpn_name_list.append(fpn_name)
if not self.has_extra_convs and self.max_level - self.min_level == len(spatial_scale):
body_top_name = fpn_name_list[0]
body_top_extension = fluid.layers.pool2d(
fpn_dict[body_top_name], 1, 'max', pool_stride=2, name=body_top_name + '_subsampled_2x')
fpn_dict[body_top_name + '_subsampled_2x'] = body_top_extension
fpn_name_list.insert(0, body_top_name + '_subsampled_2x')
spatial_scale.insert(0, spatial_scale[0] * 0.5)
# Coarser FPN levels introduced for RetinaNet
highest_backbone_level = self.min_level + len(spatial_scale) - 1
if self.has_extra_convs and self.max_level > highest_backbone_level:
fpn_blob = body_dict[body_name_list[0]]
for i in range(highest_backbone_level + 1, self.max_level + 1):
fpn_blob_in = fpn_blob
fpn_name = 'fpn_' + str(i)
if i > highest_backbone_level + 1:
fpn_blob_in = fluid.layers.relu(fpn_blob)
fan = fpn_blob_in.shape[1] * 3 * 3
fpn_blob = fluid.layers.conv2d(
input=fpn_blob_in,
num_filters=self.num_chan,
filter_size=3,
stride=2,
padding=1,
param_attr=ParamAttr(name=fpn_name + "_w", initializer=Xavier(fan_out=fan)),
bias_attr=ParamAttr(name=fpn_name + "_b", learning_rate=2., regularizer=L2Decay(0.)),
name=fpn_name)
fpn_dict[fpn_name] = fpn_blob
fpn_name_list.insert(0, fpn_name)
spatial_scale.insert(0, spatial_scale[0] * 0.5)
res_dict = OrderedDict([(k, fpn_dict[k]) for k in fpn_name_list])
return res_dict, spatial_scale
background
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import ast
import argparse
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.module.module import moduleinfo, runnable, serving
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.io.parser import txt_parser
from paddlehub.common.paddle_helper import add_vars_prefix
from retinanet_resnet50_fpn_coco2017.fpn import FPN
from retinanet_resnet50_fpn_coco2017.retina_head import AnchorGenerator, RetinaTargetAssign, RetinaOutputDecoder, RetinaHead
from retinanet_resnet50_fpn_coco2017.processor import load_label_info, postprocess, base64_to_cv2
from retinanet_resnet50_fpn_coco2017.data_feed import test_reader, padding_minibatch
from retinanet_resnet50_fpn_coco2017.resnet import ResNet
@moduleinfo(
name="retinanet_resnet50_fpn_coco2017",
version="1.0.0",
type="cv/object_detection",
summary="Baidu's RetinaNet model for object detection, with backbone ResNet50 and FPN.",
author="paddlepaddle",
author_email="paddle-dev@baidu.com")
class RetinaNetResNet50FPN(hub.Module):
def _initialize(self):
# default pretrained model of Retinanet_ResNet50_FPN, the shape of input image tensor is (3, 608, 608)
self.default_pretrained_model_path = os.path.join(self.directory, "retinanet_resnet50_fpn_model")
self.label_names = load_label_info(os.path.join(self.directory, "label_file.txt"))
self.infer_prog = None
self.image = None
self.im_info = None
self.bbox_out = None
self._set_config()
def _set_config(self):
"""
predictor config setting
"""
cpu_config = AnalysisConfig(self.default_pretrained_model_path)
cpu_config.disable_glog_info()
cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config)
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
use_gpu = True
except:
use_gpu = False
if use_gpu:
gpu_config = AnalysisConfig(self.default_pretrained_model_path)
gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config)
def context(self, num_classes=81, trainable=True, pretrained=True, phase='train'):
"""
Distill the Head Features, so as to perform transfer learning.
Args:
num_classes (int): number of classes.
trainable (bool): whether to set parameters trainable.
pretrained (bool): whether to load default pretrained model.
phase (str): optional choices are 'train' and 'predict'.
Returns:
inputs(dict): the input variables.
outputs(dict): the output variables.
context_prog (Program): the program to execute transfer learning.
"""
context_prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(context_prog, startup_program):
with fluid.unique_name.guard():
var_prefix = '@HUB_{}@'.format(self.name)
# image
image = fluid.layers.data(name='image', shape=[-1, 3, -1, -1], dtype='float32', lod_level=0)
# im_info
im_info = fluid.layers.data(name='im_info', shape=[3], dtype='float32', lod_level=0)
# backbone
backbone = ResNet(
norm_type='affine_channel', freeze_at=2, norm_decay=0., depth=50, feature_maps=[3, 4, 5])
body_feats = backbone(image)
# retina_head
retina_head = RetinaHead(
anchor_generator=AnchorGenerator(aspect_ratios=[1.0, 2.0, 0.5], variance=[1.0, 1.0, 1.0, 1.0]),
target_assign=RetinaTargetAssign(positive_overlap=0.5, negative_overlap=0.4),
output_decoder=RetinaOutputDecoder(
score_thresh=0.05, nms_thresh=0.5, pre_nms_top_n=1000, detections_per_im=100, nms_eta=1.0),
num_convs_per_octave=4,
num_chan=256,
max_level=7,
min_level=3,
prior_prob=0.01,
base_scale=4,
num_scales_per_octave=3)
# fpn
fpn = FPN(
max_level=7,
min_level=3,
num_chan=256,
spatial_scale=[0.03125, 0.0625, 0.125],
has_extra_convs=True)
# body_feats
body_feats, spatial_scale = fpn.get_output(body_feats)
# inputs, outputs, context_prog
inputs = {'image': var_prefix + image.name, 'im_info': var_prefix + im_info.name}
if phase == 'predict':
pred = retina_head.get_prediction(body_feats, spatial_scale, im_info)
outputs = {'bbox_out': var_prefix + pred.name}
else:
outputs = {'body_features': [var_prefix + var.name for key, var in body_feats.items()]}
# add_vars_prefix
add_vars_prefix(context_prog, var_prefix)
add_vars_prefix(fluid.default_startup_program(), var_prefix)
global_vars = context_prog.global_block().vars
inputs = {key: global_vars[value] for key, value in inputs.items()}
outputs = {
key: global_vars[value] if not isinstance(value, list) else [global_vars[var] for var in value]
for key, value in outputs.items()
}
place = fluid.CPUPlace()
exe = fluid.Executor(place)
for param in context_prog.global_block().iter_parameters():
param.trainable = trainable
if pretrained:
def _if_exist(var):
return os.path.exists(os.path.join(self.default_pretrained_model_path, var.name))
fluid.io.load_vars(exe, self.default_pretrained_model_path, predicate=_if_exist)
else:
exe.run(startup_program)
return inputs, outputs, context_prog
def save_inference_model(self, dirname, model_filename=None, params_filename=None, combined=True):
if combined:
model_filename = "__model__" if not model_filename else model_filename
params_filename = "__params__" if not params_filename else params_filename
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feeded_var_names, target_vars = fluid.io.load_inference_model(
dirname=self.default_pretrained_model_path, executor=exe)
fluid.io.save_inference_model(
dirname=dirname,
main_program=program,
executor=exe,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
model_filename=model_filename,
params_filename=params_filename)
def object_detection(self,
paths=None,
images=None,
use_gpu=False,
batch_size=1,
output_dir='detection_result',
score_thresh=0.5,
visualization=True):
"""API of Object Detection.
Args:
paths (list[str]): The paths of images.
images (list(numpy.ndarray)): images data, shape of each is [H, W, C]
batch_size (int): batch size.
use_gpu (bool): Whether to use gpu.
output_dir (str): The path to store output images.
visualization (bool): Whether to save image or not.
score_thresh (float): threshold for object detecion.
visualization (bool): whether to save result as images.
Returns:
res (list[dict]): The result of coco2017 detecion. keys include 'data', 'save_path', the corresponding value is:
data (dict): the result of object detection, keys include 'left', 'top', 'right', 'bottom', 'label', 'confidence', the corresponding value is:
left (float): The X coordinate of the upper left corner of the bounding box;
top (float): The Y coordinate of the upper left corner of the bounding box;
right (float): The X coordinate of the lower right corner of the bounding box;
bottom (float): The Y coordinate of the lower right corner of the bounding box;
label (str): The label of detection result;
confidence (float): The confidence of detection result.
save_path (str, optional): The path to save output images.
"""
if use_gpu:
try:
_places = os.environ["CUDA_VISIBLE_DEVICES"]
int(_places[0])
except:
raise RuntimeError(
"Environment Variable CUDA_VISIBLE_DEVICES is not set correctly. If you wanna use gpu, please set CUDA_VISIBLE_DEVICES as cuda_device_id."
)
all_images = list()
paths = paths if paths else list()
for yield_data in test_reader(paths, images):
all_images.append(yield_data)
images_num = len(all_images)
loop_num = int(np.ceil(images_num / batch_size))
res = list()
for iter_id in range(loop_num):
batch_data = list()
handle_id = iter_id * batch_size
for image_id in range(batch_size):
try:
batch_data.append(all_images[handle_id + image_id])
except:
pass
padding_image, padding_info = padding_minibatch(batch_data, coarsest_stride=32, use_padded_im_info=True)
padding_image_tensor = PaddleTensor(padding_image.copy())
padding_info_tensor = PaddleTensor(padding_info.copy())
feed_list = [padding_image_tensor, padding_info_tensor]
if use_gpu:
data_out = self.gpu_predictor.run(feed_list)
else:
data_out = self.cpu_predictor.run(feed_list)
output = postprocess(
paths=paths,
images=images,
data_out=data_out,
score_thresh=score_thresh,
label_names=self.label_names,
output_dir=output_dir,
handle_id=handle_id,
visualization=visualization)
res += output
return res
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU or not")
self.arg_config_group.add_argument('--batch_size', type=int, default=1, help="batch size for prediction")
def add_module_input_arg(self):
"""
Add the command input options
"""
self.arg_input_group.add_argument('--input_path', type=str, default=None, help="input data")
self.arg_input_group.add_argument('--input_file', type=str, default=None, help="file contain input data")
def check_input_data(self, args):
input_data = list()
if args.input_path:
input_data = [args.input_path]
elif args.input_file:
if not os.path.exists(args.input_file):
raise RuntimeError("File %s is not exist." % args.input_file)
else:
input_data = txt_parser.parse(args.input_file, use_strip=True)
return input_data
@serving
def serving_method(self, images, **kwargs):
"""
Run as a service.
"""
images_decode = [base64_to_cv2(image) for image in images]
results = self.object_detection(images=images_decode, **kwargs)
return results
@runnable
def run_cmd(self, argvs):
self.parser = argparse.ArgumentParser(
description="Run the {}".format(self.name),
prog="hub run {}".format(self.name),
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, not required.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
input_data = self.check_input_data(args)
if len(input_data) == 0:
self.parser.print_help()
exit(1)
else:
for image_path in input_data:
if not os.path.exists(image_path):
raise RuntimeError("File %s or %s is not exist." % image_path)
return self.object_detection(paths=input_data, use_gpu=args.use_gpu, batch_size=args.batch_size)
# coding=utf-8
class NameAdapter(object):
"""Fix the backbones variable names for pretrained weight"""
def __init__(self, model):
super(NameAdapter, self).__init__()
self.model = model
@property
def model_type(self):
return getattr(self.model, '_model_type', '')
@property
def variant(self):
return getattr(self.model, 'variant', '')
def fix_conv_norm_name(self, name):
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
# the naming rule is same as pretrained weight
if self.model_type == 'SEResNeXt':
bn_name = name + "_bn"
return bn_name
def fix_shortcut_name(self, name):
if self.model_type == 'SEResNeXt':
name = 'conv' + name + '_prj'
return name
def fix_bottleneck_name(self, name):
if self.model_type == 'SEResNeXt':
conv_name1 = 'conv' + name + '_x1'
conv_name2 = 'conv' + name + '_x2'
conv_name3 = 'conv' + name + '_x3'
shortcut_name = name
else:
conv_name1 = name + "_branch2a"
conv_name2 = name + "_branch2b"
conv_name3 = name + "_branch2c"
shortcut_name = name + "_branch1"
return conv_name1, conv_name2, conv_name3, shortcut_name
def fix_layer_warp_name(self, stage_num, count, i):
name = 'res' + str(stage_num)
if count > 10 and stage_num == 4:
if i == 0:
conv_name = name + "a"
else:
conv_name = name + "b" + str(i)
else:
conv_name = name + chr(ord("a") + i)
if self.model_type == 'SEResNeXt':
conv_name = str(stage_num + 2) + '_' + str(i + 1)
return conv_name
def fix_c1_stage_name(self):
return "res_conv1" if self.model_type == 'ResNeXt' else "conv1"
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
nonlocal_params = {
"use_zero_init_conv": False,
"conv_init_std": 0.01,
"no_bias": True,
"use_maxpool": False,
"use_softmax": True,
"use_bn": False,
"use_scale": True, # vital for the model prformance!!!
"use_affine": False,
"bn_momentum": 0.9,
"bn_epsilon": 1.0000001e-5,
"bn_init_gamma": 0.9,
"weight_decay_bn": 1.e-4,
}
def space_nonlocal(input, dim_in, dim_out, prefix, dim_inner, max_pool_stride=2):
cur = input
theta = fluid.layers.conv2d(input = cur, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr=ParamAttr(name = prefix + '_theta' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_theta' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if not nonlocal_params["no_bias"] else False, \
name = prefix + '_theta')
theta_shape = theta.shape
theta_shape_op = fluid.layers.shape(theta)
theta_shape_op.stop_gradient = True
if nonlocal_params["use_maxpool"]:
max_pool = fluid.layers.pool2d(input = cur, \
pool_size = [max_pool_stride, max_pool_stride], \
pool_type = 'max', \
pool_stride = [max_pool_stride, max_pool_stride], \
pool_padding = [0, 0], \
name = prefix + '_pool')
else:
max_pool = cur
phi = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_phi' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_phi' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_phi')
phi_shape = phi.shape
g = fluid.layers.conv2d(input = max_pool, num_filters = dim_inner, \
filter_size = [1, 1], stride = [1, 1], \
padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_g' + "_w", \
initializer = fluid.initializer.Normal(loc = 0.0, scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_g' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_g')
g_shape = g.shape
# we have to use explicit batch size (to support arbitrary spacetime size)
# e.g. (8, 1024, 4, 14, 14) => (8, 1024, 784)
theta = fluid.layers.reshape(theta, shape=(0, 0, -1))
theta = fluid.layers.transpose(theta, [0, 2, 1])
phi = fluid.layers.reshape(phi, [0, 0, -1])
theta_phi = fluid.layers.matmul(theta, phi, name=prefix + '_affinity')
g = fluid.layers.reshape(g, [0, 0, -1])
if nonlocal_params["use_softmax"]:
if nonlocal_params["use_scale"]:
theta_phi_sc = fluid.layers.scale(theta_phi, scale=dim_inner**-.5)
else:
theta_phi_sc = theta_phi
p = fluid.layers.softmax(theta_phi_sc, name=prefix + '_affinity' + '_prob')
else:
# not clear about what is doing in xlw's code
p = None # not implemented
raise "Not implemented when not use softmax"
# note g's axis[2] corresponds to p's axis[2]
# e.g. g(8, 1024, 784_2) * p(8, 784_1, 784_2) => (8, 1024, 784_1)
p = fluid.layers.transpose(p, [0, 2, 1])
t = fluid.layers.matmul(g, p, name=prefix + '_y')
# reshape back
# e.g. (8, 1024, 784) => (8, 1024, 4, 14, 14)
t_shape = t.shape
t_re = fluid.layers.reshape(t, shape=list(theta_shape), actual_shape=theta_shape_op)
blob_out = t_re
blob_out = fluid.layers.conv2d(input = blob_out, num_filters = dim_out, \
filter_size = [1, 1], stride = [1, 1], padding = [0, 0], \
param_attr = ParamAttr(name = prefix + '_out' + "_w", \
initializer = fluid.initializer.Constant(value = 0.) \
if nonlocal_params["use_zero_init_conv"] \
else fluid.initializer.Normal(loc = 0.0,
scale = nonlocal_params["conv_init_std"])), \
bias_attr = ParamAttr(name = prefix + '_out' + "_b", \
initializer = fluid.initializer.Constant(value = 0.)) \
if (nonlocal_params["no_bias"] == 0) else False, \
name = prefix + '_out')
blob_out_shape = blob_out.shape
if nonlocal_params["use_bn"]:
bn_name = prefix + "_bn"
blob_out = fluid.layers.batch_norm(blob_out, \
# is_test = test_mode, \
momentum = nonlocal_params["bn_momentum"], \
epsilon = nonlocal_params["bn_epsilon"], \
name = bn_name, \
param_attr = ParamAttr(name = bn_name + "_s", \
initializer = fluid.initializer.Constant(value = nonlocal_params["bn_init_gamma"]), \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
bias_attr = ParamAttr(name = bn_name + "_b", \
regularizer = fluid.regularizer.L2Decay(nonlocal_params["weight_decay_bn"])), \
moving_mean_name = bn_name + "_rm", \
moving_variance_name = bn_name + "_riv") # add bn
if nonlocal_params["use_affine"]:
affine_scale = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_s'), \
default_initializer = fluid.initializer.Constant(value = 1.))
affine_bias = fluid.layers.create_parameter(\
shape=[blob_out_shape[1]], dtype = blob_out.dtype, \
attr=ParamAttr(name=prefix + '_affine' + '_b'), \
default_initializer = fluid.initializer.Constant(value = 0.))
blob_out = fluid.layers.affine_channel(blob_out, scale = affine_scale, \
bias = affine_bias, name = prefix + '_affine') # add affine
return blob_out
def add_space_nonlocal(input, dim_in, dim_out, prefix, dim_inner):
'''
add_space_nonlocal:
Non-local Neural Networks: see https://arxiv.org/abs/1711.07971
'''
conv = space_nonlocal(input, dim_in, dim_out, prefix, dim_inner)
output = fluid.layers.elementwise_add(input, conv, name=prefix + '_sum')
return output
# coding=utf-8
import base64
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw
__all__ = [
'base64_to_cv2',
'load_label_info',
'postprocess',
]
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
def get_save_image_name(img, output_dir, image_path):
"""Get save image name from source image path.
"""
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
if ext == '':
if img.format == 'PNG':
ext = '.png'
elif img.format == 'JPEG':
ext = '.jpg'
elif img.format == 'BMP':
ext = '.bmp'
else:
if img.mode == "RGB" or img.mode == "L":
ext = ".jpg"
elif img.mode == "RGBA" or img.mode == "P":
ext = '.png'
return os.path.join(output_dir, "{}".format(name)) + ext
def draw_bounding_box_on_image(image_path, data_list, save_dir):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
for data in data_list:
left, right, top, bottom = data['left'], data['right'], data['top'], data['bottom']
# draw bbox
draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=2, fill='red')
# draw label
if image.mode == 'RGB':
text = data['label'] + ": %.2f%%" % (100 * data['confidence'])
textsize_width, textsize_height = draw.textsize(text=text)
draw.rectangle(
xy=(left, top - (textsize_height + 5), left + textsize_width + 10, top), fill=(255, 255, 255))
draw.text(xy=(left, top - 15), text=text, fill=(0, 0, 0))
save_name = get_save_image_name(image, save_dir, image_path)
if os.path.exists(save_name):
os.remove(save_name)
image.save(save_name)
return save_name
def clip_bbox(bbox, img_width, img_height):
xmin = max(min(bbox[0], img_width), 0.)
ymin = max(min(bbox[1], img_height), 0.)
xmax = max(min(bbox[2], img_width), 0.)
ymax = max(min(bbox[3], img_height), 0.)
return float(xmin), float(ymin), float(xmax), float(ymax)
def load_label_info(file_path):
with open(file_path, 'r') as fr:
text = fr.readlines()
label_names = []
for info in text:
label_names.append(info.strip())
return label_names
def postprocess(paths, images, data_out, score_thresh, label_names, output_dir, handle_id, visualization):
"""
postprocess the lod_tensor produced by fluid.Executor.run
Args:
paths (list[str]): the path of images.
images (list(numpy.ndarray)): list of images, shape of each is [H, W, C].
data_out (lod_tensor): data produced by executor.run.
score_thresh (float): the low limit of bounding box.
label_names (list[str]): label names.
output_dir (str): output directory.
handle_id (int): The number of images that have been handled.
visualization (bool): whether to save as images.
Returns:
res (list[dict]): The result of vehicles detecion. keys include 'data', 'save_path', the corresponding value is:
data (dict): the result of object detection, keys include 'left', 'top', 'right', 'bottom', 'label', 'confidence', the corresponding value is:
left (float): The X coordinate of the upper left corner of the bounding box;
top (float): The Y coordinate of the upper left corner of the bounding box;
right (float): The X coordinate of the lower right corner of the bounding box;
bottom (float): The Y coordinate of the lower right corner of the bounding box;
label (str): The label of detection result;
confidence (float): The confidence of detection result.
save_path (str): The path to save output images.
"""
lod_tensor = data_out[0]
lod = lod_tensor.lod[0]
results = lod_tensor.as_ndarray()
if handle_id < len(paths):
unhandled_paths = paths[handle_id:]
unhandled_paths_num = len(unhandled_paths)
else:
unhandled_paths_num = 0
output_dir = output_dir if output_dir else os.path.join(os.getcwd(), 'detection_result')
if visualization:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output = []
for index in range(len(lod) - 1):
output_i = {'data': []}
if index < unhandled_paths_num:
org_img_path = unhandled_paths[index]
org_img = Image.open(org_img_path)
output_i['path'] = org_img_path
else:
org_img = images[index - unhandled_paths_num]
org_img = org_img.astype(np.uint8)
org_img = Image.fromarray(org_img[:, :, ::-1])
if visualization:
org_img_path = get_save_image_name(org_img, output_dir, 'image_numpy_{}'.format((handle_id + index)))
org_img.save(org_img_path)
org_img_height = org_img.height
org_img_width = org_img.width
result_i = results[lod[index]:lod[index + 1]]
for row in result_i:
if len(row) != 6:
continue
if row[1] < score_thresh:
continue
category_id = int(row[0])
confidence = row[1]
bbox = row[2:]
dt = {}
dt['label'] = label_names[category_id]
dt['confidence'] = float(confidence)
dt['left'], dt['top'], dt['right'], dt['bottom'] = clip_bbox(bbox, org_img_width, org_img_height)
output_i['data'].append(dt)
output.append(output_i)
if visualization:
output_i['save_path'] = draw_bounding_box_on_image(org_img_path, output_i['data'], output_dir)
return output
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
from collections import OrderedDict
from numbers import Integral
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.framework import Variable
from paddle.fluid.regularizer import L2Decay
from paddle.fluid.initializer import Constant
from .nonlocal_helper import add_space_nonlocal
from .name_adapter import NameAdapter
__all__ = ['ResNet', 'ResNetC5']
class ResNet(object):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
Args:
depth (int): ResNet depth, should be 34, 50.
freeze_at (int): freeze the backbone at which stage
norm_type (str): normalization type, 'bn'/'sync_bn'/'affine_channel'
freeze_norm (bool): freeze normalization layers
norm_decay (float): weight decay for normalization layer weights
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
feature_maps (list): index of stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
nonlocal_stages (list): index of stages who select nonlocal networks
"""
__shared__ = ['norm_type', 'freeze_norm', 'weight_prefix_name']
def __init__(self,
depth=50,
freeze_at=0,
norm_type='sync_bn',
freeze_norm=False,
norm_decay=0.,
variant='b',
feature_maps=[3, 4, 5],
dcn_v2_stages=[],
weight_prefix_name='',
nonlocal_stages=[],
get_prediction=False,
class_dim=1000):
super(ResNet, self).__init__()
if isinstance(feature_maps, Integral):
feature_maps = [feature_maps]
assert depth in [34, 50], \
"depth {} not in [34, 50]"
assert variant in ['a', 'b', 'c', 'd'], "invalid ResNet variant"
assert 0 <= freeze_at <= 4, "freeze_at should be 0, 1, 2, 3 or 4"
assert len(feature_maps) > 0, "need one or more feature maps"
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
assert not (len(nonlocal_stages)>0 and depth<50), \
"non-local is not supported for resnet18 or resnet34"
self.depth = depth
self.freeze_at = freeze_at
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.variant = variant
self._model_type = 'ResNet'
self.feature_maps = feature_maps
self.dcn_v2_stages = dcn_v2_stages
self.depth_cfg = {
34: ([3, 4, 6, 3], self.basicblock),
50: ([3, 4, 6, 3], self.bottleneck),
}
self.stage_filters = [64, 128, 256, 512]
self._c1_out_chan_num = 64
self.na = NameAdapter(self)
self.prefix_name = weight_prefix_name
self.nonlocal_stages = nonlocal_stages
self.nonlocal_mod_cfg = {
50: 2,
101: 5,
152: 8,
200: 12,
}
self.get_prediction = get_prediction
self.class_dim = class_dim
def _conv_offset(self, input, filter_size, stride, padding, act=None, name=None):
out_channel = filter_size * filter_size * 3
out = fluid.layers.conv2d(
input,
num_filters=out_channel,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(initializer=Constant(0.0), name=name + ".w_0"),
bias_attr=ParamAttr(initializer=Constant(0.0), name=name + ".b_0"),
act=act,
name=name)
return out
def _conv_norm(self, input, num_filters, filter_size, stride=1, groups=1, act=None, name=None, dcn_v2=False):
_name = self.prefix_name + name if self.prefix_name != '' else name
if not dcn_v2:
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + '.conv2d.output.1')
else:
# select deformable conv"
offset_mask = self._conv_offset(
input=input,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
act=None,
name=_name + "_conv_offset")
offset_channel = filter_size**2 * 2
mask_channel = filter_size**2
offset, mask = fluid.layers.split(input=offset_mask, num_or_sections=[offset_channel, mask_channel], dim=1)
mask = fluid.layers.sigmoid(mask)
conv = fluid.layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
deformable_groups=1,
im2col_step=1,
param_attr=ParamAttr(name=_name + "_weights"),
bias_attr=False,
name=_name + ".conv2d.output.1")
bn_name = self.na.fix_conv_norm_name(name)
bn_name = self.prefix_name + bn_name if self.prefix_name != '' else bn_name
norm_lr = 0. if self.freeze_norm else 1.
norm_decay = self.norm_decay
pattr = ParamAttr(name=bn_name + '_scale', learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
battr = ParamAttr(name=bn_name + '_offset', learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
if self.norm_type in ['bn', 'sync_bn']:
global_stats = True if self.freeze_norm else False
out = fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=pattr,
bias_attr=battr,
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=global_stats)
scale = fluid.framework._get_var(pattr.name)
bias = fluid.framework._get_var(battr.name)
elif self.norm_type == 'affine_channel':
scale = fluid.layers.create_parameter(
shape=[conv.shape[1]], dtype=conv.dtype, attr=pattr, default_initializer=fluid.initializer.Constant(1.))
bias = fluid.layers.create_parameter(
shape=[conv.shape[1]], dtype=conv.dtype, attr=battr, default_initializer=fluid.initializer.Constant(0.))
out = fluid.layers.affine_channel(x=conv, scale=scale, bias=bias, act=act)
if self.freeze_norm:
scale.stop_gradient = True
bias.stop_gradient = True
return out
def _shortcut(self, input, ch_out, stride, is_first, name):
max_pooling_in_short_cut = self.variant == 'd'
ch_in = input.shape[1]
# the naming rule is same as pretrained weight
name = self.na.fix_shortcut_name(name)
std_senet = getattr(self, 'std_senet', False)
if ch_in != ch_out or stride != 1 or (self.depth < 50 and is_first):
if std_senet:
if is_first:
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return self._conv_norm(input, ch_out, 3, stride, name=name)
if max_pooling_in_short_cut and not is_first:
input = fluid.layers.pool2d(
input=input, pool_size=2, pool_stride=2, pool_padding=0, ceil_mode=True, pool_type='avg')
return self._conv_norm(input, ch_out, 1, 1, name=name)
return self._conv_norm(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck(self, input, num_filters, stride, is_first, name, dcn_v2=False):
if self.variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
# ResNeXt
groups = getattr(self, 'groups', 1)
group_width = getattr(self, 'group_width', -1)
if groups == 1:
expand = 4
elif (groups * group_width) == 256:
expand = 1
else: # FIXME hard code for now, handles 32x4d, 64x4d and 32x8d
num_filters = num_filters // 2
expand = 2
conv_name1, conv_name2, conv_name3, \
shortcut_name = self.na.fix_bottleneck_name(name)
std_senet = getattr(self, 'std_senet', False)
if std_senet:
conv_def = [[int(num_filters / 2), 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
else:
conv_def = [[num_filters, 1, stride1, 'relu', 1, conv_name1],
[num_filters, 3, stride2, 'relu', groups, conv_name2],
[num_filters * expand, 1, 1, None, 1, conv_name3]]
residual = input
for i, (c, k, s, act, g, _name) in enumerate(conv_def):
residual = self._conv_norm(
input=residual,
num_filters=c,
filter_size=k,
stride=s,
act=act,
groups=g,
name=_name,
dcn_v2=(i == 1 and dcn_v2))
short = self._shortcut(input, num_filters * expand, stride, is_first=is_first, name=shortcut_name)
# Squeeze-and-Excitation
if callable(getattr(self, '_squeeze_excitation', None)):
residual = self._squeeze_excitation(input=residual, num_channels=num_filters, name='fc' + name)
return fluid.layers.elementwise_add(x=short, y=residual, act='relu', name=name + ".add.output.5")
def basicblock(self, input, num_filters, stride, is_first, name, dcn_v2=False):
assert dcn_v2 is False, "Not implemented yet."
conv0 = self._conv_norm(
input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, name=name + "_branch2a")
conv1 = self._conv_norm(input=conv0, num_filters=num_filters, filter_size=3, act=None, name=name + "_branch2b")
short = self._shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def layer_warp(self, input, stage_num):
"""
Args:
input (Variable): input variable.
stage_num (int): the stage number, should be 2, 3, 4, 5
Returns:
The last variable in endpoint-th stage.
"""
assert stage_num in [2, 3, 4, 5]
stages, block_func = self.depth_cfg[self.depth]
count = stages[stage_num - 2]
ch_out = self.stage_filters[stage_num - 2]
is_first = False if stage_num != 2 else True
dcn_v2 = True if stage_num in self.dcn_v2_stages else False
nonlocal_mod = 1000
if stage_num in self.nonlocal_stages:
nonlocal_mod = self.nonlocal_mod_cfg[self.depth] if stage_num == 4 else 2
# Make the layer name and parameter name consistent
# with ImageNet pre-trained model
conv = input
for i in range(count):
conv_name = self.na.fix_layer_warp_name(stage_num, count, i)
if self.depth < 50:
is_first = True if i == 0 and stage_num == 2 else False
conv = block_func(
input=conv,
num_filters=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
is_first=is_first,
name=conv_name,
dcn_v2=dcn_v2)
# add non local model
dim_in = conv.shape[1]
nonlocal_name = "nonlocal_conv{}".format(stage_num)
if i % nonlocal_mod == nonlocal_mod - 1:
conv = add_space_nonlocal(conv, dim_in, dim_in, nonlocal_name + '_{}'.format(i), int(dim_in / 2))
return conv
def c1_stage(self, input):
out_chan = self._c1_out_chan_num
conv1_name = self.na.fix_c1_stage_name()
if self.variant in ['c', 'd']:
conv_def = [
[out_chan // 2, 3, 2, "conv1_1"],
[out_chan // 2, 3, 1, "conv1_2"],
[out_chan, 3, 1, "conv1_3"],
]
else:
conv_def = [[out_chan, 7, 2, conv1_name]]
for (c, k, s, _name) in conv_def:
input = self._conv_norm(input=input, num_filters=c, filter_size=k, stride=s, act='relu', name=_name)
output = fluid.layers.pool2d(input=input, pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
return output
def __call__(self, input):
assert isinstance(input, Variable)
assert not (set(self.feature_maps) - set([2, 3, 4, 5])), \
"feature maps {} not in [2, 3, 4, 5]".format(self.feature_maps)
res_endpoints = []
res = input
feature_maps = self.feature_maps
severed_head = getattr(self, 'severed_head', False)
if not severed_head:
res = self.c1_stage(res)
feature_maps = range(2, max(self.feature_maps) + 1)
for i in feature_maps:
res = self.layer_warp(res, i)
if i in self.feature_maps:
res_endpoints.append(res)
if self.freeze_at >= i:
res.stop_gradient = True
if self.get_prediction:
pool = fluid.layers.pool2d(input=res, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(initializer=fluid.initializer.Uniform(-stdv, stdv)))
out = fluid.layers.softmax(out)
return out
return OrderedDict(
[('res{}_sum'.format(self.feature_maps[idx]), feat) for idx, feat in enumerate(res_endpoints)])
class ResNetC5(ResNet):
def __init__(self,
depth=50,
freeze_at=2,
norm_type='affine_channel',
freeze_norm=True,
norm_decay=0.,
variant='b',
feature_maps=[5],
weight_prefix_name=''):
super(ResNetC5, self).__init__(depth, freeze_at, norm_type, freeze_norm, norm_decay, variant, feature_maps)
self.severed_head = True
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant
from paddle.fluid.regularizer import L2Decay
__all__ = ['AnchorGenerator', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'RetinaHead']
class AnchorGenerator(object):
# __op__ = fluid.layers.anchor_generator
def __init__(self,
stride=[16.0, 16.0],
anchor_sizes=[32, 64, 128, 256, 512],
aspect_ratios=[0.5, 1., 2.],
variance=[1., 1., 1., 1.]):
self.anchor_sizes = anchor_sizes
self.aspect_ratios = aspect_ratios
self.variance = variance
self.stride = stride
class RetinaTargetAssign(object):
# __op__ = fluid.layers.retinanet_target_assign
def __init__(self, positive_overlap=0.5, negative_overlap=0.4):
self.positive_overlap = positive_overlap
self.negative_overlap = negative_overlap
class RetinaOutputDecoder(object):
# __op__ = fluid.layers.retinanet_detection_output
def __init__(self, score_thresh=0.05, nms_thresh=0.3, pre_nms_top_n=1000, detections_per_im=100, nms_eta=1.0):
super(RetinaOutputDecoder, self).__init__()
self.score_threshold = score_thresh
self.nms_threshold = nms_thresh
self.nms_top_k = pre_nms_top_n
self.keep_top_k = detections_per_im
self.nms_eta = nms_eta
class RetinaHead(object):
"""
Retina Head
Args:
anchor_generator (object): `AnchorGenerator` instance
target_assign (object): `RetinaTargetAssign` instance
output_decoder (object): `RetinaOutputDecoder` instance
num_convs_per_octave (int): Number of convolution layers in each octave
num_chan (int): Number of octave output channels
max_level (int): Highest level of FPN output
min_level (int): Lowest level of FPN output
prior_prob (float): Used to set the bias init for the class prediction layer
base_scale (int): Anchors are generated based on this scale
num_scales_per_octave (int): Number of anchor scales per octave
num_classes (int): Number of classes
gamma (float): The parameter in focal loss
alpha (float): The parameter in focal loss
sigma (float): The parameter in smooth l1 loss
"""
__inject__ = ['anchor_generator', 'target_assign', 'output_decoder']
__shared__ = ['num_classes']
def __init__(self,
anchor_generator=AnchorGenerator(),
target_assign=RetinaTargetAssign(),
output_decoder=RetinaOutputDecoder(),
num_convs_per_octave=4,
num_chan=256,
max_level=7,
min_level=3,
prior_prob=0.01,
base_scale=4,
num_scales_per_octave=3,
num_classes=81,
gamma=2.0,
alpha=0.25,
sigma=3.0151134457776365):
self.anchor_generator = anchor_generator
self.target_assign = target_assign
self.output_decoder = output_decoder
self.num_convs_per_octave = num_convs_per_octave
self.num_chan = num_chan
self.max_level = max_level
self.min_level = min_level
self.prior_prob = prior_prob
self.base_scale = base_scale
self.num_scales_per_octave = num_scales_per_octave
self.num_classes = num_classes
self.gamma = gamma
self.alpha = alpha
self.sigma = sigma
def _class_subnet(self, body_feats, spatial_scale):
"""
Get class predictions of all level FPN level.
Args:
fpn_dict(dict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
Returns:
cls_pred_input(list): Class prediction of all input fpn levels.
"""
assert len(body_feats) == self.max_level - self.min_level + 1
fpn_name_list = list(body_feats.keys())
cls_pred_list = []
for lvl in range(self.min_level, self.max_level + 1):
fpn_name = fpn_name_list[self.max_level - lvl]
subnet_blob = body_feats[fpn_name]
for i in range(self.num_convs_per_octave):
conv_name = 'retnet_cls_conv_n{}_fpn{}'.format(i, lvl)
conv_share_name = 'retnet_cls_conv_n{}_fpn{}'.format(i, self.min_level)
subnet_blob_in = subnet_blob
subnet_blob = fluid.layers.conv2d(
input=subnet_blob_in,
num_filters=self.num_chan,
filter_size=3,
stride=1,
padding=1,
act='relu',
name=conv_name,
param_attr=ParamAttr(name=conv_share_name + '_w', initializer=Normal(loc=0., scale=0.01)),
bias_attr=ParamAttr(name=conv_share_name + '_b', learning_rate=2., regularizer=L2Decay(0.)))
# class prediction
cls_name = 'retnet_cls_pred_fpn{}'.format(lvl)
cls_share_name = 'retnet_cls_pred_fpn{}'.format(self.min_level)
num_anchors = self.num_scales_per_octave * len(self.anchor_generator.aspect_ratios)
cls_dim = num_anchors * (self.num_classes - 1)
# bias initialization: b = -log((1 - pai) / pai)
bias_init = float(-np.log((1 - self.prior_prob) / self.prior_prob))
out_cls = fluid.layers.conv2d(
input=subnet_blob,
num_filters=cls_dim,
filter_size=3,
stride=1,
padding=1,
act=None,
name=cls_name,
param_attr=ParamAttr(name=cls_share_name + '_w', initializer=Normal(loc=0., scale=0.01)),
bias_attr=ParamAttr(
name=cls_share_name + '_b',
initializer=Constant(value=bias_init),
learning_rate=2.,
regularizer=L2Decay(0.)))
cls_pred_list.append(out_cls)
return cls_pred_list
def _bbox_subnet(self, body_feats, spatial_scale):
"""
Get bounding box predictions of all level FPN level.
Args:
fpn_dict(dict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
Returns:
bbox_pred_input(list): Bounding box prediction of all input fpn
levels.
"""
assert len(body_feats) == self.max_level - self.min_level + 1
fpn_name_list = list(body_feats.keys())
bbox_pred_list = []
for lvl in range(self.min_level, self.max_level + 1):
fpn_name = fpn_name_list[self.max_level - lvl]
subnet_blob = body_feats[fpn_name]
for i in range(self.num_convs_per_octave):
conv_name = 'retnet_bbox_conv_n{}_fpn{}'.format(i, lvl)
conv_share_name = 'retnet_bbox_conv_n{}_fpn{}'.format(i, self.min_level)
subnet_blob_in = subnet_blob
subnet_blob = fluid.layers.conv2d(
input=subnet_blob_in,
num_filters=self.num_chan,
filter_size=3,
stride=1,
padding=1,
act='relu',
name=conv_name,
param_attr=ParamAttr(name=conv_share_name + '_w', initializer=Normal(loc=0., scale=0.01)),
bias_attr=ParamAttr(name=conv_share_name + '_b', learning_rate=2., regularizer=L2Decay(0.)))
# bbox prediction
bbox_name = 'retnet_bbox_pred_fpn{}'.format(lvl)
bbox_share_name = 'retnet_bbox_pred_fpn{}'.format(self.min_level)
num_anchors = self.num_scales_per_octave * len(self.anchor_generator.aspect_ratios)
bbox_dim = num_anchors * 4
out_bbox = fluid.layers.conv2d(
input=subnet_blob,
num_filters=bbox_dim,
filter_size=3,
stride=1,
padding=1,
act=None,
name=bbox_name,
param_attr=ParamAttr(name=bbox_share_name + '_w', initializer=Normal(loc=0., scale=0.01)),
bias_attr=ParamAttr(name=bbox_share_name + '_b', learning_rate=2., regularizer=L2Decay(0.)))
bbox_pred_list.append(out_bbox)
return bbox_pred_list
def _anchor_generate(self, body_feats, spatial_scale):
"""
Get anchor boxes of all level FPN level.
Args:
fpn_dict(dict): A dictionary represents the output of FPN with their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
Return:
anchor_input(list): Anchors of all input fpn levels with shape of.
anchor_var_input(list): Anchor variance of all input fpn levels with shape.
"""
assert len(body_feats) == self.max_level - self.min_level + 1
fpn_name_list = list(body_feats.keys())
anchor_list = []
anchor_var_list = []
for lvl in range(self.min_level, self.max_level + 1):
anchor_sizes = []
stride = int(1 / spatial_scale[self.max_level - lvl])
for octave in range(self.num_scales_per_octave):
anchor_size = stride * (2**(float(octave) / float(self.num_scales_per_octave))) * self.base_scale
anchor_sizes.append(anchor_size)
fpn_name = fpn_name_list[self.max_level - lvl]
anchor, anchor_var = fluid.layers.anchor_generator(
input=body_feats[fpn_name],
anchor_sizes=anchor_sizes,
aspect_ratios=self.anchor_generator.aspect_ratios,
stride=[stride, stride],
variance=self.anchor_generator.variance)
anchor_list.append(anchor)
anchor_var_list.append(anchor_var)
return anchor_list, anchor_var_list
def _get_output(self, body_feats, spatial_scale):
"""
Get class, bounding box predictions and anchor boxes of all level FPN level.
Args:
fpn_dict(dict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
Returns:
cls_pred_input(list): Class prediction of all input fpn levels.
bbox_pred_input(list): Bounding box prediction of all input fpn
levels.
anchor_input(list): Anchors of all input fpn levels with shape of.
anchor_var_input(list): Anchor variance of all input fpn levels with
shape.
"""
assert len(body_feats) == self.max_level - self.min_level + 1
# class subnet
cls_pred_list = self._class_subnet(body_feats, spatial_scale)
# bbox subnet
bbox_pred_list = self._bbox_subnet(body_feats, spatial_scale)
#generate anchors
anchor_list, anchor_var_list = self._anchor_generate(body_feats, spatial_scale)
cls_pred_reshape_list = []
bbox_pred_reshape_list = []
anchor_reshape_list = []
anchor_var_reshape_list = []
for i in range(self.max_level - self.min_level + 1):
cls_pred_transpose = fluid.layers.transpose(cls_pred_list[i], perm=[0, 2, 3, 1])
cls_pred_reshape = fluid.layers.reshape(cls_pred_transpose, shape=(0, -1, self.num_classes - 1))
bbox_pred_transpose = fluid.layers.transpose(bbox_pred_list[i], perm=[0, 2, 3, 1])
bbox_pred_reshape = fluid.layers.reshape(bbox_pred_transpose, shape=(0, -1, 4))
anchor_reshape = fluid.layers.reshape(anchor_list[i], shape=(-1, 4))
anchor_var_reshape = fluid.layers.reshape(anchor_var_list[i], shape=(-1, 4))
cls_pred_reshape_list.append(cls_pred_reshape)
bbox_pred_reshape_list.append(bbox_pred_reshape)
anchor_reshape_list.append(anchor_reshape)
anchor_var_reshape_list.append(anchor_var_reshape)
output = {}
output['cls_pred'] = cls_pred_reshape_list
output['bbox_pred'] = bbox_pred_reshape_list
output['anchor'] = anchor_reshape_list
output['anchor_var'] = anchor_var_reshape_list
return output
def get_prediction(self, body_feats, spatial_scale, im_info):
"""
Get prediction bounding box in test stage.
Args:
fpn_dict(dict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
im_info (Variable): A 2-D LoDTensor with shape [B, 3]. B is the
number of input images, each element consists of im_height,
im_width, im_scale.
Returns:
pred_result(Variable): Prediction result with shape [N, 6]. Each
row has 6 values: [label, confidence, xmin, ymin, xmax, ymax].
N is the total number of prediction.
"""
output = self._get_output(body_feats, spatial_scale)
cls_pred_reshape_list = output['cls_pred']
bbox_pred_reshape_list = output['bbox_pred']
anchor_reshape_list = output['anchor']
for i in range(self.max_level - self.min_level + 1):
cls_pred_reshape_list[i] = fluid.layers.sigmoid(cls_pred_reshape_list[i])
pred_result = fluid.layers.retinanet_detection_output(
bboxes=bbox_pred_reshape_list,
scores=cls_pred_reshape_list,
anchors=anchor_reshape_list,
im_info=im_info,
score_threshold=self.output_decoder.score_threshold,
nms_threshold=self.output_decoder.nms_threshold,
nms_top_k=self.output_decoder.nms_top_k,
keep_top_k=self.output_decoder.keep_top_k,
nms_eta=self.output_decoder.nms_eta)
return pred_result
def get_loss(self, body_feats, spatial_scale, im_info, gt_box, gt_label, is_crowd):
"""
Calculate the loss of retinanet.
Args:
fpn_dict(dict): A dictionary represents the output of FPN with
their name.
spatial_scale(list): A list of multiplicative spatial scale factor.
im_info(Variable): A 2-D LoDTensor with shape [B, 3]. B is the
number of input images, each element consists of im_height,
im_width, im_scale.
gt_box(Variable): The ground-truth bounding boxes with shape [M, 4].
M is the number of groundtruth.
gt_label(Variable): The ground-truth labels with shape [M, 1].
M is the number of groundtruth.
is_crowd(Variable): Indicates groud-truth is crowd or not with
shape [M, 1]. M is the number of groundtruth.
Returns:
Type: dict
loss_cls(Variable): focal loss.
loss_bbox(Variable): smooth l1 loss.
"""
output = self._get_output(body_feats, spatial_scale)
cls_pred_reshape_list = output['cls_pred']
bbox_pred_reshape_list = output['bbox_pred']
anchor_reshape_list = output['anchor']
anchor_var_reshape_list = output['anchor_var']
cls_pred_input = fluid.layers.concat(cls_pred_reshape_list, axis=1)
bbox_pred_input = fluid.layers.concat(bbox_pred_reshape_list, axis=1)
anchor_input = fluid.layers.concat(anchor_reshape_list, axis=0)
anchor_var_input = fluid.layers.concat(anchor_var_reshape_list, axis=0)
score_pred, loc_pred, score_tgt, loc_tgt, bbox_weight, fg_num = \
fluid.layers.rpn_target_assign(
bbox_pred=bbox_pred_input,
cls_logits=cls_pred_input,
anchor_box=anchor_input,
anchor_var=anchor_var_input,
gt_boxes=gt_box,
gt_labels=gt_label,
is_crowd=is_crowd,
im_info=im_info,
num_classes=self.num_classes - 1,
rpn_batch_size_per_im=self.target_assign.rpn_batch_size_per_im,
rpn_straddle_thresh=self.target_assign.rpn_straddle_thresh,
rpn_fg_fraction=self.target_assign.rpn_fg_fraction,
rpn_positive_overlap=self.target_assign.rpn_positive_overlap,
rpn_negative_overlap=self.target_assign.rpn_negative_overlap,
use_random=self.target_assign.use_random)
fg_num = fluid.layers.reduce_sum(fg_num, name='fg_num')
score_tgt = fluid.layers.cast(score_tgt, 'int32')
loss_cls = fluid.layers.sigmoid_focal_loss(
x=score_pred, label=score_tgt, fg_num=fg_num, gamma=self.gamma, alpha=self.alpha)
loss_cls = fluid.layers.reduce_sum(loss_cls, name='loss_cls')
loss_bbox = fluid.layers.smooth_l1(
x=loc_pred, y=loc_tgt, sigma=self.sigma, inside_weight=bbox_weight, outside_weight=bbox_weight)
loss_bbox = fluid.layers.reduce_sum(loss_bbox, name='loss_bbox')
loss_bbox = loss_bbox / fg_num
return {'loss_cls': loss_cls, 'loss_bbox': loss_bbox}
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from paddle.regularizer import L2Decay
from paddlehub.module.cv_module import Yolov3Module
import paddlehub.process.detect_transforms as T
from paddlehub.module.module import moduleinfo
class ConvBNLayer(nn.Layer):
"""Basic block for Darknet"""
def __init__(self,
ch_in: int,
ch_out: int,
filter_size: int = 3,
stride: int = 1,
groups: int = 1,
padding: int = 0,
act: str = 'leakly',
is_test: bool = False):
super(ConvBNLayer, self).__init__()
self.conv = nn.Conv2d(
ch_in,
ch_out,
filter_size,
padding=padding,
stride=stride,
groups=groups,
weight_attr=paddle.ParamAttr(initializer=Normal(0., 0.02)),
bias_attr=False)
self.batch_norm = nn.BatchNorm(
num_channels=ch_out,
is_test=is_test,
param_attr=paddle.ParamAttr(initializer=Normal(0., 0.02), regularizer=L2Decay(0.)))
self.act = act
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv(inputs)
out = self.batch_norm(out)
if self.act == "leakly":
out = F.leaky_relu(x=out, negative_slope=0.1)
return out
class DownSample(nn.Layer):
"""Downsample block for Darknet"""
def __init__(self,
ch_in: int,
ch_out: int,
filter_size: int = 3,
stride: int = 2,
padding: int = 1,
is_test: bool = False):
super(DownSample, self).__init__()
self.conv_bn_layer = ConvBNLayer(
ch_in=ch_in, ch_out=ch_out, filter_size=filter_size, stride=stride, padding=padding, is_test=is_test)
self.ch_out = ch_out
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv_bn_layer(inputs)
return out
class BasicBlock(nn.Layer):
"""Basic residual block for Darknet"""
def __init__(self, ch_in: int, ch_out: int, is_test: bool = False):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(ch_in=ch_in, ch_out=ch_out, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv2 = ConvBNLayer(ch_in=ch_out, ch_out=ch_out * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
conv1 = self.conv1(inputs)
conv2 = self.conv2(conv1)
out = paddle.elementwise_add(x=inputs, y=conv2, act=None)
return out
class LayerWarp(nn.Layer):
"""Warp layer composed by basic residual blocks"""
def __init__(self, ch_in: int, ch_out: int, count: int, is_test: bool = False):
super(LayerWarp, self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out, is_test=is_test)
self.res_out_list = []
for i in range(1, count):
res_out = self.add_sublayer("basic_block_%d" % (i), BasicBlock(ch_out * 2, ch_out, is_test=is_test))
self.res_out_list.append(res_out)
self.ch_out = ch_out
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
y = self.basicblock0(inputs)
for basic_block_i in self.res_out_list:
y = basic_block_i(y)
return y
class DarkNet53_conv_body(nn.Layer):
"""Darknet53
Args:
ch_in(int): Input channels, default is 3.
is_test (bool): Set the test mode, default is True.
"""
def __init__(self, ch_in: int = 3, is_test: bool = False):
super(DarkNet53_conv_body, self).__init__()
self.stages = [1, 2, 8, 8, 4]
self.stages = self.stages[0:5]
self.conv0 = ConvBNLayer(ch_in=ch_in, ch_out=32, filter_size=3, stride=1, padding=1, is_test=is_test)
self.downsample0 = DownSample(ch_in=32, ch_out=32 * 2, is_test=is_test)
self.darknet53_conv_block_list = []
self.downsample_list = []
ch_in = [64, 128, 256, 512, 1024]
for i, stage in enumerate(self.stages):
conv_block = self.add_sublayer("stage_%d" % (i),
LayerWarp(int(ch_in[i]), 32 * (2**i), stage, is_test=is_test))
self.darknet53_conv_block_list.append(conv_block)
for i in range(len(self.stages) - 1):
downsample = self.add_sublayer(
"stage_%d_downsample" % i, DownSample(
ch_in=32 * (2**(i + 1)), ch_out=32 * (2**(i + 2)), is_test=is_test))
self.downsample_list.append(downsample)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
out = self.conv0(inputs)
out = self.downsample0(out)
blocks = []
for i, conv_block_i in enumerate(self.darknet53_conv_block_list):
out = conv_block_i(out)
blocks.append(out)
if i < len(self.stages) - 1:
out = self.downsample_list[i](out)
return blocks[-1:-4:-1]
class YoloDetectionBlock(nn.Layer):
"""Basic block for Yolov3"""
def __init__(self, ch_in: int, channel: int, is_test: bool = True):
super(YoloDetectionBlock, self).__init__()
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
self.conv0 = ConvBNLayer(ch_in=ch_in, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv1 = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
self.conv2 = ConvBNLayer(ch_in=channel * 2, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.conv3 = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
self.route = ConvBNLayer(ch_in=channel * 2, ch_out=channel, filter_size=1, stride=1, padding=0, is_test=is_test)
self.tip = ConvBNLayer(ch_in=channel, ch_out=channel * 2, filter_size=3, stride=1, padding=1, is_test=is_test)
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
tip = self.tip(route)
return route, tip
class Upsample(nn.Layer):
"""Upsample block for Yolov3"""
def __init__(self, scale: int = 2):
super(Upsample, self).__init__()
self.scale = scale
def forward(self, inputs: paddle.Tensor):
shape_nchw = paddle.to_tensor(inputs.shape)
shape_hw = paddle.slice(shape_nchw, axes=[0], starts=[2], ends=[4])
shape_hw.stop_gradient = True
in_shape = paddle.cast(shape_hw, dtype='int32')
out_shape = in_shape * self.scale
out_shape.stop_gradient = True
out = F.resize_nearest(input=inputs, scale=self.scale, actual_shape=out_shape)
return out
@moduleinfo(
name="yolov3_darknet53_pascalvoc",
type="CV/image_editing",
author="paddlepaddle",
author_email="",
summary="Yolov3 is a detection model, this module is trained with VOC dataset.",
version="1.0.0",
meta=Yolov3Module)
class YOLOv3(nn.Layer):
"""YOLOV3 for detection
Args:
ch_in(int): Input channels, default is 3.
class_num(int): Categories for detection,if dataset is voc, class_num is 20.
ignore_thresh(float): The ignore threshold to ignore confidence loss.
valid_thresh(float): Threshold to filter out bounding boxes with low confidence score.
nms_topk(int): Maximum number of detections to be kept according to the confidences after the filtering
detections based on score_threshold.
nms_posk(int): Number of total bboxes to be kept per image after NMS step. -1 means keeping all bboxes after NMS
step.
nms_thresh (float): The threshold to be used in NMS. Default: 0.3.
is_train (bool): Set the train mode, default is True.
load_checkpoint(str): Whether to load checkpoint.
"""
def __init__(self,
ch_in: int = 3,
class_num: int = 20,
ignore_thresh: float = 0.7,
valid_thresh: float = 0.005,
nms_topk: int = 400,
nms_posk: int = 100,
nms_thresh: float = 0.45,
is_train: bool = True,
load_checkpoint: str = None):
super(YOLOv3, self).__init__()
self.is_train = is_train
self.block = DarkNet53_conv_body(ch_in=ch_in, is_test=not self.is_train)
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks_2 = []
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
self.class_num = class_num
self.ignore_thresh = ignore_thresh
self.valid_thresh = valid_thresh
self.nms_topk = nms_topk
self.nms_posk = nms_posk
self.nms_thresh = nms_thresh
ch_in_list = [1024, 768, 384]
for i in range(3):
yolo_block = self.add_sublayer(
"yolo_detecton_block_%d" % (i),
YoloDetectionBlock(ch_in_list[i], channel=512 // (2**i), is_test=not self.is_train))
self.yolo_blocks.append(yolo_block)
num_filters = len(self.anchor_masks[i]) * (self.class_num + 5)
block_out = self.add_sublayer(
"block_out_%d" % (i),
nn.Conv2d(
1024 // (2**i),
num_filters,
1,
stride=1,
padding=0,
weight_attr=paddle.ParamAttr(initializer=Normal(0., 0.02)),
bias_attr=paddle.ParamAttr(initializer=Constant(0.0), regularizer=L2Decay(0.))))
self.block_outputs.append(block_out)
if i < 2:
route = self.add_sublayer(
"route2_%d" % i,
ConvBNLayer(
ch_in=512 // (2**i),
ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
is_test=(not self.is_train)))
self.route_blocks_2.append(route)
self.upsample = Upsample()
if load_checkpoint is not None:
model_dict = paddle.load(load_checkpoint)[0]
self.set_dict(model_dict)
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'yolov3_darknet53_voc.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://paddlehub.bj.bcebos.com/dygraph/detection/yolov3_darknet53_voc.pdparams -O ' \
+ checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
print("load pretrained checkpoint success")
def transform(self, img):
if self.is_train:
transform = T.Compose([
T.RandomDistort(),
T.RandomExpand(fill=[0.485, 0.456, 0.406]),
T.RandomCrop(),
T.Resize(target_size=416),
T.RandomFlip(),
T.ShuffleBox(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
transform = T.Compose([
T.Resize(target_size=416, interp='CUBIC'),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(img)
def forward(self, inputs: paddle.Tensor):
outputs = []
blocks = self.block(inputs)
route = None
for i, block in enumerate(blocks):
if i > 0:
block = paddle.concat([route, block], axis=1)
route, tip = self.yolo_blocks[i](block)
block_out = self.block_outputs[i](tip)
outputs.append(block_out)
if i < 2:
route = self.route_blocks_2[i](route)
route = self.upsample(route)
return outputs
## 概述
Tencent_AILab_ChineseEmbedding提供了基于海量中文语料训练学习得到的800多万个中文词语和短语的词向量表示,每一个词向量为200维。可以用于各种下游任务迁移学习。
更多详情参考: https://ai.tencent.com/ailab/nlp/en/embedding.html
注:该Module由第三方开发者DesmonDay贡献。
## API
```python
def context(trainable=False, max_seq_len=128, num_slots=1)
```
获取该Module的预训练program以及program相应的输入输出。
**参数**
* trainable(bool): trainable=True表示program中的参数在Fine-tune时需要微调,否则保持不变。
* max_seq_len(int): 模型使用的最大序列长度。
* num_slots(int): 输入到模型所需要的文本个数,如完成单句文本分类任务,则num_slots=1;完成pointwise文本匹配任务,则num_slots=2;完成pairtwise文本匹配任务,则num_slots=3;
**返回**
* inputs(dict): program的输入变量
* outputs(dict): program的输出变量
* main_program(Program): 带有预训练参数的program
### 代码示例
```python
import paddlehub as hub
import cv2
tencent_ailab_chinese_embedding = hub.Module(name="tencent_ailab_chinese_embedding")
inputs, outputs, program = tencent_ailab_chinese_embedding.context(trainable=True, max_seq_len=128, num_slots=1)
```
## 依赖
paddlepaddle >= 1.8.2
paddlehub >= 1.8.0
## 更新历史
* 1.0.0
初始发布
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.common.paddle_helper import add_vars_prefix
from paddlehub.module.module import moduleinfo
def load_vocab(file_path):
"""
load the given vocabulary
"""
vocab = {}
with io.open(file_path, 'r', encoding='utf8') as f:
for line in f:
parts = line.split("\t")
vocab[parts[0]] = int(parts[1])
return vocab
@moduleinfo(
name="tencent_ailab_chinese_embedding",
version="1.0.0",
summary=
"Tencent AI Lab Embedding Corpus for Chinese Words and Phrases and the vocab size is 8,824,331. For more information, please refer to https://ai.tencent.com/ailab/nlp/zh/embedding.html",
author="",
author_email="",
type="nlp/semantic_model")
class TencentAILabChineseEmbedding(hub.Module):
def _initialize(self):
"""
initialize with the necessary elements
"""
self.pretrained_model_path = os.path.join(self.directory, "assets", "model")
self.vocab_path = os.path.join(self.directory, "assets", "vocab.txt")
self.vocab = load_vocab(self.vocab_path)
def context(self, trainable=False, max_seq_len=128, num_slots=1):
"""
Get the input ,output and program of the pretrained tencent_ailab_chinese_embedding
Args:
trainable(bool): whether fine-tune the pretrained parameters of simnet_bow or not
num_slots(int): It's number of slots inputted to the model, selectted as following options:
- 1(default): There's only one data to be feeded in the model, e.g. the module is used for sentence classification task.
- 2: There are two data to be feeded in the model, e.g. the module is used for text matching task (point-wise).
- 3: There are three data to be feeded in the model, e.g. the module is used for text matching task (pair-wise).
Returns:
inputs(dict): the input variables of tencent_ailab_chinese_embedding (words)
outputs(dict): the output variables of input words (word embeddings)
main_program(Program): the main_program of tencent_ailab_chinese_embedding with pretrained prameters
"""
assert num_slots >= 1 and num_slots <= 3, "num_slots must be 1, 2, or 3, but the input is %d" % num_slots
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
w_param_attrs = fluid.ParamAttr(
name="embedding_0.w_0",
initializer=fluid.initializer.TruncatedNormal(scale=0.02),
trainable=trainable)
text_1 = fluid.data(name='text', shape=[-1, max_seq_len], dtype='int64', lod_level=0)
emb_1 = fluid.embedding(
input=text_1,
size=[len(self.vocab), 200],
is_sparse=True,
padding_idx=len(self.vocab) - 1,
dtype='float32',
param_attr=w_param_attrs)
emb_1_name = emb_1.name
data_list = [text_1]
emb_name_list = [emb_1_name]
if num_slots > 1:
text_2 = fluid.data(name='text_2', shape=[-1, max_seq_len], dtype='int64', lod_level=0)
emb_2 = fluid.embedding(
input=text_2,
size=[len(self.vocab), 200],
is_sparse=True,
padding_idx=len(self.vocab) - 1,
dtype='float32',
param_attr=w_param_attrs)
emb_2_name = emb_2.name
data_list.append(text_2)
emb_name_list.append(emb_2_name)
if num_slots > 2:
text_3 = fluid.data(name='text_3', shape=[-1, max_seq_len], dtype='int64', lod_level=0)
emb_3 = fluid.embedding(
input=text_3,
size=[len(self.vocab), 200],
is_sparse=True,
padding_idx=len(self.vocab) - 1,
dtype='float32',
param_attr=w_param_attrs)
emb_3_name = emb_3.name
data_list.append(text_3)
emb_name_list.append(emb_3_name)
variable_names = filter(lambda v: v not in ['text', 'text_2', 'text_3'],
list(main_program.global_block().vars.keys()))
prefix_name = "@HUB_{}@".format(self.name)
add_vars_prefix(program=main_program, prefix=prefix_name, vars=variable_names)
for param in main_program.global_block().iter_parameters():
param.trainable = trainable
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# load the pretrained model
def if_exist(var):
return os.path.exists(os.path.join(self.pretrained_model_path, var.name))
fluid.io.load_vars(exe, self.pretrained_model_path, predicate=if_exist)
inputs = {}
outputs = {}
for index, data in enumerate(data_list):
if index == 0:
inputs['text'] = data
outputs['emb'] = main_program.global_block().vars[prefix_name + emb_name_list[0]]
else:
inputs['text_%s' % (index + 1)] = data
outputs['emb_%s' % (index + 1)] = main_program.global_block().vars[prefix_name +
emb_name_list[index]]
return inputs, outputs, main_program
def get_vocab_path(self):
return self.vocab_path
if __name__ == "__main__":
w2v = TencentAILabChineseEmbedding()
inputs, outputs, program = w2v.context(num_slots=3)
print(inputs)
print(outputs)
print(w2v.get_vocab_path())
## 概述
Tencent_AILab_ChineseEmbedding提供了基于海量中文语料训练学习得到的800多万个中文词语和短语的词向量表示,每一个词向量为200维。
该Module截取了原来词汇表中前200万的词语,同样可以用于各种下游任务迁移学习。
更多详情参考: https://ai.tencent.com/ailab/nlp/en/embedding.html
注:该Module由第三方开发者DesmonDay贡献。
## API
```python
def context(trainable=False, max_seq_len=128, num_slots=1)
```
获取该Module的预训练program以及program相应的输入输出。
**参数**
* trainable(bool): trainable=True表示program中的参数在Fine-tune时需要微调,否则保持不变。
* max_seq_len(int): 模型使用的最大序列长度。
* num_slots(int): 输入到模型所需要的文本个数,如完成单句文本分类任务,则num_slots=1;完成pointwise文本匹配任务,则num_slots=2;完成pairtwise文本匹配任务,则num_slots=3;
**返回**
* inputs(dict): program的输入变量
* outputs(dict): program的输出变量
* main_program(Program): 带有预训练参数的program
### 代码示例
```python
import paddlehub as hub
import cv2
tencent_ailab_chinese_embedding = hub.Module(name="tencent_ailab_chinese_embedding_small")
inputs, outputs, program = tencent_ailab_chinese_embedding.context(trainable=True, max_seq_len=128, num_slots=1)
```
## 依赖
paddlepaddle >= 1.8.2
paddlehub >= 1.8.0
## 更新历史
* 1.0.0
初始发布
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.common.paddle_helper import add_vars_prefix
from paddlehub.module.module import moduleinfo
def load_vocab(file_path):
"""
load the given vocabulary
"""
vocab = {}
with io.open(file_path, 'r', encoding='utf8') as f:
for line in f:
parts = line.split("\t")
vocab[parts[0]] = int(parts[1])
return vocab
@moduleinfo(
name="tencent_ailab_chinese_embedding_small",
version="1.0.0",
summary=
"Tencent AI Lab Embedding Corpus for Chinese Words and Phrases and the vocab size is 2,000,002. For more information, please refer to https://ai.tencent.com/ailab/nlp/zh/embedding.html",
author="",
author_email="",
type="nlp/semantic_model")
class TencentAILabChineseEmbeddingSmall(hub.Module):
def _initialize(self):
"""
initialize with the necessary elements
"""
self.pretrained_model_path = os.path.join(self.directory, "assets", "model")
self.vocab_path = os.path.join(self.directory, "assets", "vocab.txt")
self.vocab = load_vocab(self.vocab_path)
def context(self, trainable=False, max_seq_len=128, num_slots=1):
"""
Get the input ,output and program of the pretrained word2vec_skipgram
Args:
trainable(bool): Whether fine-tune the pretrained parameters of tencent_ailab_chinese_embedding_small or not.
num_slots(int): It's number of data inputted to the model, selectted as following options:
- 1(default): There's only one data to be feeded in the model, e.g. the module is used for sentence classification task.
- 2: There are two data to be feeded in the model, e.g. the module is used for text matching task (point-wise).
- 3: There are three data to be feeded in the model, e.g. the module is used for text matching task (pair-wise).
Returns:
inputs(dict): the input variables of tencent_ailab_chinese_embedding_small (words)
outputs(dict): the output variables of input words (word embeddings)
main_program(Program): the main_program of tencent_ailab_chinese_embedding_small with pretrained prameters
"""
assert num_slots >= 1 and num_slots <= 3, "num_slots must be 1, 2, or 3, but the input is %d" % num_slots
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
w_param_attrs = fluid.ParamAttr(
name="embedding_0.w_0",
initializer=fluid.initializer.TruncatedNormal(scale=0.02),
trainable=trainable)
text_1 = fluid.data(name='text', shape=[-1, max_seq_len], dtype='int64', lod_level=0)
emb_1 = fluid.embedding(
input=text_1,
size=[len(self.vocab), 200],
is_sparse=True,
padding_idx=len(self.vocab) - 1,
dtype='float32',
param_attr=w_param_attrs)
emb_1_name = emb_1.name
data_list = [text_1]
emb_name_list = [emb_1_name]
if num_slots > 1:
text_2 = fluid.data(name='text_2', shape=[-1, max_seq_len], dtype='int64', lod_level=0)
emb_2 = fluid.embedding(
input=text_2,
size=[len(self.vocab), 200],
is_sparse=True,
padding_idx=len(self.vocab) - 1,
dtype='float32',
param_attr=w_param_attrs)
emb_2_name = emb_2.name
data_list.append(text_2)
emb_name_list.append(emb_2_name)
if num_slots > 2:
text_3 = fluid.data(name='text_3', shape=[-1, max_seq_len], dtype='int64', lod_level=0)
emb_3 = fluid.embedding(
input=text_3,
size=[len(self.vocab), 200],
is_sparse=True,
padding_idx=len(self.vocab) - 1,
dtype='float32',
param_attr=w_param_attrs)
emb_3_name = emb_3.name
data_list.append(text_3)
emb_name_list.append(emb_3_name)
variable_names = filter(lambda v: v not in ['text', 'text_2', 'text_3'],
list(main_program.global_block().vars.keys()))
prefix_name = "@HUB_{}@".format(self.name)
add_vars_prefix(program=main_program, prefix=prefix_name, vars=variable_names)
for param in main_program.global_block().iter_parameters():
param.trainable = trainable
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# load the pretrained model
def if_exist(var):
return os.path.exists(os.path.join(self.pretrained_model_path, var.name))
fluid.io.load_vars(exe, self.pretrained_model_path, predicate=if_exist)
inputs = {}
outputs = {}
for index, data in enumerate(data_list):
if index == 0:
inputs['text'] = data
outputs['emb'] = main_program.global_block().vars[prefix_name + emb_name_list[0]]
else:
inputs['text_%s' % (index + 1)] = data
outputs['emb_%s' % (index + 1)] = main_program.global_block().vars[prefix_name +
emb_name_list[index]]
return inputs, outputs, main_program
def get_vocab_path(self):
return self.vocab_path
if __name__ == "__main__":
w2v = TencentAILabChineseEmbeddingSmall()
inputs, outputs, program = w2v.context(num_slots=3)
print(inputs)
print(outputs)
print(w2v.get_vocab_path())
## 概述
ernie_gen_leave是基于ERNIE-GEN进行微调的模型,该模型的主要功能为生成请假条。输出一个关键词,给出你的请假理由。
## 命令行预测
```shell
$ hub run ernie_gen_leave --input_text="理由" --use_gpu True --beam_width 5
```
## API
```python
def generate(texts, use_gpu=False, beam_width=5):
```
预测API,输入关键字给出请假理由。
**参数**
* texts (list\[str\]): 请假关键字;
* use\_gpu (bool): 是否使用 GPU;**若使用GPU,请先设置CUDA\_VISIBLE\_DEVICES环境变量**
* beam\_width: beam search宽度,决定输出多少理由的数量。
**返回**
* results (list\[list\]\[str\]): 输出请假理由。
**代码示例**
```python
import paddlehub as hub
module = hub.Module(name="ernie_gen_leave")
test_texts = ["理由"]
results = module.generate(texts=test_texts, use_gpu=False, beam_width=2)
for result in results:
print(result)
```
## 查看代码
https://github.com/PaddlePaddle/PaddleHub/tree/release/v2.0.0-rc/modules/text/text_generation/ernie_gen_leave
### 依赖
paddlepaddle >= 2.0.0rc1
paddlehub >= 2.0.0rc0
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
from collections import namedtuple
import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D
def gen_bias(encoder_inputs, decoder_inputs, step):
decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2]
attn_bias = L.reshape(L.range(0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1])
decoder_bias = L.cast((L.matmul(attn_bias, 1. / attn_bias, transpose_y=True) >= 1.),
'float32') #[1, 1, decoderlen, decoderlen]
encoder_bias = L.unsqueeze(L.cast(L.ones_like(encoder_inputs), 'float32'), [1]) #[bsz, 1, encoderlen]
encoder_bias = L.expand(encoder_bias, [1, decoder_seqlen, 1]) #[bsz,decoderlen, encoderlen]
decoder_bias = L.expand(decoder_bias, [decoder_bsz, 1, 1]) #[bsz, decoderlen, decoderlen]
if step > 0:
bias = L.concat([encoder_bias, L.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias], -1)
else:
bias = L.concat([encoder_bias, decoder_bias], -1)
return bias
@D.no_grad
def greedy_search_infilling(model,
q_ids,
q_sids,
sos_id,
eos_id,
attn_id,
max_encode_len=640,
max_decode_len=100,
tgt_type_id=3):
model.eval()
_, logits, info = model(q_ids, q_sids)
gen_ids = L.argmax(logits, -1)
d_batch, d_seqlen = q_ids.shape
seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True)
has_stopped = np.zeros([d_batch], dtype=np.bool)
gen_seq_len = np.zeros([d_batch], dtype=np.int64)
output_ids = []
past_cache = info['caches']
cls_ids = L.ones([d_batch], dtype='int64') * sos_id
attn_ids = L.ones([d_batch], dtype='int64') * attn_id
ids = L.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
bias = gen_bias(q_ids, ids, step)
pos_ids = D.to_variable(np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch, 1]))
pos_ids += seqlen
_, logits, info = model(
ids, L.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache)
gen_ids = L.argmax(logits, -1)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [L.concat([pk, k[:, :1, :]], 1) for pk, k in zip(past_cached_k, cached_k)] # concat cached
cached_v = [L.concat([pv, v[:, :1, :]], 1) for pv, v in zip(past_cached_v, cached_v)]
past_cache = (cached_k, cached_v)
gen_ids = gen_ids[:, 1]
ids = L.stack([gen_ids, attn_ids], 1)
gen_ids = gen_ids.numpy()
has_stopped |= (gen_ids == eos_id).astype(np.bool)
gen_seq_len += (1 - has_stopped.astype(np.int64))
output_ids.append(gen_ids.tolist())
if has_stopped.all():
break
output_ids = np.array(output_ids).transpose([1, 0])
return output_ids
BeamSearchState = namedtuple('BeamSearchState', ['log_probs', 'lengths', 'finished'])
BeamSearchOutput = namedtuple('BeamSearchOutput', ['scores', 'predicted_ids', 'beam_parent_ids'])
def log_softmax(x):
e_x = np.exp(x - np.max(x))
return np.log(e_x / e_x.sum())
def mask_prob(p, onehot_eos, finished):
is_finished = L.cast(L.reshape(finished, [-1, 1]) != 0, 'float32')
p = is_finished * (1. - L.cast(onehot_eos, 'float32')) * -9999. + (1. - is_finished) * p
return p
def hyp_score(log_probs, length, length_penalty):
lp = L.pow((5. + L.cast(length, 'float32')) / 6., length_penalty)
return log_probs / lp
def beam_search_step(state, logits, eos_id, beam_width, is_first_step, length_penalty):
"""logits.shape == [B*W, V]"""
beam_size, vocab_size = logits.shape # as batch size=1 in this hub module. the first dim means bsz * beam_size equals beam_size
logits_np = logits.numpy()
for i in range(beam_size):
logits_np[i][17963] = 0 # make [UNK] prob = 0
logits = D.to_variable(logits_np)
bsz, beam_width = state.log_probs.shape
onehot_eos = L.cast(F.one_hot(L.ones([1], 'int64') * eos_id, vocab_size), 'int64') #[1, V]
probs = L.log(L.softmax(logits)) #[B*W, V]
probs = mask_prob(probs, onehot_eos, state.finished) #[B*W, V]
allprobs = L.reshape(state.log_probs, [-1, 1]) + probs #[B*W, V]
not_finished = 1 - L.reshape(state.finished, [-1, 1]) #[B*W,1]
not_eos = 1 - onehot_eos
length_to_add = not_finished * not_eos #[B*W,V]
alllen = L.reshape(state.lengths, [-1, 1]) + length_to_add
allprobs = L.reshape(allprobs, [-1, beam_width * vocab_size])
alllen = L.reshape(alllen, [-1, beam_width * vocab_size])
allscore = hyp_score(allprobs, alllen, length_penalty)
if is_first_step:
allscore = L.reshape(allscore, [bsz, beam_width, -1])[:, 0, :] # first step only consiter beam 0
scores, idx = L.topk(allscore, k=beam_width) #[B, W]
next_beam_id = idx // vocab_size #[B, W]
next_word_id = idx % vocab_size
gather_idx = L.concat([L.where(idx != -1)[:, :1], L.reshape(idx, [-1, 1])], 1)
next_probs = L.reshape(L.gather_nd(allprobs, gather_idx), idx.shape)
next_len = L.reshape(L.gather_nd(alllen, gather_idx), idx.shape)
gather_idx = L.concat([L.where(next_beam_id != -1)[:, :1], L.reshape(next_beam_id, [-1, 1])], 1)
next_finished = L.reshape(L.gather_nd(state.finished, gather_idx),
state.finished.shape) #[gather new beam state according to new beam id]
next_finished += L.cast(next_word_id == eos_id, 'int64')
next_finished = L.cast(next_finished > 0, 'int64')
next_state = BeamSearchState(log_probs=next_probs, lengths=next_len, finished=next_finished)
output = BeamSearchOutput(scores=scores, predicted_ids=next_word_id, beam_parent_ids=next_beam_id)
return output, next_state
@D.no_grad
def beam_search_infilling(model,
q_ids,
q_sids,
sos_id,
eos_id,
attn_id,
max_encode_len=640,
max_decode_len=100,
beam_width=5,
tgt_type_id=3,
length_penalty=1.0):
model.eval()
_, __, info = model(q_ids, q_sids)
d_batch, d_seqlen = q_ids.shape
state = BeamSearchState(
log_probs=L.zeros([d_batch, beam_width], 'float32'),
lengths=L.zeros([d_batch, beam_width], 'int64'),
finished=L.zeros([d_batch, beam_width], 'int64'))
outputs = []
def reorder_(t, parent_id):
"""reorder cache according to parent beam id"""
gather_idx = L.where(parent_id != -1)[:, 0] * beam_width + L.reshape(parent_id, [-1])
t = L.gather(t, gather_idx)
return t
def tile_(t, times):
_shapes = list(t.shape[1:])
ret = L.reshape(L.expand(L.unsqueeze(t, [1]), [
1,
times,
] + [
1,
] * len(_shapes)), [
-1,
] + _shapes)
return ret
cached_k, cached_v = info['caches']
cached_k = [tile_(k, beam_width) for k in cached_k]
cached_v = [tile_(v, beam_width) for v in cached_v]
past_cache = (cached_k, cached_v)
q_ids = tile_(q_ids, beam_width)
seqlen = L.reduce_sum(L.cast(q_ids != 0, 'int64'), 1, keep_dim=True)
cls_ids = L.ones([d_batch * beam_width], dtype='int64') * sos_id
attn_ids = L.ones([d_batch * beam_width], dtype='int64') * attn_id # SOS
ids = L.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
bias = gen_bias(q_ids, ids, step)
pos_ids = D.to_variable(np.tile(np.array([[step, step + 1]], dtype=np.int64), [d_batch * beam_width, 1]))
pos_ids += seqlen
_, logits, info = model(
ids, L.ones_like(ids) * tgt_type_id, pos_ids=pos_ids, attn_bias=bias, past_cache=past_cache)
output, state = beam_search_step(
state,
logits[:, 1],
eos_id=eos_id,
beam_width=beam_width,
is_first_step=(step == 0),
length_penalty=length_penalty)
outputs.append(output)
past_cached_k, past_cached_v = past_cache
cached_k, cached_v = info['caches']
cached_k = [
reorder_(L.concat([pk, k[:, :1, :]], 1), output.beam_parent_ids) for pk, k in zip(past_cached_k, cached_k)
] # concat cached
cached_v = [
reorder_(L.concat([pv, v[:, :1, :]], 1), output.beam_parent_ids) for pv, v in zip(past_cached_v, cached_v)
]
past_cache = (cached_k, cached_v)
pred_ids_flatten = L.reshape(output.predicted_ids, [d_batch * beam_width])
ids = L.stack([pred_ids_flatten, attn_ids], 1)
if state.finished.numpy().all():
break
final_ids = L.stack([o.predicted_ids for o in outputs], 0)
final_parent_ids = L.stack([o.beam_parent_ids for o in outputs], 0)
final_ids = L.gather_tree(final_ids, final_parent_ids) #[:, :,
#0] #pick best beam
final_ids = L.transpose(L.reshape(final_ids, [-1, d_batch * 1, beam_width]), [1, 2, 0])
return final_ids
en_patten = re.compile(r'^[a-zA-Z0-9]*$')
def post_process(token):
if token.startswith('##'):
ret = token[2:]
else:
if en_patten.match(token):
ret = ' ' + token
else:
ret = token
return ret
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from tqdm import tqdm
from paddlehub.common.logger import logger
from paddlehub.common.dir import MODULE_HOME
def _fetch_from_remote(url, force_download=False):
import tempfile, requests, tarfile
cached_dir = os.path.join(MODULE_HOME, "ernie_for_gen")
if force_download or not os.path.exists(cached_dir):
with tempfile.NamedTemporaryFile() as f:
#url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz'
r = requests.get(url, stream=True)
total_len = int(r.headers.get('content-length'))
for chunk in tqdm(
r.iter_content(chunk_size=1024), total=total_len // 1024, desc='downloading %s' % url, unit='KB'):
if chunk:
f.write(chunk)
f.flush()
logger.debug('extacting... to %s' % f.name)
with tarfile.open(f.name) as tf:
tf.extractall(path=cached_dir)
logger.debug('%s cached in %s' % (url, cached_dir))
return cached_dir
def add_docstring(doc):
def func(f):
f.__doc__ += ('\n======other docs from supper class ======\n%s' % doc)
return f
return func
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import logging
import paddle.fluid.dygraph as D
import paddle.fluid as F
import paddle.fluid.layers as L
log = logging.getLogger(__name__)
def _build_linear(n_in, n_out, name, init, act=None):
return D.Linear(
n_in,
n_out,
param_attr=F.ParamAttr(name='%s.w_0' % name if name is not None else None, initializer=init),
bias_attr='%s.b_0' % name if name is not None else None,
act=act)
def _build_ln(n_in, name):
return D.LayerNorm(
normalized_shape=n_in,
param_attr=F.ParamAttr(
name='%s_layer_norm_scale' % name if name is not None else None, initializer=F.initializer.Constant(1.)),
bias_attr=F.ParamAttr(
name='%s_layer_norm_bias' % name if name is not None else None, initializer=F.initializer.Constant(1.)),
)
def append_name(name, postfix):
if name is None:
return None
elif name == '':
return postfix
else:
return '%s_%s' % (name, postfix)
class AttentionLayer(D.Layer):
def __init__(self, cfg, name=None):
super(AttentionLayer, self).__init__()
initializer = F.initializer.TruncatedNormal(scale=cfg['initializer_range'])
d_model = cfg['hidden_size']
n_head = cfg['num_attention_heads']
assert d_model % n_head == 0
d_model_q = cfg.get('query_hidden_size_per_head', d_model // n_head) * n_head
d_model_v = cfg.get('value_hidden_size_per_head', d_model // n_head) * n_head
self.n_head = n_head
self.d_key = d_model_q // n_head
self.q = _build_linear(d_model, d_model_q, append_name(name, 'query_fc'), initializer)
self.k = _build_linear(d_model, d_model_q, append_name(name, 'key_fc'), initializer)
self.v = _build_linear(d_model, d_model_v, append_name(name, 'value_fc'), initializer)
self.o = _build_linear(d_model_v, d_model, append_name(name, 'output_fc'), initializer)
self.dropout = lambda i: L.dropout(
i,
dropout_prob=cfg['attention_probs_dropout_prob'],
dropout_implementation="upscale_in_train",
) if self.training else i
def forward(self, queries, keys, values, attn_bias, past_cache):
assert len(queries.shape) == len(keys.shape) == len(values.shape) == 3
q = self.q(queries)
k = self.k(keys)
v = self.v(values)
cache = (k, v)
if past_cache is not None:
cached_k, cached_v = past_cache
k = L.concat([cached_k, k], 1)
v = L.concat([cached_v, v], 1)
q = L.transpose(L.reshape(q, [0, 0, self.n_head, q.shape[-1] // self.n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
k = L.transpose(L.reshape(k, [0, 0, self.n_head, k.shape[-1] // self.n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
v = L.transpose(L.reshape(v, [0, 0, self.n_head, v.shape[-1] // self.n_head]),
[0, 2, 1, 3]) #[batch, head, seq, dim]
q = L.scale(q, scale=self.d_key**-0.5)
score = L.matmul(q, k, transpose_y=True)
if attn_bias is not None:
score += attn_bias
score = L.softmax(score, use_cudnn=True)
score = self.dropout(score)
out = L.matmul(score, v)
out = L.transpose(out, [0, 2, 1, 3])
out = L.reshape(out, [0, 0, out.shape[2] * out.shape[3]])
out = self.o(out)
return out, cache
class PositionwiseFeedForwardLayer(D.Layer):
def __init__(self, cfg, name=None):
super(PositionwiseFeedForwardLayer, self).__init__()
initializer = F.initializer.TruncatedNormal(scale=cfg['initializer_range'])
d_model = cfg['hidden_size']
d_ffn = cfg.get('intermediate_size', 4 * d_model)
assert cfg['hidden_act'] in ['relu', 'gelu']
self.i = _build_linear(d_model, d_ffn, append_name(name, 'fc_0'), initializer, act=cfg['hidden_act'])
self.o = _build_linear(d_ffn, d_model, append_name(name, 'fc_1'), initializer)
prob = cfg.get('intermediate_dropout_prob', 0.)
self.dropout = lambda i: L.dropout(
i,
dropout_prob=prob,
dropout_implementation="upscale_in_train",
) if self.training else i
def forward(self, inputs):
hidden = self.i(inputs)
hidden = self.dropout(hidden)
out = self.o(hidden)
return out
class ErnieBlock(D.Layer):
def __init__(self, cfg, name=None):
super(ErnieBlock, self).__init__()
d_model = cfg['hidden_size']
initializer = F.initializer.TruncatedNormal(scale=cfg['initializer_range'])
self.attn = AttentionLayer(cfg, name=append_name(name, 'multi_head_att'))
self.ln1 = _build_ln(d_model, name=append_name(name, 'post_att'))
self.ffn = PositionwiseFeedForwardLayer(cfg, name=append_name(name, 'ffn'))
self.ln2 = _build_ln(d_model, name=append_name(name, 'post_ffn'))
prob = cfg.get('intermediate_dropout_prob', cfg['hidden_dropout_prob'])
self.dropout = lambda i: L.dropout(
i,
dropout_prob=prob,
dropout_implementation="upscale_in_train",
) if self.training else i
def forward(self, inputs, attn_bias=None, past_cache=None):
attn_out, cache = self.attn(inputs, inputs, inputs, attn_bias, past_cache=past_cache) #self attn
attn_out = self.dropout(attn_out)
hidden = attn_out + inputs
hidden = self.ln1(hidden) # dropout/ add/ norm
ffn_out = self.ffn(hidden)
ffn_out = self.dropout(ffn_out)
hidden = ffn_out + hidden
hidden = self.ln2(hidden)
return hidden, cache
class ErnieEncoderStack(D.Layer):
def __init__(self, cfg, name=None):
super(ErnieEncoderStack, self).__init__()
n_layers = cfg['num_hidden_layers']
self.block = D.LayerList([ErnieBlock(cfg, append_name(name, 'layer_%d' % i)) for i in range(n_layers)])
def forward(self, inputs, attn_bias=None, past_cache=None):
if past_cache is not None:
assert isinstance(
past_cache,
tuple), 'unknown type of `past_cache`, expect tuple or list. got %s' % repr(type(past_cache))
past_cache = list(zip(*past_cache))
else:
past_cache = [None] * len(self.block)
cache_list_k, cache_list_v, hidden_list = [], [], [inputs]
for b, p in zip(self.block, past_cache):
inputs, cache = b(inputs, attn_bias=attn_bias, past_cache=p)
cache_k, cache_v = cache
cache_list_k.append(cache_k)
cache_list_v.append(cache_v)
hidden_list.append(inputs)
return inputs, hidden_list, (cache_list_k, cache_list_v)
class ErnieModel(D.Layer):
def __init__(self, cfg, name=None):
"""
Fundamental pretrained Ernie model
"""
log.debug('init ErnieModel with config: %s' % repr(cfg))
D.Layer.__init__(self)
d_model = cfg['hidden_size']
d_emb = cfg.get('emb_size', cfg['hidden_size'])
d_vocab = cfg['vocab_size']
d_pos = cfg['max_position_embeddings']
d_sent = cfg.get("sent_type_vocab_size") or cfg['type_vocab_size']
self.n_head = cfg['num_attention_heads']
self.return_additional_info = cfg.get('return_additional_info', False)
initializer = F.initializer.TruncatedNormal(scale=cfg['initializer_range'])
self.ln = _build_ln(d_model, name=append_name(name, 'pre_encoder'))
self.word_emb = D.Embedding([d_vocab, d_emb],
param_attr=F.ParamAttr(
name=append_name(name, 'word_embedding'), initializer=initializer))
self.pos_emb = D.Embedding([d_pos, d_emb],
param_attr=F.ParamAttr(
name=append_name(name, 'pos_embedding'), initializer=initializer))
self.sent_emb = D.Embedding([d_sent, d_emb],
param_attr=F.ParamAttr(
name=append_name(name, 'sent_embedding'), initializer=initializer))
prob = cfg['hidden_dropout_prob']
self.dropout = lambda i: L.dropout(
i,
dropout_prob=prob,
dropout_implementation="upscale_in_train",
) if self.training else i
self.encoder_stack = ErnieEncoderStack(cfg, append_name(name, 'encoder'))
if cfg.get('has_pooler', True):
self.pooler = _build_linear(
cfg['hidden_size'], cfg['hidden_size'], append_name(name, 'pooled_fc'), initializer, act='tanh')
else:
self.pooler = None
self.train()
def eval(self):
if F.in_dygraph_mode():
super(ErnieModel, self).eval()
self.training = False
for l in self.sublayers():
l.training = False
def train(self):
if F.in_dygraph_mode():
super(ErnieModel, self).train()
self.training = True
for l in self.sublayers():
l.training = True
def forward(self,
src_ids,
sent_ids=None,
pos_ids=None,
input_mask=None,
attn_bias=None,
past_cache=None,
use_causal_mask=False):
"""
Args:
src_ids (`Variable` of shape `[batch_size, seq_len]`):
Indices of input sequence tokens in the vocabulary.
sent_ids (optional, `Variable` of shape `[batch_size, seq_len]`):
aka token_type_ids, Segment token indices to indicate first and second portions of the inputs.
if None, assume all tokens come from `segment_a`
pos_ids(optional, `Variable` of shape `[batch_size, seq_len]`):
Indices of positions of each input sequence tokens in the position embeddings.
input_mask(optional `Variable` of shape `[batch_size, seq_len]`):
Mask to avoid performing attention on the padding token indices of the encoder input.
attn_bias(optional, `Variable` of shape `[batch_size, seq_len, seq_len] or False`):
3D version of `input_mask`, if set, overrides `input_mask`; if set not False, will not apply attention mask
past_cache(optional, tuple of two lists: cached key and cached value,
each is a list of `Variable`s of shape `[batch_size, seq_len, hidden_size]`):
cached key/value tensor that will be concated to generated key/value when performing self attention.
if set, `attn_bias` should not be None.
Returns:
pooled (`Variable` of shape `[batch_size, hidden_size]`):
output logits of pooler classifier
encoded(`Variable` of shape `[batch_size, seq_len, hidden_size]`):
output logits of transformer stack
"""
assert len(src_ids.shape) == 2, 'expect src_ids.shape = [batch, sequecen], got %s' % (repr(src_ids.shape))
assert attn_bias is not None if past_cache else True, 'if `past_cache` is specified; attn_bias should not be None'
d_batch = L.shape(src_ids)[0]
d_seqlen = L.shape(src_ids)[1]
if pos_ids is None:
pos_ids = L.reshape(L.range(0, d_seqlen, 1, dtype='int32'), [1, -1])
pos_ids = L.cast(pos_ids, 'int64')
if attn_bias is None:
if input_mask is None:
input_mask = L.cast(src_ids != 0, 'float32')
assert len(input_mask.shape) == 2
input_mask = L.unsqueeze(input_mask, axes=[-1])
attn_bias = L.matmul(input_mask, input_mask, transpose_y=True)
if use_causal_mask:
sequence = L.reshape(L.range(0, d_seqlen, 1, dtype='float32') + 1., [1, 1, -1, 1])
causal_mask = L.cast((L.matmul(sequence, 1. / sequence, transpose_y=True) >= 1.), 'float32')
attn_bias *= causal_mask
else:
assert len(attn_bias.shape) == 3, 'expect attn_bias tobe rank 3, got %r' % attn_bias.shape
attn_bias = (1. - attn_bias) * -10000.0
attn_bias = L.unsqueeze(attn_bias, [1])
attn_bias = L.expand(attn_bias, [1, self.n_head, 1, 1]) # avoid broadcast =_=
attn_bias.stop_gradient = True
if sent_ids is None:
sent_ids = L.zeros_like(src_ids)
src_embedded = self.word_emb(src_ids)
pos_embedded = self.pos_emb(pos_ids)
sent_embedded = self.sent_emb(sent_ids)
embedded = src_embedded + pos_embedded + sent_embedded
embedded = self.dropout(self.ln(embedded))
encoded, hidden_list, cache_list = self.encoder_stack(embedded, attn_bias, past_cache=past_cache)
if self.pooler is not None:
pooled = self.pooler(encoded[:, 0, :])
else:
pooled = None
additional_info = {
'hiddens': hidden_list,
'caches': cache_list,
}
if self.return_additional_info:
return pooled, encoded, additional_info
else:
return pooled, encoded
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as F
import paddle.fluid.layers as L
from .modeling_ernie import ErnieModel
from .modeling_ernie import _build_linear, _build_ln, append_name
class ErnieModelForGeneration(ErnieModel):
def __init__(self, cfg, name=None):
cfg['return_additional_info'] = True
cfg['has_pooler'] = False
super(ErnieModelForGeneration, self).__init__(cfg, name=name)
initializer = F.initializer.TruncatedNormal(scale=cfg['initializer_range'])
d_model = cfg['hidden_size']
d_vocab = cfg['vocab_size']
self.mlm = _build_linear(
d_model, d_model, append_name(name, 'mask_lm_trans_fc'), initializer, act=cfg['hidden_act'])
self.mlm_ln = _build_ln(d_model, name=append_name(name, 'mask_lm_trans'))
self.mlm_bias = L.create_parameter(
dtype='float32',
shape=[d_vocab],
attr=F.ParamAttr(
name=append_name(name, 'mask_lm_out_fc.b_0'), initializer=F.initializer.Constant(value=0.0)),
is_bias=True,
)
def forward(self, src_ids, *args, **kwargs):
tgt_labels = kwargs.pop('tgt_labels', None)
tgt_pos = kwargs.pop('tgt_pos', None)
encode_only = kwargs.pop('encode_only', False)
_, encoded, info = ErnieModel.forward(self, src_ids, *args, **kwargs)
if encode_only:
return None, None, info
elif tgt_labels is None:
encoded = self.mlm(encoded)
encoded = self.mlm_ln(encoded)
logits = L.matmul(encoded, self.word_emb.weight, transpose_y=True) + self.mlm_bias
output_ids = L.argmax(logits, -1)
return output_ids, logits, info
else:
encoded_2d = L.gather_nd(encoded, tgt_pos)
encoded_2d = self.mlm(encoded_2d)
encoded_2d = self.mlm_ln(encoded_2d)
logits_2d = L.matmul(encoded_2d, self.word_emb.weight, transpose_y=True) + self.mlm_bias
if len(tgt_labels.shape) == 1:
tgt_labels = L.reshape(tgt_labels, [-1, 1])
loss = L.reduce_mean(
L.softmax_with_cross_entropy(logits_2d, tgt_labels, soft_label=(tgt_labels.shape[-1] != 1)))
return loss, logits_2d, info
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import re
import logging
from functools import partial
import numpy as np
import io
open = partial(io.open, encoding='utf8')
log = logging.getLogger(__name__)
_max_input_chars_per_word = 100
def _wordpiece(token, vocab, unk_token, prefix='##', sentencepiece_prefix=''):
""" wordpiece: helloworld => [hello, ##world] """
chars = list(token)
if len(chars) > _max_input_chars_per_word:
return [unk_token], [(0, len(chars))]
is_bad = False
start = 0
sub_tokens = []
sub_pos = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start == 0:
substr = sentencepiece_prefix + substr
if start > 0:
substr = prefix + substr
if substr in vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
sub_pos.append((start, end))
start = end
if is_bad:
return [unk_token], [(0, len(chars))]
else:
return sub_tokens, sub_pos
class ErnieTokenizer(object):
def __init__(self,
vocab,
unk_token='[UNK]',
sep_token='[SEP]',
cls_token='[CLS]',
pad_token='[PAD]',
mask_token='[MASK]',
wordpiece_prefix='##',
sentencepiece_prefix='',
lower=True,
encoding='utf8',
special_token_list=[]):
if not isinstance(vocab, dict):
raise ValueError('expect `vocab` to be instance of dict, got %s' % type(vocab))
self.vocab = vocab
self.lower = lower
self.prefix = wordpiece_prefix
self.sentencepiece_prefix = sentencepiece_prefix
self.pad_id = self.vocab[pad_token]
self.cls_id = cls_token and self.vocab[cls_token]
self.sep_id = sep_token and self.vocab[sep_token]
self.unk_id = unk_token and self.vocab[unk_token]
self.mask_id = mask_token and self.vocab[mask_token]
self.unk_token = unk_token
special_tokens = {pad_token, cls_token, sep_token, unk_token, mask_token} | set(special_token_list)
pat_str = ''
for t in special_tokens:
if t is None:
continue
pat_str += '(%s)|' % re.escape(t)
pat_str += r'([a-zA-Z0-9]+|\S)'
log.debug('regex: %s' % pat_str)
self.pat = re.compile(pat_str)
self.encoding = encoding
def tokenize(self, text):
if len(text) == 0:
return []
if six.PY3 and not isinstance(text, six.string_types):
text = text.decode(self.encoding)
if six.PY2 and isinstance(text, str):
text = text.decode(self.encoding)
res = []
for match in self.pat.finditer(text):
match_group = match.group(0)
if match.groups()[-1]:
if self.lower:
match_group = match_group.lower()
words, _ = _wordpiece(
match_group,
vocab=self.vocab,
unk_token=self.unk_token,
prefix=self.prefix,
sentencepiece_prefix=self.sentencepiece_prefix)
else:
words = [match_group]
res += words
return res
def convert_tokens_to_ids(self, tokens):
return [self.vocab.get(t, self.unk_id) for t in tokens]
def truncate(self, id1, id2, seqlen):
len1 = len(id1)
len2 = len(id2)
half = seqlen // 2
if len1 > len2:
len1_truncated, len2_truncated = max(half, seqlen - len2), min(half, len2)
else:
len1_truncated, len2_truncated = min(half, seqlen - len1), max(half, seqlen - len1)
return id1[:len1_truncated], id2[:len2_truncated]
def build_for_ernie(self, text_id, pair_id=[]):
"""build sentence type id, add [CLS] [SEP]"""
text_id_type = np.zeros_like(text_id, dtype=np.int64)
ret_id = np.concatenate([[self.cls_id], text_id, [self.sep_id]], 0)
ret_id_type = np.concatenate([[0], text_id_type, [0]], 0)
if len(pair_id):
pair_id_type = np.ones_like(pair_id, dtype=np.int64)
ret_id = np.concatenate([ret_id, pair_id, [self.sep_id]], 0)
ret_id_type = np.concatenate([ret_id_type, pair_id_type, [1]], 0)
return ret_id, ret_id_type
def encode(self, text, pair=None, truncate_to=None):
text_id = np.array(self.convert_tokens_to_ids(self.tokenize(text)), dtype=np.int64)
text_id_type = np.zeros_like(text_id, dtype=np.int64)
if pair is not None:
pair_id = np.array(self.convert_tokens_to_ids(self.tokenize(pair)), dtype=np.int64)
else:
pair_id = []
if truncate_to is not None:
text_id, pair_id = self.truncate(text_id, [] if pair_id is None else pair_id, truncate_to)
ret_id, ret_id_type = self.build_for_ernie(text_id, pair_id)
return ret_id, ret_id_type
# coding:utf-8
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import json
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.module.module import runnable
from paddlehub.compat.module.nlp_module import DataFormatError
from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, serving
import argparse
import os
import numpy as np
import paddle.fluid.dygraph as D
from .model.tokenizing_ernie import ErnieTokenizer
from .model.decode import beam_search_infilling
from .model.modeling_ernie_gen import ErnieModelForGeneration
@moduleinfo(
name="ernie_gen_leave",
version="1.0.0",
summary="",
author="彭兆帅,郑博培",
author_email="1084667371@qq.com,2733821739@qq.com",
type="nlp/text_generation",
)
class ErnieGen(hub.NLPPredictionModule):
def _initialize(self):
"""
initialize with the necessary elements
"""
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
with open(ernie_cfg_path, encoding='utf8') as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
with open(ernie_vocab_path, encoding='utf8') as ernie_vocab_file:
ernie_vocab = {j.strip().split('\t')[0]: i for i, j in enumerate(ernie_vocab_file.readlines())}
with fluid.dygraph.guard(fluid.CPUPlace()):
with fluid.unique_name.guard():
self.model = ErnieModelForGeneration(ernie_cfg)
finetuned_states, _ = D.load_dygraph(gen_checkpoint_path)
self.model.set_dict(finetuned_states)
self.tokenizer = ErnieTokenizer(ernie_vocab)
self.rev_dict = {v: k for k, v in self.tokenizer.vocab.items()}
self.rev_dict[self.tokenizer.pad_id] = '' # replace [PAD]
self.rev_dict[self.tokenizer.unk_id] = '' # replace [PAD]
self.rev_lookup = np.vectorize(lambda i: self.rev_dict[i])
@serving
def generate(self, texts, use_gpu=False, beam_width=5):
"""
Get the predict result from the input texts.
Args:
texts(list): the input texts.
use_gpu(bool): whether use gpu to predict or not
beam_width(int): the beam search width.
Returns:
results(list): the predict result.
"""
if texts and isinstance(texts, list) and all(texts) and all([isinstance(text, str) for text in texts]):
predicted_data = texts
else:
raise ValueError("The input texts should be a list with nonempty string elements.")
if use_gpu and "CUDA_VISIBLE_DEVICES" not in os.environ:
use_gpu = False
logger.warning(
"use_gpu has been set False as you didn't set the environment variable CUDA_VISIBLE_DEVICES while using use_gpu=True"
)
if use_gpu:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
self.model.eval()
results = []
for text in predicted_data:
sample_results = []
ids, sids = self.tokenizer.encode(text)
src_ids = D.to_variable(np.expand_dims(ids, 0))
src_sids = D.to_variable(np.expand_dims(sids, 0))
output_ids = beam_search_infilling(
self.model,
src_ids,
src_sids,
eos_id=self.tokenizer.sep_id,
sos_id=self.tokenizer.cls_id,
attn_id=self.tokenizer.vocab['[MASK]'],
max_decode_len=50,
max_encode_len=50,
beam_width=beam_width,
tgt_type_id=1)
output_str = self.rev_lookup(output_ids[0].numpy())
for ostr in output_str.tolist():
if '[SEP]' in ostr:
ostr = ostr[:ostr.index('[SEP]')]
sample_results.append("".join(ostr))
results.append(sample_results)
return results
def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu', type=ast.literal_eval, default=False, help="whether use GPU for prediction")
self.arg_config_group.add_argument('--beam_width', type=int, default=5, help="the beam search width")
@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(
description='Run the %s module.' % self.name,
prog='hub run %s' % self.name,
usage='%(prog)s',
add_help=True)
self.arg_input_group = self.parser.add_argument_group(title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options", description="Run configuration for controlling module behavior, optional.")
self.add_module_config_arg()
self.add_module_input_arg()
args = self.parser.parse_args(argvs)
try:
input_data = self.check_input_data(args)
except DataFormatError and RuntimeError:
self.parser.print_help()
return None
results = self.generate(texts=input_data, use_gpu=args.use_gpu, beam_width=args.beam_width)
return results
import paddlehub as hub
module = hub.Module(name="ernie_gen_leave")
test_texts = ["理由"]
results = module.generate(texts=test_texts, use_gpu=False, beam_width=2)
for result in results:
print(result)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册