提交 950914d8 编写于 作者: Y Yibing Liu

Merge branch 'develop' of upstream into transformer_pf

......@@ -13,6 +13,7 @@ import priorbox
import permute
import detection_out
import normalize
import select
#custom layer import ends
......
""" A custom layer for 'detectionout' used in 'SSD' model to produce outputs
Note: Since Paddle's implementation of 'detectionout' applied 'flatten' and 'softmax' ops on the input of 'conf',
while Caffe's implementation do not. Hence, you should ajust generated 'ssd.py' to remove 'softmax' and 'flatten' ops applied on 'conf' input.
"""
from .register import register
def detectionoutput_shape(input_shape):
""" the output shape of this layer is dynamic and not determined by 'input_shape'
Args:
@input_shape (list of int): input shape
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
output_shape = [-1, 6]
return output_shape
def detectionoutput_layer(inputs,
name,
background_label=0,
share_location=True,
nms_param=None,
keep_top_k=100,
confidence_threshold=0.1):
""" build a layer of type 'detectionout' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
if nms_param is None:
nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
mbox_conf_flatten = inputs[1]
mbox_priorbox = inputs[2]
mbox_priorbox_list = fluid.layers.split(mbox_priorbox, 2, dim=1)
pb = mbox_priorbox_list[0]
pbv = mbox_priorbox_list[1]
pb = fluid.layers.reshape(x=pb, shape=[-1, 4])
pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4])
mbox_loc = inputs[0]
mbox_loc = fluid.layers.reshape(
x=mbox_loc, shape=[-1, mbox_conf_flatten.shape[1], 4])
default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
fields = ['eta', 'top_k', 'nms_threshold']
for f in default.keys():
if not nms_param.has_key(f):
nms_param[f] = default[f]
nmsed_outs = fluid.layers.detection_output(
scores=mbox_conf_flatten,
loc=mbox_loc,
prior_box=pb,
prior_box_var=pbv,
background_label=background_label,
nms_threshold=nms_param["nms_threshold"],
nms_top_k=nms_param["top_k"],
keep_top_k=keep_top_k,
score_threshold=confidence_threshold,
nms_eta=nms_param["eta"])
return nmsed_outs
register(
kind='DetectionOutput',
shape=detectionoutput_shape,
layer=detectionoutput_layer)
""" A custom layer for 'normalize' op
"""
from .register import register
def normalize_shape(input_shape,
across_spatial=True,
scale_filler=True,
eps=1e-10):
""" calculate the output shape of this layer using input shapes
Args:
@input_shape (list of tuples): input shape
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
output_shape = input_shape
return output_shape
def normalize_layer(input,
name,
across_spatial=True,
scale_filler=True,
channel_shared=False,
eps=1e-10):
""" build a layer of type 'normalize' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
param_prefix = name.split('.')[0]
assert across_spatial == False, "Only support across_spatial == False for Normalize[%s]" % (
name)
l2_norm = fluid.layers.l2_normalize(input, axis=1) # l2 norm along channel
shape = [1] if channel_shared else [input.shape[1]]
scale_attr = fluid.ParamAttr(name=param_prefix + '_scale')
scale_param = fluid.layers.create_parameter(
shape=shape, dtype=input.dtype, name=name, attr=scale_attr)
out = fluid.layers.elementwise_mul(
x=l2_norm, y=scale_param, axis=-1 if channel_shared else 1)
return out
register(kind='Normalize', shape=normalize_shape, layer=normalize_layer)
""" A custom layer for 'Permute' which is equivalent to transpose in paddle
"""
from .register import register
def permute_shape(input_shape, order):
""" calculate the output shape of this layer using input shapes
Args:
@input_shape (list of numbers): input shape
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
output_shape = []
for ii in order:
assert ii < len(input_shape), "invalid order for permute[%s]" % (name)
output_shape.append(input_shape[ii])
return output_shape
def permute_layer(input, name, order):
""" build a layer of type 'permute' using fluid
Args:
@input (input variable): input fluid variables for this layer
@name (str): name for this layer
@order (list of int): order to permute the dims
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
output = fluid.layers.transpose(input, order, name=name)
return output
register(kind='Permute', shape=permute_shape, layer=permute_layer)
""" A custom layer for 'priorbox' which is used in ssd to generate prior box info
Since the order of prior box is different between caffe and paddle,
we use 'slice' and 'concate' ops to align them.
"""
from .register import register
def priorbox_shape(input_shapes, min_size, max_size=None, aspect_ratio=None):
""" calculate the output shape of this layer using input shapes
Args:
@input_shapes (list of tuples): a list of input shapes
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
assert len(input_shapes) == 2, "invalid inputs for Priorbox[%s]" % (name)
fc_shape = input_shapes[0]
N = 1
if not max_size == None:
N += 1
if not aspect_ratio == None:
N += 2 * len(aspect_ratio)
N_bbx = fc_shape[2] * fc_shape[3] * N
output_shape = [1, 2, 4 * N_bbx]
return output_shape
def priorbox_layer(inputs,
name,
min_size,
step,
max_size=None,
aspect_ratio=None,
flip=True,
clip=False,
variance=[],
offset=0.5):
""" build a layer of type 'Priorbox' using fluid
Args:
@inputs (list of variables): input fluid variables for this layer
@name (str): name for this layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
assert len(inputs) == 2, "invalid inputs for Priorbox[%s]" % (name)
input = inputs[0]
image = inputs[1]
box, variance_ = fluid.layers.prior_box(
input,
image,
min_size,
max_size,
aspect_ratio,
variance,
flip,
clip, (step, step),
offset,
min_max_aspect_ratios_order=True)
"""
#adjust layout when the output is not consistent with caffe's
feat_shape = list(input.shape)
H = feat_shape[2]
W = feat_shape[3]
box_tmp = fluid.layers.reshape(box, [H, W, -1, 4])
nb_prior_bbx = int(box_tmp.shape[2])
tensor_list = fluid.layers.split(box_tmp, nb_prior_bbx, 2)
#TODO:
# current implementation for this layer is not efficient
# and we should fix this bug in future when Paddle support the same prior-box layout with Caffe
index_list = [0]
index_list = index_list * nb_prior_bbx
index_offset = 0
if max_size is not None:
index_list[1] = -1
index_offset = 1
for ii in xrange(2 * len(aspect_ratio)):
index_list[ii + 1 + index_offset] = ii + 1
tensor_list_gathered = [tensor_list[ii] for ii in index_list]
caffe_prior_bbx = fluid.layers.concat(tensor_list_gathered, axis=2)
box = fluid.layers.reshape(caffe_prior_bbx, [1, 1, -1])
"""
box = fluid.layers.reshape(box, [1, 1, -1])
variance_ = fluid.layers.reshape(variance_, [1, 1, -1])
output = fluid.layers.concat([box, variance_], axis=1)
return output
register(kind='PriorBox', shape=priorbox_shape, layer=priorbox_layer)
""" a custom layer for 'ROIPooling', maybe we should implement this in standard way.
more info can be found here: http://caffe.berkeleyvision.org/tutorial/layers/ROIPooling.html
"""
from .register import register
def roipooling_shape(input_shapes, pooled_h, pooled_w, spatial_scale):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@out_max_val (bool): parameter from caffe's ROIPooling layer
@top_k (int): parameter from caffe's ROIPooling layer
@axis (int): parameter from caffe's ROIPooling layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
assert len(input_shapes) == 2, "not valid input shape for roipooling layer"
base_fea_shape = input_shapes[0]
rois_shape = input_shapes[1]
output_shape = base_fea_shape
output_shape[0] = rois_shape[0]
output_shape[2] = pooled_h
output_shape[3] = pooled_w
return output_shape
def roipooling_layer(inputs, name, pooled_h, pooled_w, spatial_scale):
""" build a layer of type 'ROIPooling' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@out_max_val (bool): parameter from caffe's ROIPooling layer
@top_k (int): parameter from caffe's ROIPooling layer
@axis (int): parameter from caffe's ROIPooling layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
assert len(inputs) == 2, "not valid input shape for roipooling layer"
base_fea = inputs[0]
rois = inputs[1][:, 1:5]
rois_fea = fluid.layers.roi_pool(base_fea, rois, pooled_h, pooled_w,
spatial_scale)
return rois_fea
register(kind='ROIPooling', shape=roipooling_shape, layer=roipooling_layer)
""" a custom layer for 'select' which is used to replace standard 'Slice' layer
for converting layer with multiple different output tensors
"""
from .register import register
def select_shape(input_shape, slice_point, axis=1):
""" calculate the output shape of this layer using input shape
Args:
@input_shape (list of num): a list of number which represents the input shape
@slice_point (list): parameter from caffe's Slice layer
@axis (int): parameter from caffe's Slice layer
Returns:
@output_shape (list of num): a list of numbers represent the output shape
"""
input_shape = list(input_shape)
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
assert end > start, "invalid slice_point with [start:%d, end:%d]"\
% (start, end)
output_shape = input_shape
output_shape[axis] = end - start
return output_shape
def select_layer(input, name, slice_point, axis=1):
""" build a layer of type 'Slice' using fluid
Args:
@input (variable): input fluid variable for this layer
@name (str): name for this layer
@slice_point (list): parameter from caffe's Slice layer
@axis (int): parameter from caffe's Slice layer
Returns:
output (variable): output variable for this layer
"""
import paddle.fluid as fluid
input_shape = list(input.shape)
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
sections = []
if start > 0:
sections.append(start)
pos = len(sections)
sections.append(end - start)
if end != input_shape[axis]:
sections.append(input_shape[axis] - end)
outputs = fluid.layers.split(input, sections, dim=axis, name=name)
return outputs[pos]
register(kind='Select', shape=select_shape, layer=select_layer)
......@@ -16,7 +16,7 @@ LAYER_DESCRIPTORS = {
'Concat': shape_concat,
'ContrastiveLoss': shape_scalar,
'Convolution': shape_convolution,
'Deconvolution': shape_not_implemented,
'Deconvolution': shape_deconvolution,
'Data': shape_data,
'Dropout': shape_identity,
'DummyData': shape_data,
......@@ -181,6 +181,8 @@ class LayerAdapter(object):
name = NodeDispatch.get_handler_name(self.kind)
if self.kind.lower() == "normalize":
name = "norm"
elif self.kind.lower() == "deconvolution":
name = "convolution"
name = '_'.join((name, 'param'))
try:
......@@ -210,7 +212,9 @@ class LayerAdapter(object):
@property
def kernel_parameters(self):
assert self.kind in (NodeKind.Convolution, NodeKind.Pooling)
assert self.kind in (NodeKind.Convolution, NodeKind.Pooling,\
NodeKind.Deconvolution)
params = self.parameters
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0)
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1)
......@@ -222,7 +226,7 @@ class LayerAdapter(object):
p_w = self.get_kernel_value(params.pad_w, params.pad, 1, default=0)
dila_h = dila_w = 1
if self.kind in (NodeKind.Convolution, ):
if self.kind in (NodeKind.Convolution, NodeKind.Deconvolution):
dila_len = len(params.dilation)
if dila_len == 2:
dila_h = params.dilation[0]
......
......@@ -185,6 +185,58 @@ class Network(object):
return output
@layer
def deconv(self,
input,
k_h,
k_w,
c_o,
s_h,
s_w,
name,
relu=True,
relu_negative_slope=0.0,
padding=None,
dilation=1,
biased=True):
if padding is None:
padding = [0, 0]
# Get the number of channels in the input
c_i, h_i, w_i = input.shape[1:]
fluid = import_fluid()
prefix = name + '_'
leaky_relu = False
act = 'relu'
if relu is False:
act = None
elif relu_negative_slope != 0.0:
leaky_relu = True
act = None
p_h = padding[0]
p_w = padding[1]
h_o = (h_i - 1) * s_h - 2 * p_h + dilation * (k_h - 1) + 1
w_o = (w_i - 1) * s_w - 2 * p_w + dilation * (k_w - 1) + 1
output = fluid.layers.conv2d_transpose(
name=self.get_unique_output_name(name, 'conv2d_transpose'),
input=input,
num_filters=c_o,
output_size=[h_o, w_o],
filter_size=[k_h, k_w],
padding=padding,
stride=[s_h, s_w],
dilation=dilation,
param_attr=fluid.ParamAttr(name=prefix + "weights"),
bias_attr=fluid.ParamAttr(name=prefix + "biases"),
act=act)
if leaky_relu:
output = fluid.layers.leaky_relu(output, alpha=relu_negative_slope)
return output
@layer
def relu(self, input, name):
fluid = import_fluid()
......@@ -258,6 +310,12 @@ class Network(object):
return fluid.layers.sigmoid(
input, name=self.get_unique_output_name(name, 'sigmoid'))
@layer
def tanh(self, input, name):
fluid = import_fluid()
return fluid.layers.tanh(
input, name=self.get_unique_output_name(name, 'tanh'))
@layer
def lrn(self, input, radius, alpha, beta, name, bias=1.0):
fluid = import_fluid()
......
......@@ -91,6 +91,24 @@ class PaddleMapper(NodeMapper):
'conv', kernel_params.kernel_h, kernel_params.kernel_w, c_o,
kernel_params.stride_h, kernel_params.stride_w, **kwargs)
def map_deconvolution(self, node):
(kernel_params, kwargs) = self.get_kernel_params(node)
h = kernel_params.kernel_h
w = kernel_params.kernel_w
c_o = node.output_shape[1]
c_i = node.parents[0].output_shape[1]
if not node.parameters.bias_term:
kwargs['biased'] = False
if kernel_params.dila_h != 1 or kernel_params.dila_w != 1:
kwargs['dilation'] = (kernel_params.dila_h, kernel_params.dila_w)
assert kernel_params.kernel_h == h
assert kernel_params.kernel_w == w
return MaybeActivated(node)(
'deconv', kernel_params.kernel_h, kernel_params.kernel_w, c_o,
kernel_params.stride_h, kernel_params.stride_w, **kwargs)
def map_relu(self, node):
return PaddleNode('relu')
......
......@@ -105,6 +105,34 @@ def shape_convolution(node):
return get_strided_kernel_output_shape(node, math.floor)
def shape_deconvolution(node):
assert node.layer is not None
input_shape = node.get_only_parent().output_shape
h_i = input_shape.height
w_i = input_shape.width
params = node.layer.kernel_parameters
p_h = params.pad_h
p_w = params.pad_w
dila_h = params.dila_h
dila_w = params.dila_w
k_h = params.kernel_h
k_w = params.kernel_w
s_h = params.stride_h
s_w = params.stride_w
h_o = (h_i - 1) * s_h - 2 * p_h + dila_h * (k_h - 1) + 1
w_o = (w_i - 1) * s_w - 2 * p_w + dila_w * (k_w - 1) + 1
params = node.layer.parameters
has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape.channels
return make_tensor(input_shape.batch_size, c, h_o, w_o)
def shape_pool(node):
global_pool = getattr(node.layer.parameters, 'global_pooling', False)
if global_pool:
......
......@@ -325,7 +325,8 @@ class ParameterNamer(object):
for node in graph.nodes:
if node.data is None:
continue
if node.kind in (NodeKind.Convolution, NodeKind.InnerProduct):
if node.kind in (NodeKind.Convolution, NodeKind.InnerProduct,\
NodeKind.Deconvolution):
names = ('weights', )
if node.parameters.bias_term:
names += ('biases', )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册