未验证 提交 13320626 编写于 作者: J Jason 提交者: GitHub

Merge pull request #1 from PaddlePaddle/develop

Develop
# X2Paddle # X2Paddle
X2Paddle is a toolkit for converting trained model to PaddlePaddle from other deep learning frameworks. 支持主流深度学习框架模型转换至PaddlePaddle(飞桨) X2Paddle支持将其余深度学习框架训练得到的模型,转换至PaddlePaddle模型。
X2Paddle is a toolkit for converting trained model to PaddlePaddle from other deep learning frameworks.
## Requirements
python >= 3.5
paddlepaddle >= 1.5.0
tensorflow == 1.x
## Installation
```
pip install git+https://github.com/PaddlePaddle/X2Paddle.git@develop
```
## How To Use
```
x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model
```
## 转换tensorflow vgg_16模型 ## 转换tensorflow vgg_16模型
...@@ -40,12 +56,7 @@ with tf.Session() as sess: ...@@ -40,12 +56,7 @@ with tf.Session() as sess:
### 步骤三 模型转换 ### 步骤三 模型转换
``` ```
git clone https://github.com/PaddlePaddle/X2Paddle.git x2paddle --framework=tensorflow \
cd X2Paddle
git checkout develop
export PYTHONPATH=${PWD}
mkdir paddle_model
python x2paddle/convert.py --framework=tensorflow \
--model=../vgg16.pb \ --model=../vgg16.pb \
--save_dir=paddle_model --save_dir=paddle_model
``` ```
...@@ -60,12 +71,7 @@ wget https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/deploy. ...@@ -60,12 +71,7 @@ wget https://github.com/DeepScale/SqueezeNet/blob/master/SqueezeNet_v1.1/deploy.
### 步骤二 模型转换 ### 步骤二 模型转换
``` ```
git clone https://github.com/PaddlePaddle/X2Paddle.git x2paddle --framework=caffe \
cd X2Paddle
git checkout develop
export PYTHONPATH=${PWD}:$PYTHONPATH
mkdir paddle_model
python x2paddle/convert.py --framework=caffe \
--weight=../squeezenet_v1.1.caffemodel \ --weight=../squeezenet_v1.1.caffemodel \
--proto =../deploy.prototxt \ --proto =../deploy.prototxt \
--save_dir=paddle_model --save_dir=paddle_model
...@@ -60,11 +60,14 @@ class CaffeResolver(object): ...@@ -60,11 +60,14 @@ class CaffeResolver(object):
class CaffeGraphNode(GraphNode): class CaffeGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, layer_name=None):
if layer_name is None: if layer_name is None:
super(CaffeGraphNode, self).__init__(layer, layer.name.replace('/', '_')) super(CaffeGraphNode, self).__init__(layer,
layer.name.replace('/', '_'))
else: else:
super(CaffeGraphNode, self).__init__(layer, layer_name.replace('/', '_')) super(CaffeGraphNode, self).__init__(layer,
layer_name.replace('/', '_'))
self.layer_type = layer.type self.layer_type = layer.type
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.data = None
def set_params(self, params): def set_params(self, params):
self.data = params self.data = params
...@@ -117,6 +120,7 @@ class CaffeGraph(Graph): ...@@ -117,6 +120,7 @@ class CaffeGraph(Graph):
inputs_num = len(self.model.input) inputs_num = len(self.model.input)
if inputs_num != 0: if inputs_num != 0:
input_dims_num = len(self.model.input_dim) input_dims_num = len(self.model.input_dim)
if input_dims_num != 0:
if input_dims_num > 0 and input_dims_num != inputs_num * 4: if input_dims_num > 0 and input_dims_num != inputs_num * 4:
raise Error('invalid input_dim[%d] param in prototxt' % raise Error('invalid input_dim[%d] param in prototxt' %
(input_dims_num)) (input_dims_num))
...@@ -130,11 +134,28 @@ class CaffeGraph(Graph): ...@@ -130,11 +134,28 @@ class CaffeGraph(Graph):
dim=[dims[0], dims[1], dims[2], dims[3] dim=[dims[0], dims[1], dims[2], dims[3]
]))).to_proto().layer[0]) ]))).to_proto().layer[0])
except: except:
raise Error( raise ImportError(
'You must install the caffe first when you use old style prototxt.' 'You must install the caffe first when you use old style prototxt.'
) )
data.name = self.model.input[0] data.name = self.model.input[i]
data.top[0] = self.model.input[0] data.top[0] = self.model.input[i]
else:
for i in range(inputs_num):
dims = self.model.input_shape[i].dim[0:4]
data = self.model.layer.add()
try:
from caffe import layers as L
data.CopyFrom(
L.Input(input_param=dict(shape=dict(
dim=[dims[0], dims[1], dims[2], dims[3]
]))).to_proto().layer[0])
except:
raise ImportError(
'You must install the caffe first when you use old style prototxt.'
)
data.name = self.model.input[i]
data.top[0] = self.model.input[i]
layers = [data] + layers
top_layer = {} top_layer = {}
for layer in layers: for layer in layers:
......
...@@ -169,6 +169,7 @@ def shape_softmax(layer, input_shape): ...@@ -169,6 +169,7 @@ def shape_softmax(layer, input_shape):
def shape_input(layer, input_shape): def shape_input(layer, input_shape):
return [list(layer.input_param.shape[0].dim)] return [list(layer.input_param.shape[0].dim)]
def shape_concat(layer, input_shape): def shape_concat(layer, input_shape):
params = layer.concat_param params = layer.concat_param
axis = params.axis axis = params.axis
...@@ -179,3 +180,276 @@ def shape_concat(layer, input_shape): ...@@ -179,3 +180,276 @@ def shape_concat(layer, input_shape):
else: else:
output_shape[axis] += shape[axis] output_shape[axis] += shape[axis]
return [output_shape] return [output_shape]
def shape_slice(layer, input_shape):
inshape = input_shape[0]
params = layer.slice_param
axis = params.axis
count = inshape[axis]
points = list(params.slice_point)
points = [0] + points + [count]
output_shape = []
for i in range(len(points)):
shape = inshape
size = points[i + 1] - points[i]
shape[axis] = size
output_shape.append(shape)
if i == len(points) - 2:
break
return output_shape
def shape_prelu(layer, input_shape):
return input_shape
def shape_sigmoid(layer, input_shape):
return input_shape
def shape_absval(layer, input_shape):
return input_shape
def shape_accuracy(layer, input_shape):
return [[1]]
def shape_tanh(layer, input_shape):
return input_shape
def shape_eltwise(layer, input_shape):
return [input_shape[0]]
def shape_batchnorm(layer, input_shape):
return input_shape
def shape_scale(layer, input_shape):
return input_shape
def shape_reshape(layer, input_shape):
def count(num_list):
return reduce(lambda a, b: a * b, num_list)
inshape = input_shape[0]
params = layer.reshape_param
axis = params.axis if hasattr(params, axis) else 0
num_axes = params.num_axes if hasattr(params, num_axes) else -1
if inshape[0] == -1:
inshape[0] = 1
input_count = count(inshape)
input_num_axes = len(inshape)
input_start_axis = axis
start_axis = input_start_axis if input_start_axis >= 0 \
else input_num_axes + input_start_axis + 1
assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis)
assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\
% (input_start_axis, input_num_axes)
assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all"
end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes
assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\
% (end_axis, start_axis, num_axes)
num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(shape['dim'])
outshape = []
for i in range(start_axis):
outshape.append(inshape[i])
for i in range(num_new_axes):
outshape.append(shape['dim'][i])
for i in range(end_axis, input_num_axes):
outshape.append(inshape[i])
assert len(outshape) == num_axes_retained + num_new_axes,\
"[Reshape]invalid dims of output shape[%s]" % (str(outshape))
inferred_axis = -1
copy_axes = []
constant_count = 1
for i in range(num_new_axes):
top_dim = shape['dim'][i]
if top_dim == 0:
copy_axes.append(i)
copy_axis_index = start_axis + i
outshape[copy_axis_index] = inshape[copy_axis_index]
elif top_dim == -1:
assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims"
inferred_axis = i
else:
constant_count *= top_dim
if inferred_axis >= 0:
explicit_count = constant_count
l = inshape[0:start_axis]
if len(l) > 0:
explicit_count *= count(l)
l = inshape[end_axis:]
if len(l) > 0:
explicit_count *= count(l)
for i in range(len(copy_axes)):
explicit_count *= outshape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count)
outshape[start_axis + inferred_axis] = input_count / explicit_count
output_count = count(outshape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count)
if inshape[0] == -1:
outshape[0] = -1
return [outshape]
def shape_argmax(layer, input_shape):
inshape = input_shape[0]
params = layer.argmax_param
out_max_val = params.out_max_val if hasattr(params, out_max_val) else False
top_k = params.top_k if hasattr(params, top_k) else 1
axis = parmas.axis if hasattr(params, axis) else -1
if axis < 0:
axis += len(inshape)
assert (axis + 1 == len(inshape)
), 'only can be applied on the last dimension[axis:%d, %s] now,'\
'make sure you have set axis param in xxx.prototxt file' \
% (axis, str(inshape))
outshape = inshape
outshape[-1] = top_k
if out_max_val is True:
outshape[-1] *= 2
return [outshape]
def shape_axpy(layer, input_shape):
assert len(input_shapes) == 3, "not valid input shape for axpy layer"
assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims'
output_shape = input_shapes[1]
assert (input_shapes[2] == output_shape),\
"shape not consistent for axpy[%s <--> %s]" \
% (str(output_shape), str(input_shapes[2]))
return [output_shape]
def shape_crop(layer, input_shape):
assert len(input_shape) == 2, "the number of crop's inputs must be 2"
return [input_shape[1]]
def shape_detectionoutput(layer, input_shape):
return [[-1, 6]]
def shape_flatten(layer, input_shape):
assert len(input_shape) == 1, "the number of flatten's inputs must be 1"
params = layer.flatten_param
start_axis = params.axis
end_axis = params.end_axis
if start_axis < 0:
start_axis += len(input_shape[0])
if end_axis < 0:
end_axis += len(input_shape[0]) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis)
output_shape = [0] * (start_axis - 0) + [
-1
] + [0] * (len(input_shape[0]) - end_axis)
return [output_shape]
def shape_normalize(layer, input_shape):
return input_shape
def shape_permute(layer, input_shape):
params = layer.permute_param
order = list(params.order)
inshape = input_shape[0]
output_shape = []
for ii in order:
assert ii < len(inshape), "invalid order for permute[%s]" % (name)
output_shape.append(inshape[ii])
return [output_shape]
def shape_power(layer, input_shape):
return input_shape
def shape_priorbox(layer, input_shape):
params = layer.prior_box_param
min_size = list(params.min_size)
max_size = list(params.max_size)
aspect_ratio = list(params.aspect_ratio)
assert len(input_shapes[0]) == 2, "invalid inputs for Priorbox[%s]" % (name)
fc_shape = input_shapes[0][0]
N = 1
if not max_size == None:
N += 1
if not aspect_ratio == None:
N += 2 * len(aspect_ratio)
N_bbx = fc_shape[2] * fc_shape[3] * N
output_shape = [[1, 2, 4 * N_bbx]]
return output_shape
def shape_reduction(layer, input_shape):
params = layer.reduction_param
axis = params.axis
if axis < 0:
axis += len(input_shape[0]) + 1
assert axis <= len(input_shape[0]), 'invalid axis[%d] error' % (axis)
return [input_shape[0:axis]]
def shape_roipooling(layer, input_shape):
params = layer.roi_pooling_param
pooled_w = params.pooled_w
pooled_h = params.pooled_h
spatial_scale = params.spatial_scale
assert len(
input_shapes[0]) == 2, "not valid input shape for roipooling layer"
base_fea_shape = input_shapes[0][0]
rois_shape = input_shapes[0][1]
output_shape = base_fea_shape
output_shape[0] = rois_shape[0]
output_shape[2] = pooled_h
output_shape[3] = pooled_w
return [output_shape]
def shape_select(layer, input_shape):
input_shape = list(input_shape[0])
params = layer.select_param
axis = params.axis
slice_point = list(params.slice_point)
start = slice_point[0]
if len(slice_point) == 2:
end = slice_point[1]
else:
end = input_shape[axis]
assert end > start, "invalid slice_point with [start:%d, end:%d]"\
% (start, end)
output_shape = input_shape
output_shape[axis] = end - start
return [output_shape]
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numbers import numbers
import numpy as np
from x2paddle.decoder.caffe_decoder import CaffeGraph from x2paddle.decoder.caffe_decoder import CaffeGraph
from x2paddle.core.op_mapper import OpMapper from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import * from x2paddle.core.util import *
...@@ -106,11 +107,17 @@ class CaffeOpMapper(OpMapper): ...@@ -106,11 +107,17 @@ class CaffeOpMapper(OpMapper):
raise ValueError('Unable to determine kernel parameter!') raise ValueError('Unable to determine kernel parameter!')
return default return default
def get_kernel_parameters(self, kind, params, kernel_default=[1, 1]): def get_kernel_parameters(self, kind, params):
assert kind in ['Convolution', 'Pooling', 'Deconvolution'] assert kind in ['Convolution', 'Pooling', 'Deconvolution']
k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0, kernel_default[0]) k_h = self.get_kernel_value(params.kernel_h,
k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1, kernel_default[1]) params.kernel_size,
0,
default=1)
k_w = self.get_kernel_value(params.kernel_w,
params.kernel_size,
1,
default=1)
s_h = self.get_kernel_value(params.stride_h, s_h = self.get_kernel_value(params.stride_h,
params.stride, params.stride,
0, 0,
...@@ -144,6 +151,18 @@ class CaffeOpMapper(OpMapper): ...@@ -144,6 +151,18 @@ class CaffeOpMapper(OpMapper):
return c_o, kernel, stride, pad, dilation, group return c_o, kernel, stride, pad, dilation, group
def get_input_name(self, node):
if hasattr(node, "index"):
return node.layer_name + "[{}]".format(node.index)
else:
return node.layer_name
def is_BN(self, node):
return True if node.layer_type == 'BatchNorm' else False
def is_Scale(self, node):
return True if node.layer_type == 'Scale' else False
def Input(self, node): def Input(self, node):
shape = list(node.layer.input_param.shape[0].dim)[1:] shape = list(node.layer.input_param.shape[0].dim)[1:]
dtype = 'float32' dtype = 'float32'
...@@ -159,6 +178,8 @@ class CaffeOpMapper(OpMapper): ...@@ -159,6 +178,8 @@ class CaffeOpMapper(OpMapper):
def Convolution(self, node): def Convolution(self, node):
data = node.data data = node.data
assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format(
node.layer_name, node.layer_type)
data = self.adjust_parameters(node, data) data = self.adjust_parameters(node, data)
self.weights[node.layer_name + '_weights'] = data[0] self.weights[node.layer_name + '_weights'] = data[0]
if len(data) == 2: if len(data) == 2:
...@@ -169,16 +190,30 @@ class CaffeOpMapper(OpMapper): ...@@ -169,16 +190,30 @@ class CaffeOpMapper(OpMapper):
assert len(node.inputs assert len(node.inputs
) == 1, 'The count of Convolution node\'s input is not 1.' ) == 1, 'The count of Convolution node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'filter_size': kernel, 'filter_size':
'num_filters': channel, kernel,
'stride': stride, 'num_filters':
'padding': pad, channel,
'dilation': dilation, 'stride':
'groups': group, stride,
'name': string(node.layer_name), 'padding':
'param_attr': string(node.layer_name + '_weights'), pad,
'bias_attr': string(node.layer_name + '_bias'), 'dilation':
dilation,
'groups':
group,
'name':
string(node.layer_name),
'param_attr':
string(node.layer_name + '_weights'),
'bias_attr':
False if len(data) == 1 else string(node.layer_name + '_bias'),
} }
node.fluid_code.add_layer("conv2d", node.fluid_code.add_layer("conv2d",
inputs=input, inputs=input,
...@@ -187,6 +222,8 @@ class CaffeOpMapper(OpMapper): ...@@ -187,6 +222,8 @@ class CaffeOpMapper(OpMapper):
def Deconvolution(self, node): def Deconvolution(self, node):
data = node.data data = node.data
assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format(
node.layer_name, node.layer_type)
data = self.adjust_parameters(node, data) data = self.adjust_parameters(node, data)
self.weights[node.layer_name + '_weights'] = data[0] self.weights[node.layer_name + '_weights'] = data[0]
if len(data) == 2: if len(data) == 2:
...@@ -197,17 +234,31 @@ class CaffeOpMapper(OpMapper): ...@@ -197,17 +234,31 @@ class CaffeOpMapper(OpMapper):
assert len(node.inputs assert len(node.inputs
) == 1, 'The count of Deconvolution node\'s input is not 1.' ) == 1, 'The count of Deconvolution node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'output_size': None, 'output_size':
'filter_size': kernel, None,
'num_filters': channel, 'filter_size':
'stride': stride, kernel,
'padding': pad, 'num_filters':
'dilation': dilation, channel,
'groups': group, 'stride':
'name': string(node.layer_name), stride,
'param_attr': string(node.layer_name + '_weights'), 'padding':
'bias_attr': string(node.layer_name + '_bias') pad,
'dilation':
dilation,
'groups':
group,
'name':
string(node.layer_name),
'param_attr':
string(node.layer_name + '_weights'),
'bias_attr':
False if len(data) == 1 else string(node.layer_name + '_bias')
} }
node.fluid_code.add_layer("conv2d_transpose", node.fluid_code.add_layer("conv2d_transpose",
inputs=input, inputs=input,
...@@ -216,13 +267,11 @@ class CaffeOpMapper(OpMapper): ...@@ -216,13 +267,11 @@ class CaffeOpMapper(OpMapper):
def Pooling(self, node): def Pooling(self, node):
params = node.layer.pooling_param params = node.layer.pooling_param
shape = node.input_shape[0] ceil_mode = getattr(params, 'ceil_mode', True)
global_pool = getattr(params, 'global_pooling', False) global_pool = getattr(params, 'global_pooling', False)
kernel_default = [1, 1] kernel_default = [1, 1]
if global_pool:
kernel_default = [shape[2],shape[3]]
channel, kernel, stride, pad, dilation, group = self.get_kernel_parameters( channel, kernel, stride, pad, dilation, group = self.get_kernel_parameters(
node.layer_type, params, kernel_default=kernel_default) node.layer_type, params)
if params.pool == 0: if params.pool == 0:
pool_type = 'max' pool_type = 'max'
else: else:
...@@ -230,13 +279,18 @@ class CaffeOpMapper(OpMapper): ...@@ -230,13 +279,18 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Pooling node\'s input is not 1.' node.inputs) == 1, 'The count of Pooling node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'pool_size': kernel, 'pool_size': kernel,
'pool_stride': stride, 'pool_stride': stride,
'pool_padding': pad, 'pool_padding': pad,
'ceil_mode': True, 'ceil_mode': ceil_mode,
'pool_type': string(pool_type), 'pool_type': string(pool_type),
'exclusive': True, 'exclusive': True,
'global_pooling': global_pool,
'name': string(node.layer_name) 'name': string(node.layer_name)
} }
node.fluid_code.add_layer("pool2d", node.fluid_code.add_layer("pool2d",
...@@ -248,6 +302,10 @@ class CaffeOpMapper(OpMapper): ...@@ -248,6 +302,10 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of ReLU node\'s input is not 1.' node.inputs) == 1, 'The count of ReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)} attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("relu", node.fluid_code.add_layer("relu",
inputs=input, inputs=input,
...@@ -265,6 +323,10 @@ class CaffeOpMapper(OpMapper): ...@@ -265,6 +323,10 @@ class CaffeOpMapper(OpMapper):
# We'll account for that here. # We'll account for that here.
alpha = params.alpha / float(params.local_size) alpha = params.alpha / float(params.local_size)
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'n': params.local_size, 'n': params.local_size,
'k': 1.0, 'k': 1.0,
...@@ -279,6 +341,8 @@ class CaffeOpMapper(OpMapper): ...@@ -279,6 +341,8 @@ class CaffeOpMapper(OpMapper):
def InnerProduct(self, node): def InnerProduct(self, node):
data = node.data data = node.data
assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format(
node.layer_name, node.layer_type)
data = self.adjust_parameters(node, data) data = self.adjust_parameters(node, data)
# Reshape the parameters to Paddle's ordering # Reshape the parameters to Paddle's ordering
transpose_order = (1, 0) transpose_order = (1, 0)
...@@ -298,12 +362,21 @@ class CaffeOpMapper(OpMapper): ...@@ -298,12 +362,21 @@ class CaffeOpMapper(OpMapper):
assert params.axis == 1 assert params.axis == 1
assert params.bias_term == True assert params.bias_term == True
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'size': params.num_output, 'size':
'name': string(node.layer_name), params.num_output,
'act': None, 'name':
'param_attr': string(node.layer_name + '_weights'), string(node.layer_name),
'bias_attr': string(node.layer_name + '_bias') 'act':
None,
'param_attr':
string(node.layer_name + '_weights'),
'bias_attr':
False if len(data) == 1 else string(node.layer_name + '_bias')
} }
node.fluid_code.add_layer("fc", node.fluid_code.add_layer("fc",
inputs=input, inputs=input,
...@@ -314,6 +387,10 @@ class CaffeOpMapper(OpMapper): ...@@ -314,6 +387,10 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Softmax node\'s input is not 1.' node.inputs) == 1, 'The count of Softmax node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.softmax_param params = node.layer.softmax_param
axis = params.axis axis = params.axis
shape = node.input_shape[0] shape = node.input_shape[0]
...@@ -358,50 +435,895 @@ class CaffeOpMapper(OpMapper): ...@@ -358,50 +435,895 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Slice node\'s input is not 1.' node.inputs) == 1, 'The count of Slice node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.slice_param params = node.layer.slice_param
axis = params.axis axis = params.axis
points = list(params.slice_point) points = list(params.slice_point)
shape = node.input_shape[0] maxint32 = 2147483647
count = shape[axis] points = [0] + points
sections = [] points.append(maxint32)
idx = 0 i = 0
for p in points: node.fluid_code.add_note('{} = []'.format(node.layer_name))
if idx == 0: for i in range(len(points)):
sections.append(p - 0)
elif idx == len(points) - 1:
sections.append(count - p)
else:
sections.append(points[idx + 1] - p)
idx += 1
attr = { attr = {
'dim': axis, 'axes': [axis],
'num_or_sections': sections, 'starts': [points[i]],
'name': string(node.layer_name + '_slice') 'ends': [points[i + 1]],
'name': string(node.layer_name + '_' + str(i))
} }
node.fluid_code.add_layer("split", node.fluid_code.add_layer("slice",
inputs=input, inputs=input,
output=node, output=string(node.layer_name + '_' +
str(i)),
param_attr=attr) param_attr=attr)
node.fluid_code.add_note('{}.append({})'.format(
node.layer_name, node.layer_name + '_' + str(i)))
if i == len(points) - 2:
break
def Concat(self, node): def Concat(self, node):
assert len( assert len(
node.inputs) > 1, 'The count of Concat node\'s input is not more than 1.' node.inputs
) > 1, 'The count of Concat node\'s input is not more than 1.'
inputs = [] inputs = []
for i in range(len(node.inputs)): for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_bottom_node(node, idx=i, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
inputs.append(input) inputs.append(input)
params = node.layer.concat_param params = node.layer.concat_param
axis = params.axis axis = params.axis
attr = {'axis': axis, 'name': string(node.layer_name)}
node.fluid_code.add_layer("concat",
inputs=inputs,
output=node,
param_attr=attr)
def PReLU(self, node):
assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.prelu_param
mode_bool = params.channel_shared
if mode_bool:
mode = 'all'
else:
mode = 'channel'
data = node.data
assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format(
node.layer_name, node.layer_type)
self.weights[node.layer_name + '_weights'] = data[0]
attr = { attr = {
'axis': axis, 'mode': mode,
'name': string(node.layer_name + '_concat') 'param_attr': string(node.layer_name + '_weights'),
'name': string(node.layer_name)
} }
node.fluid_code.add_layer("concat", node.fluid_code.add_layer("prelu",
inputs=input,
output=node,
param_attr=attr)
def Sigmoid(self, node):
assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("sigmoid",
inputs=input,
output=node,
param_attr=attr)
def AbsVal(self, node):
assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("absval",
inputs=input,
output=node,
param_attr=attr)
def Accuracy(self, node):
assert len(
node.inputs) == 2, 'The count of Accuracy node\'s input is not 2.'
inputs = []
inputs[0] = None
inputs[1] = None
i = 0
for shape in node.input_shape:
if shape[1] == 1:
input = self.graph.get_bottom_node(node, idx=i, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
inputs[1] = input
else:
input = self.graph.get_bottom_node(node, idx=i, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
inputs[0] = input
i += 1
params = node.layer.accuracy_param
top_k = params.top_k
axis = params.axis
ignore_label = params.ignore_label
# TODO(syf)
assert axis == 1, 'PaddlePaddle can not support the situation when the axis is not 1.'
assert not ignore_label >= 0, 'PaddlePaddle can not support the situation when the model has ignore label.'
attr = {'k': top_k}
node.fluid_code.add_layer("accuracy",
inputs=inputs,
output=node,
param_attr=attr)
def TanH(self, node):
assert len(
node.inputs) == 1, 'The count of TanH node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("tanh",
inputs=input,
output=node,
param_attr=attr)
def Eltwise(self, node):
assert len(
node.inputs) == 2, 'The count of TanH node\'s input is not 2.'
params = node.layer.eltwise_param
mode = params.operation
inputs = []
input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input0):
tmp = self.graph.get_bottom_node(input0, idx=0, copy=True)
if self.is_BN(tmp):
input0 = tmp
inputs.append(input0)
input1 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
inputs.append(input1)
if mode == 0:
inputs_dict = {}
inputs_dict['x'] = inputs[0]
inputs_dict['y'] = inputs[1]
attr = {'act': None, 'name': string(node.layer_name)}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs_dict,
output=node,
param_attr=attr)
elif mode == 1:
if hasattr(params, 'coeff') and len(params.coeff) == 2:
coeff = params.coeff
input1_name = self.get_input_name(inputs[0])
attr = {
'shape': [1],
'value': coeff[0],
'dtype': '{}.dtype'.format(input1_name)
}
node.fluid_code.add_layer("fill_constant",
inputs=None,
output=node.layer_name + '_const1',
param_attr=attr)
attr = {'act': None, 'name': string(node.layer_name + '_mul1')}
node.fluid_code.add_layer("elementwise_mul",
inputs=input1_name + ', ' +
node.layer_name + '_const1',
output=node.layer_name + '_mul1',
param_attr=attr)
input2_name = self.get_input_name(inputs[1])
attr = {
'shape': [1],
'value': coeff[1],
'dtype': '{}.dtype'.format(input2_name)
}
node.fluid_code.add_layer("fill_constant",
inputs=None,
output=node.layer_name + '_const2',
param_attr=attr)
attr = {'act': None, 'name': string(node.layer_name + '_mul2')}
node.fluid_code.add_layer("elementwise_mul",
inputs=input2_name + ', ' +
node.layer_name + '_const2',
output=node.layer_name + '_mul2',
param_attr=attr)
attr = {'act': None, 'name': string(node.layer_name)}
node.fluid_code.add_layer("elementwise_add",
inputs='{}_mul1, {}_mul2'.format(
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
else:
inputs_dict = {}
inputs_dict['x'] = inputs[0]
inputs_dict['y'] = inputs[1]
attr = {'act': None, 'name': string(node.layer_name)}
node.fluid_code.add_layer("elementwise_add",
inputs=inputs_dict,
output=node,
param_attr=attr)
else:
inputs_dict = {}
inputs_dict['x'] = inputs[0]
inputs_dict['y'] = inputs[1]
attr = {'act': None, 'name': string(node.layer_name)}
node.fluid_code.add_layer("elementwise_max",
inputs=inputs_dict,
output=node,
param_attr=attr)
def BatchNorm(self, node):
assert len(node.inputs) == 1 and len(
node.outputs
) == 1, 'The count of BatchNorm node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.batch_norm_param
if hasattr(params, 'eps'):
eps = params.eps
else:
eps = 1e-5
assert len(node.data) == 3
node.data = [np.squeeze(i) for i in node.data]
mean, variance, scale = node.data
# Prescale the stats
scaling_factor = 1.0 / scale if scale != 0 else 0
mean *= scaling_factor
variance *= scaling_factor
self.weights[node.layer_name + '_mean'] = mean
self.weights[node.layer_name + '_variance'] = variance
if self.graph.get_node(node.outputs[0]).layer_type == 'Scale':
data = self.graph.get_node(node.outputs[0]).data
self.weights[node.layer_name + '_scale'] = np.squeeze(data[0])
self.weights[node.layer_name + '_offset'] = np.squeeze(data[1])
attr = {
'is_test': True,
'param_attr': string(node.layer_name + '_scale'),
'bias_attr': string(node.layer_name + '_offset'),
'moving_mean_name': string(node.layer_name + '_mean'),
'moving_variance_name': string(node.layer_name + '_variance'),
'epsilon': eps,
'name': string(node.layer_name)
}
else:
attr = {
'is_test': True,
'param_attr': None,
'bias_attr': None,
'moving_mean_name': string(node.layer_name + '_mean'),
'moving_variance_name': string(node.layer_name + '_variance'),
'epsilon': eps,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("batch_norm",
inputs=input,
output=node,
param_attr=attr)
def Scale(self, node):
assert len(
node.outputs) == 1, 'The count of Scale node\'s output is not 1.'
if len(node.inputs) == 1 and self.graph.get_node(
node.inputs[0]).layer_type == 'BatchNorm':
return
else:
self.weights[node.layer_name + '_scale'] = np.squeeze(nose.data[0])
self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1])
params = node.layer.scale_param
axis = params.axis
num_axes = params.num_axes
assert num_axes == 1, "layer scale not support this num_axes[%d] now" % (
num_axes)
inputs = []
if len(node.inputs) == 2:
# for two tensor, here resets axis to 1. Maybe there is a bug for unkown case.
axis = 1
bias_shape = node.input_shape[0][axis:axis + num_axes]
input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input0):
tmp = self.graph.get_bottom_node(input0, idx=0, copy=True)
if self.is_BN(tmp):
input0 = tmp
input1 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
inputs.append(input0)
inputs.append(input1)
attr = {'axis': axis, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs, inputs=inputs,
output=node.layer_name + '_mul',
param_attr=attr)
else:
bias_shape = node.input_shape[0][axis:axis + num_axes]
input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input0):
tmp = self.graph.get_bottom_node(input0, idx=0, copy=True)
if self.is_BN(tmp):
input0 = tmp
input0_name = self.get_input_name(input0)
attr = {
'dtype': '{}.dtype'.formatr(input0_name),
'shape': bias_shape,
'name': string(node.layer_name + '_cparam1'),
'attr': string(node.layer_name + '_scale'),
'is_bias': True,
'default_initializer': 'Constant(value=1.0)'
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node,
param_attr=attr)
inputs.append(input0)
inputs.append(node)
attr = {'axis': axis, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node.layer_name + '_mul',
param_attr=attr)
scale_shape = bias_shape
input0_name = self.get_input_name(input0)
attr = {
'dtype': '{}.dtype'.formatr(input0_name),
'shape': scale_shape,
'name': string(node.layer_name + '_cparam2'),
'attr': string(node.layer_name + '_offset'),
'is_bias': True,
'default_initializer': 'Constant(value=1.0)'
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node.layer_name + '_offset_param',
param_attr=attr)
attr = {'axis': axis, 'name': string(node.layer_name + '_add')}
node.fluid_code.add_layer("elementwise_add",
inputs='{}_mul, {}_offset_param'.format(
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
def Reshape(self, node):
assert len(node.inputs) == 1 and len(
node.outputs
) == 1, 'The count of Reshape node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
top_count = len(input.layer.top)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
is_inplace, = False if top_count == 1 else True
output_shape = node.output_shape[0]
attr = {
'shape': output_shape,
'inplace': is_inplace,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reshape",
inputs=input,
output=node,
param_attr=attr)
def ArgMax(self, node):
assert len(node.inputs) == 1 and len(
node.outputs
) == 1, 'The count of ArgMax node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
input_shape = node.input_shape[0]
params = node.layer.argmax_param
out_max_val = params.out_max_val if hasattr(params,
out_max_val) else False
top_k = params.top_k if hasattr(params, top_k) else 1
axis = parmas.axis if hasattr(params, axis) else -1
if axis < 0:
axis += len(input_shape)
if out_max_val is True:
attr = {'k': top_k, 'name': string(node.layer_name + '_topk')}
node.fluid_code.add_layer("topk",
inputs=input,
output='{}_topk_var, {}_index_var'.format(
node.layer_name, node.layer_name),
param_attr=attr)
attr = {'dtype': '{}_topk_var.dtype'.format(node.layer_name)}
node.fluid_code.add_layer(
"cast",
inputs='{}_index_var'.format(node.layer_name),
output='{}_index_var'.format(node.layer_name),
param_attr=attr)
attr = {'axis': axis, 'name': string(node.layer_name)}
node.fluid_code.add_layer("concat",
inputs='{}_topk_var, {}_index_var'.format(
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
else:
attr = {'k': top_k, 'name': string(node.layer_name)}
node.fluid_code.add_layer("topk",
inputs=input,
output='_, {}'.format(node.layer_name),
param_attr=attr)
def Axpy(self, node):
assert len(
node.inputs) == 3, 'The count of Axpy node\'s input is not 3.'
alpha = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(alpha):
tmp = self.graph.get_bottom_node(alpha, idx=0, copy=True)
if self.is_BN(tmp):
alpha = tmp
x = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(x):
tmp = self.graph.get_bottom_node(x, idx=0, copy=True)
if self.is_BN(tmp):
x = tmp
y = self.graph.get_bottom_node(node, idx=2, copy=True)
if self.is_Scale(y):
tmp = self.graph.get_bottom_node(y, idx=0, copy=True)
if self.is_BN(tmp):
y = tmp
attr = {'axis': 0, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs={
'x': alpha,
'y': x
},
output=node,
param_attr=attr)
attr = {'name': string(node.layer_name + '_add')}
node.fluid_code.add_layer("elementwise_add",
inputs={
'x': node,
'y': y
},
output=node,
param_attr=attr)
def Crop(self, node):
assert len(
node.inputs) == 2, 'The count of Crop node\'s input is not 2.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
example = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(example):
tmp = self.graph.get_bottom_node(example, idx=0, copy=True)
if self.is_BN(tmp):
example = tmp
params = node.layer.crop_param
axis = parmas.axis
input_shape = node.input_shape[0]
if axis < 0:
axis += len(input_shape)
offset_real = [0] * len(input_shape)
if hasattr(params, offset):
offset = list(params.offset)
assert (len(input_shape) - axis) == len(
offset), "invalid offset[%s] in crop layer" % (str(offset))
offset_real = [0] * axis + offset
attr = {'offsets': offset_real, 'name': string(node.layer_name)}
node.fluid_code.add_layer("crop",
inputs={
'x': input,
'y': example
},
output=node, output=node,
param_attr=attr) param_attr=attr)
def DetectionOutput(self, node):
assert len(
node.inputs
) == 3, 'The count of DetectionOutput node\'s input is not 3.'
mbox_loc = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(mbox_loc):
tmp = self.graph.get_bottom_node(mbox_loc, idx=0, copy=True)
if self.is_BN(tmp):
mbox_loc = tmp
mbox_conf_flatten = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(mbox_conf_flatten):
tmp = self.graph.get_bottom_node(mbox_conf_flatten,
idx=0,
copy=True)
if self.is_BN(tmp):
mbox_conf_flatten = tmp
mbox_priorbox = self.graph.get_bottom_node(node, idx=2, copy=True)
if self.is_Scale(mbox_priorbox):
tmp = self.graph.get_bottom_node(mbox_priorbox, idx=0, copy=True)
if self.is_BN(tmp):
mbox_priorbox = tmp
params = node.layer.detection_output_param
nms_threshold = 0.3
top_k = 10
eta = 1.0
if hasattr(params, 'nms_param'):
nms_threshold = getattr(params.nms_param, 'nms_threshold', 0.3)
top_k = getattr(params.nms_param, 'top_k', 10)
eta = getattr(params.nms_param, 'eta', 1.0)
background_label = getattr(params, 'background_label_id', 0)
share_location = getattr(params, 'share_location', True)
keep_top_k = getattr(params, 'keep_top_k', 100)
confidence_threshold = getattr(params, 'confidence_threshold', 0.1)
attr = {
'num_or_sections': 2,
'dim': 1,
'name': string(node.layer_name + '_split')
}
node.fluid_code.add_layer("split",
inputs=mbox_priorbox,
output='mbox_priorbox_list',
param_attr=attr)
node.fluid_code.add_note('pb = mbox_priorbox_list[0]')
node.fluid_code.add_note('pbv = mbox_priorbox_list[1]')
attr = {'shape': [-1, 4], 'name': string(node.layer_name + '_reshape1')}
node.fluid_code.add_layer("reshape",
inputs='pb',
output='pb',
param_attr=attr)
attr = {'shape': [-1, 4], 'name': string(node.layer_name + '_reshape2')}
node.fluid_code.add_layer("reshape",
inputs='pbv',
output='pbv',
param_attr=attr)
# TODO(syf): need chaeck
attr = {
'shape': [-1, node.input_shape[1][1], 4],
'name': string(node.layer_name + '_reshape3')
}
node.fluid_code.add_layer("reshape",
inputs=mbox_loc,
output='mbox_loc',
param_attr=attr)
attr = {
'background_label': background_label,
'nms_threshold': nms_threshold,
'nms_top_k': top_k,
'keep_top_k': keep_top_k,
'score_threshold': confidence_threshold,
'nms_eta': eta
}
inputs_str = get_input_name(mbox_conf_flatten) + ', mbox_loc, pb, pbv'
node.fluid_code.add_layer("detection_output",
inputs=inputs_str,
output=node,
param_attr=attr)
def Flatten(self, noed):
assert len(
node.inputs
) == 1, 'The count of DetectionOutput node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
shape = node.output_shape[0]
attr = {'shape': shape, 'name': string(node.layer_name)}
node.fluid_code.add_layer("reshape",
inputs=input,
output=node,
param_attr=attr)
def Normalize(self, node):
assert len(
node.inputs) == 1, 'The count of Normalize node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.norm_param
across_spatial = params.across_spatial
channel_shared = params.channel_shared
assert across_spatial == False, "Only support across_spatial == False for Normalize"
attr = {'axis': 1, 'name': string(node.layer_name + '_l2')}
node.fluid_code.add_layer("l2_normalize",
inputs=input,
output=node.layer_name + '_l2',
param_attr=attr)
input_name = self.get_input_name(input)
data = node.data
data = self.adjust_parameters(node, data)
self.weights[node.layer_name + '_scale'] = data[0]
node.fluid_code.add_note(
'{}_scale_attr = ParamAttr(name=\'{}\')'.format(
node.layer_name, node.layer_name + '_scale'))
attr = {
'shape': [1] if channel_shared else [node.input_shape[0][1]],
'dtype': '{}.dtype'.format(input_name),
'attr': '{}_scale_attr'.format(node.layer_name),
'name': string(node.layer_name + '_param')
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node.layer_name + '_scale_param',
param_attr=attr)
attr = {
'axis': -1 if channel_shared else 1,
'name': string(node.layer_name + '_mul')
}
node.fluid_code.add_layer("elementwise_mul",
inputs=node.layer_name + '_l2, ' +
node.layer_name + '_scale_param',
output=node,
param_attr=attr)
def Permute(self, node):
assert len(
node.inputs) == 1, 'The count of Permute node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.permute_param
order = list(params.order)
attr = {'order': order, 'name': string(node.layer_name)}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
def Power(self, node):
assert len(
node.inputs) == 1, 'The count of Permute node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.power_param
power = params.power
scale = params.scale
shift = params.shift
attr = {
'scale': scale,
'bias': shift,
'bias_after_scale': True,
'name': string(node.layer_name + '_scale')
}
node.fluid_code.add_layer("scale",
inputs=input,
output=node,
param_attr=attr)
attr = {'factor': power, 'name': string(node.layer_name)}
node.fluid_code.add_layer("pow",
inputs=node,
output=node,
param_attr=attr)
def PriorBox(self, node):
assert len(
node.inputs) == 2, 'The count of PriorBox node\'s input is not 2.'
input1 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
input2 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input2):
tmp = self.graph.get_bottom_node(input2, idx=0, copy=True)
if self.is_BN(tmp):
input2 = tmp
input_dict = {'input': input1, 'image': input2}
params = node.layer.prior_box_param
step = getattr(params, 'step', 0.0)
offset = getattr(params, 'offset', 0.5)
min_size = list(params.min_size)
max_size = list(params.max_size)
aspect_ratio = list(params.aspect_ratio)
flip = getattr(params, 'flip', False)
clip = getattr(params, 'clip', False)
variance = list(getattr(params, 'variance', [0.1, 0.1, 0.2, 0.2]))
steps = tuple(step) if type(step) is list or type(step) is tuple else (
step, step)
attr = {
'min_sizes': min_size,
'max_sizes': max_size,
'aspect_ratios': aspect_ratio,
'variance': variance,
'flip': flip,
'clip': clip,
'step': steps,
'offset': offset,
'min_max_aspect_ratios_order': True,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("prior_box",
inputs=input_dict,
output='{}_box, {}_var'.format(
node.layer_name, node.layer_name),
param_attr=attr)
attr = {
'shape': [1, 1, -1],
}
node.fluid_code.add_layer("reshape",
inputs='{}_box'.format(node.layer_name),
output='{}_box'.format(node.layer_name),
param_attr=attr)
attr = {
'shape': [1, 1, -1],
}
node.fluid_code.add_layer("reshape",
inputs='{}_var'.format(node.layer_name),
output='{}_var'.format(node.layer_name),
param_attr=attr)
attr = {'axis': 1, 'name': string(node.layer_name + '_concat')}
node.fluid_code.add_layer("concat",
inputs='[{}_box, {}_var]'.format(
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
def Reduction(self, node):
assert len(
node.inputs) == 1, 'The count of Reduction node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.reduction_param
operation = params.operation
axis = params.axis
coeff = params.coeff
assert operation >= 1 and operation <= 4, "reduction reduction [%s] error" % (
operation)
input_len = len(node.input_shape[0])
if axis < 0:
axis += input_len + 1
dim = list(range(input_len))
if operation == 1: ## operation = SUM
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_sum",
inputs=input,
output=node,
param_attr=attr)
elif operation == 2: ## operation = ASUM
attr = {'name': string(node.layer_name + '_abs')}
node.fluid_code.add_layer("abs",
inputs=input,
output=node,
param_attr=attr)
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_sum",
inputs=node,
output=node,
param_attr=attr)
elif operation == 3: ## operation = SUMSQ
attr = {'factor': 2.0, 'name': string(node.layer_name + '_pow')}
node.fluid_code.add_layer("pow",
inputs=input,
output=node,
param_attr=attr)
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_sum",
inputs=node,
output=node,
param_attr=attr)
else: ## operation = MEAN
attr = {
'dim': dim[axis:],
'keep_dim': False,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("reduce_mean",
inputs=node,
output=node,
param_attr=attr)
attr = {'scale': coeff}
node.fluid_code.add_layer("scale",
inputs=node,
output=node,
param_attr=attr)
def ROIPooling(self, node):
assert len(
node.inputs) == 2, 'The count of ROIPooling node\'s input is not 2.'
input1 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
input2 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input2):
tmp = self.graph.get_bottom_node(input2, idx=0, copy=True)
if self.is_BN(tmp):
input2 = tmp
attr = {'axes': [1], 'starts': [1], 'ends': [5]}
node.fluid_code.add_layer("slice",
inputs=input2,
output=input2,
param_attr=attr)
input_dict = {'input': input1, 'rois': input2}
params = node.layer.roi_pooling_param
attr = {
'pooled_w': params.pooled_w,
'pooled_h': params.pooled_h,
'spatial_scale': params.spatial_scale,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("roi_pool",
inputs=input_dict,
output=node,
param_attr=attr)
def Select(self, node):
assert len(
node.inputs) == 1, 'The count of Select node\'s input is not 2.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.select_param
slice_point = list(params.slice_point)
axis = params.axis
maxint32 = 2147483647
slice_point = [0] + slice_point
slice_point.append(maxint32)
i = 0
node.fluid_code.add_note('{} = []'.format(node.layer_name))
for i in range(len(slice_point)):
attr = {
'axes': [axis],
'starts': [slice_point[i]],
'ends': [slice_point[i + 1]],
'name': string(node.layer_name + '_' + str(i))
}
node.fluid_code.add_layer("slice",
inputs=input,
output=string(node.layer_name + '_' +
str(i)),
param_attr=attr)
node.fluid_code.add_note('{}.append({})'.format(
node.layer_name, node.layer_name + '_' + str(i)))
if i == len(slice_point) - 2:
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册