提交 ac476692 编写于 作者: S SunAhong1993

fix the bug of shape and add optimizer

上级 fc3bc25f
...@@ -93,10 +93,14 @@ def tf2paddle(model_path, save_dir): ...@@ -93,10 +93,14 @@ def tf2paddle(model_path, save_dir):
def caffe2paddle(proto, weight, save_dir, caffe_proto): def caffe2paddle(proto, weight, save_dir, caffe_proto):
from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
print("Now translating model from caffe to paddle.") print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto) model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model) mapper = CaffeOpMapper(model)
optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale()
optimizer.merge_op_activation()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from x2paddle.core.graph import GraphNode from x2paddle.core.graph import GraphNode
import collections import collections
from x2paddle.core.util import *
class Layer(object): class Layer(object):
...@@ -81,6 +82,8 @@ class Layer(object): ...@@ -81,6 +82,8 @@ class Layer(object):
param_attr = collections.OrderedDict(self.param_attr) param_attr = collections.OrderedDict(self.param_attr)
for key, value in param_attr.items(): for key, value in param_attr.items():
if '\n' in str(value):
value = string(str(value).replace('\n', ','))
layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code + key + "={}, ".format(value)
layer_code = layer_code.strip(", ") layer_code = layer_code.strip(", ")
......
...@@ -63,17 +63,6 @@ class CaffeGraphNode(GraphNode): ...@@ -63,17 +63,6 @@ class CaffeGraphNode(GraphNode):
def set_params(self, params): def set_params(self, params):
self.data = params self.data = params
def set_output_shape(self, input_shape, is_input=True):
func_name = 'shape_' + self.layer_type.lower()
if is_input:
self.output_shape = getattr(caffe_shape, func_name)(self.layer,
input_shape)
else:
self.output_shape = input_shape
def set_input_shape(self, input_shape):
self.input_shape = input_shape
class CaffeGraph(Graph): class CaffeGraph(Graph):
def __init__(self, model, params): def __init__(self, model, params):
......
...@@ -14,6 +14,18 @@ def detectionoutput_layer(inputs, ...@@ -14,6 +14,18 @@ def detectionoutput_layer(inputs,
confidence_threshold=0.1, confidence_threshold=0.1,
input_shape=None, input_shape=None,
name=None): name=None):
nms_param_str = nms_param
nms_param = {}
part = nms_param_str.split(',')
for s in part:
if s == '':
break
else:
name, obj = s.split(': ')
if name == 'top_k':
nms_param[name] = int(obj)
else:
nms_param[name] = float(obj)
if nms_param is None: if nms_param is None:
nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
mbox_conf_flatten = inputs[1] mbox_conf_flatten = inputs[1]
...@@ -24,20 +36,21 @@ def detectionoutput_layer(inputs, ...@@ -24,20 +36,21 @@ def detectionoutput_layer(inputs,
pb = fluid.layers.reshape(x=pb, shape=[-1, 4]) pb = fluid.layers.reshape(x=pb, shape=[-1, 4])
pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4]) pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4])
mbox_loc = inputs[0] mbox_loc = inputs[0]
mbox_loc = fluid.layers.reshape(x=mbox_loc, mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[-1, pb.shape[0], 4])
shape=[-1, mbox_conf_flatten.shape[1], 4]) mbox_conf_flatten = fluid.layers.reshape(x=mbox_conf_flatten,
shape=[0, pb.shape[0], -1])
default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
fields = ['eta', 'top_k', 'nms_threshold'] fields = ['eta', 'top_k', 'nms_threshold']
for f in default.keys(): for f in default.keys():
if not nms_param.has_key(f): if f not in nms_param:
nms_param[f] = default[f] nms_param[f] = default[f]
out = fluid.layers.detection_output( out = fluid.layers.detection_output(
scores=mbox_conf_flatten, scores=mbox_conf_flatten,
loc=mbox_loc, loc=mbox_loc,
prior_box=pb, prior_box=pb,
prior_box_var=pbv, prior_box_var=pbv,
background_label=background_label, background_label=background_label_id,
nms_threshold=nms_param["nms_threshold"], nms_threshold=nms_param["nms_threshold"],
nms_top_k=nms_param["top_k"], nms_top_k=nms_param["top_k"],
keep_top_k=keep_top_k, keep_top_k=keep_top_k,
......
...@@ -3,7 +3,7 @@ from x2paddle.core.util import * ...@@ -3,7 +3,7 @@ from x2paddle.core.util import *
def priorbox_shape(input_shape, max_size=None, aspect_ratio=None): def priorbox_shape(input_shape, max_size=None, aspect_ratio=None):
fc_shape = input_shapes[0] fc_shape = input_shape[0]
N = 1 N = 1
if not max_size == None: if not max_size == None:
N += 1 N += 1
...@@ -18,26 +18,27 @@ def priorbox_layer(inputs, ...@@ -18,26 +18,27 @@ def priorbox_layer(inputs,
step=0.0, step=0.0,
offset=0.5, offset=0.5,
min_size=None, min_size=None,
max_size=None, max_size=[],
aspect_ratio=[1.0], aspect_ratio=[1.0],
flip=False, flip=False,
clip=False, clip=False,
variance=[0.1, 0.1, 0.2, 0.2], variance=[0.1, 0.1, 0.2, 0.2],
input_shape=None, input_shape=None,
name=None): name=None):
input = input_shape[0] input = inputs[0]
image = input_shape[1] image = inputs[1]
steps = tuple(step) if type(step) is list or type(step) is tuple else (step, steps = tuple(step) if type(step) is list or type(step) is tuple else (step,
step) step)
box, variance_ = fluid.layers.prior_box(input, box, variance_ = fluid.layers.prior_box(input,
image, image,
min_sizes=list(min_size), min_sizes=min_size,
max_sizes=list(max_size), max_sizes=max_size,
aspect_ratios=list(aspect_ratio), aspect_ratios=aspect_ratio,
variance=list(variance), variance=variance,
flip=flip, flip=flip,
clip=clip, clip=clip,
steps=step, steps=steps,
offset=offset, offset=offset,
name=name, name=name,
min_max_aspect_ratios_order=True) min_max_aspect_ratios_order=True)
......
...@@ -9,12 +9,12 @@ def shufflechannel_shape(input_shape): ...@@ -9,12 +9,12 @@ def shufflechannel_shape(input_shape):
def shufflechannel_layer(inputs, group=None, input_shape=None, name=None): def shufflechannel_layer(inputs, group=None, input_shape=None, name=None):
input = inputs[0] input = inputs[0]
c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1) c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1)
size = int(input_shape[0][1]/group) size = int(input_shape[0][1] / group)
new_c_fm = [] new_c_fm = []
for i in range(size): for i in range(size):
for j in range(group): for j in range(group):
new_c_fm.append(c_fm[j * size + i]) new_c_fm.append(c_fm[j * size + i])
out = fluid.layers.concat(new_c_fm, axis = 1) out = fluid.layers.concat(new_c_fm, axis=1)
return out return out
......
...@@ -13,104 +13,58 @@ ...@@ -13,104 +13,58 @@
# limitations under the License. # limitations under the License.
import math import math
import numbers
from functools import reduce
def get_params_w_h(params):
def get_kernel_parameters(params):
[k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2
elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w else params.kernel_size[
len(params.kernel_size) - 1]
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h else params.stride[0]
s_w = params.stride_w if params.stride_w else params.stride[
len(params.stride) - 1]
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
elif len(params.pad) > 0:
p_h = params.pad_h if params.pad_h else params.pad[0]
p_w = params.pad_w if params.pad_w else params.pad[len(params.pad) - 1]
dila_h = dila_w = 1
if hasattr(params, 'dilation'): if hasattr(params, 'dilation'):
if len(params.dilation) == 0: dila_len = len(params.dilation)
dila_h = 1 if dila_len == 2:
dila_w = 1
elif len(params.dilation) == 1:
dila_h = params.dilation[0]
dila_w = params.dilation[0]
else:
dila_h = params.dilation[0] dila_h = params.dilation[0]
dila_w = params.dilation[1] dila_w = params.dilation[1]
elif dila_len == 1:
dila_h = dila_w = params.dilation[0]
else: else:
dila_h = 1 assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
dila_w = 1 dila_len)
return dila_h, dila_w, p_h, p_w, k_h, k_w, s_h, s_w
if not isinstance(getattr(params, 'pad'), int):
if len(params.pad) == 0:
pad_h = 0
pad_w = 0
elif len(params.pad) == 1:
pad_h = params.pad[0]
pad_w = params.pad[0]
else:
pad_h, pad_w, = params.pad[0]
pad_w = params.pad[1]
if params.pad_h != 0 or params.pad_w != 0:
pad_h = params.pad_h
pad_w = params.pad_w
else:
if params.pad_h != 0 or params.pad_w != 0:
pad_h = params.pad_h
pad_w = params.pad_w
else:
pad_h = getattr(params, 'pad')
pad_w = getattr(params, 'pad')
if not isinstance(getattr(params, 'kernel_size'), int):
if len(params.kernel_size) == 0:
kernel_h = 1
kernel_w = 1
elif len(params.kernel_size) == 1:
kernel_h = params.kernel_size[0]
kernel_w = params.kernel_size[0]
else:
kernel_h = params.kernel_size[0]
kernel_w = params.kernel_size[1]
if params.kernel_h != 0 or params.kernel_w != 0:
kernel_h = params.kernel_h
kernel_w = params.kernel_w
else:
if params.kernel_h != 0 or params.kernel_w != 0:
kernel_h = params.kernel_h
kernel_w = params.kernel_w
else:
kernel_h = getattr(params, 'kernel_size')
kernel_w = getattr(params, 'kernel_size')
if not isinstance(getattr(params, 'stride'), int):
if len(params.stride) == 0:
stride_h = 1
stride_w = 1
elif len(params.stride) == 1:
stride_h = params.stride[0]
stride_w = params.stride[0]
else:
stride_h = params.stride[0]
stride_w = params.stride[1]
if params.stride_h != 0 or params.stride_w != 0:
stride_h = params.stride_h
stride_w = params.stride_w
else:
if params.stride_h != 0 or params.stride_w != 0:
stride_h = params.stride_h
stride_w = params.stride_w
else:
stride_h = getattr(params, 'stride')
stride_w = getattr(params, 'stride')
return dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w
def get_filter_output_shape(i_h, i_w, params, round_func): def get_strided_kernel_output_shape(params, input_shape, round_func):
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_params_w_h( i_h = input_shape[2]
i_w = input_shape[3]
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params) params)
o_h = (i_h + 2 * pad_h - (dila_h * o_h = (i_h + 2 * pad_h - (dila_h *
(kernel_h - 1) + 1)) / float(stride_h) + 1 (kernel_h - 1) + 1)) / float(stride_h) + 1
o_w = (i_w + 2 * pad_w - (dila_w * o_w = (i_w + 2 * pad_w - (dila_w *
(kernel_w - 1) + 1)) / float(stride_w) + 1 (kernel_w - 1) + 1)) / float(stride_w) + 1
return (int(round_func(o_h)), int(round_func(o_w))) o_h = int(round_func(o_h))
o_w = int(round_func(o_w))
def get_strided_kernel_output_shape(params, input_shape, round_func):
o_h, o_w = get_filter_output_shape(input_shape[2], input_shape[3], params,
round_func)
has_c_o = hasattr(params, 'num_output') has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape[1] c = params.num_output if has_c_o else input_shape[1]
return [[input_shape[0], c, o_h, o_w]] return [[input_shape[0], c, o_h, o_w]]
...@@ -176,7 +130,9 @@ def shape_concat(layer, input_shape): ...@@ -176,7 +130,9 @@ def shape_concat(layer, input_shape):
output_shape = None output_shape = None
for shape in input_shape: for shape in input_shape:
if output_shape is None: if output_shape is None:
output_shape = shape output_shape = []
for i in range(len(shape)):
output_shape.append(shape[i])
else: else:
output_shape[axis] += shape[axis] output_shape[axis] += shape[axis]
return [output_shape] return [output_shape]
...@@ -191,7 +147,9 @@ def shape_slice(layer, input_shape): ...@@ -191,7 +147,9 @@ def shape_slice(layer, input_shape):
points = [0] + points + [count] points = [0] + points + [count]
output_shape = [] output_shape = []
for i in range(len(points)): for i in range(len(points)):
shape = inshape shape = []
for ii in range(len(inshape)):
shape.append(inshape[ii])
size = points[i + 1] - points[i] size = points[i + 1] - points[i]
shape[axis] = size shape[axis] = size
output_shape.append(shape) output_shape.append(shape)
...@@ -238,8 +196,8 @@ def shape_reshape(layer, input_shape): ...@@ -238,8 +196,8 @@ def shape_reshape(layer, input_shape):
inshape = input_shape[0] inshape = input_shape[0]
params = layer.reshape_param params = layer.reshape_param
axis = params.axis if hasattr(params, axis) else 0 axis = params.axis if hasattr(params, 'axis') else 0
num_axes = params.num_axes if hasattr(params, num_axes) else -1 num_axes = params.num_axes if hasattr(params, 'num_axes') else -1
if inshape[0] == -1: if inshape[0] == -1:
inshape[0] = 1 inshape[0] = 1
input_count = count(inshape) input_count = count(inshape)
...@@ -262,14 +220,14 @@ def shape_reshape(layer, input_shape): ...@@ -262,14 +220,14 @@ def shape_reshape(layer, input_shape):
num_axes_replaced = end_axis - start_axis num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(shape['dim']) num_new_axes = len(list(params.shape.dim))
outshape = [] outshape = []
for i in range(start_axis): for i in range(start_axis):
outshape.append(inshape[i]) outshape.append(inshape[i])
for i in range(num_new_axes): for i in range(num_new_axes):
outshape.append(shape['dim'][i]) outshape.append(params.shape.dim[i])
for i in range(end_axis, input_num_axes): for i in range(end_axis, input_num_axes):
outshape.append(inshape[i]) outshape.append(inshape[i])
...@@ -281,7 +239,7 @@ def shape_reshape(layer, input_shape): ...@@ -281,7 +239,7 @@ def shape_reshape(layer, input_shape):
copy_axes = [] copy_axes = []
constant_count = 1 constant_count = 1
for i in range(num_new_axes): for i in range(num_new_axes):
top_dim = shape['dim'][i] top_dim = params.shape.dim[i]
if top_dim == 0: if top_dim == 0:
copy_axes.append(i) copy_axes.append(i)
copy_axis_index = start_axis + i copy_axis_index = start_axis + i
...@@ -297,23 +255,19 @@ def shape_reshape(layer, input_shape): ...@@ -297,23 +255,19 @@ def shape_reshape(layer, input_shape):
l = inshape[0:start_axis] l = inshape[0:start_axis]
if len(l) > 0: if len(l) > 0:
explicit_count *= count(l) explicit_count *= count(l)
l = inshape[end_axis:] l = inshape[end_axis:]
if len(l) > 0: if len(l) > 0:
explicit_count *= count(l) explicit_count *= count(l)
for i in range(len(copy_axes)): for i in range(len(copy_axes)):
explicit_count *= outshape[start_axis + copy_axes[i]] explicit_count *= outshape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\ assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\ "must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count) % (input_count, explicit_count)
outshape[start_axis + inferred_axis] = input_count / explicit_count outshape[start_axis + inferred_axis] = int(input_count / explicit_count)
output_count = count(outshape) output_count = count(outshape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % ( assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count) output_count, input_count)
if inshape[0] == -1:
outshape[0] = -1 outshape[0] = -1
return [outshape] return [outshape]
...@@ -345,18 +299,22 @@ def shape_crop(layer, input_shape): ...@@ -345,18 +299,22 @@ def shape_crop(layer, input_shape):
def shape_flatten(layer, input_shape): def shape_flatten(layer, input_shape):
assert len(input_shape) == 1, "the number of flatten's inputs must be 1" assert len(input_shape) == 1, "the number of flatten's inputs must be 1"
inshape = input_shape[0]
params = layer.flatten_param params = layer.flatten_param
start_axis = params.axis start_axis = params.axis
end_axis = params.end_axis end_axis = params.end_axis
if start_axis < 0: if start_axis < 0:
start_axis += len(input_shape[0]) start_axis += len(inshape)
if end_axis < 0: if end_axis < 0:
end_axis += len(input_shape[0]) + 1 end_axis += len(inshape) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\ assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis) % (start_axis, end_axis)
output_shape = [0] * (start_axis - 0) + [ output_shape = inshape[0:start_axis]
-1 if len(inshape[start_axis:end_axis]) != 0:
] + [0] * (len(input_shape[0]) - end_axis) flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis])
output_shape += [flat_sz]
output_shape += inshape[end_axis:len(inshape)]
output_shape[0] = -1
return [output_shape] return [output_shape]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册