提交 ac476692 编写于 作者: S SunAhong1993

fix the bug of shape and add optimizer

上级 fc3bc25f
...@@ -93,10 +93,14 @@ def tf2paddle(model_path, save_dir): ...@@ -93,10 +93,14 @@ def tf2paddle(model_path, save_dir):
def caffe2paddle(proto, weight, save_dir, caffe_proto): def caffe2paddle(proto, weight, save_dir, caffe_proto):
from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
print("Now translating model from caffe to paddle.") print("Now translating model from caffe to paddle.")
model = CaffeDecoder(proto, weight, caffe_proto) model = CaffeDecoder(proto, weight, caffe_proto)
mapper = CaffeOpMapper(model) mapper = CaffeOpMapper(model)
optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale()
optimizer.merge_op_activation()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from x2paddle.core.graph import GraphNode from x2paddle.core.graph import GraphNode
import collections import collections
from x2paddle.core.util import *
class Layer(object): class Layer(object):
...@@ -81,6 +82,8 @@ class Layer(object): ...@@ -81,6 +82,8 @@ class Layer(object):
param_attr = collections.OrderedDict(self.param_attr) param_attr = collections.OrderedDict(self.param_attr)
for key, value in param_attr.items(): for key, value in param_attr.items():
if '\n' in str(value):
value = string(str(value).replace('\n', ','))
layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code + key + "={}, ".format(value)
layer_code = layer_code.strip(", ") layer_code = layer_code.strip(", ")
......
...@@ -63,17 +63,6 @@ class CaffeGraphNode(GraphNode): ...@@ -63,17 +63,6 @@ class CaffeGraphNode(GraphNode):
def set_params(self, params): def set_params(self, params):
self.data = params self.data = params
def set_output_shape(self, input_shape, is_input=True):
func_name = 'shape_' + self.layer_type.lower()
if is_input:
self.output_shape = getattr(caffe_shape, func_name)(self.layer,
input_shape)
else:
self.output_shape = input_shape
def set_input_shape(self, input_shape):
self.input_shape = input_shape
class CaffeGraph(Graph): class CaffeGraph(Graph):
def __init__(self, model, params): def __init__(self, model, params):
......
...@@ -14,6 +14,18 @@ def detectionoutput_layer(inputs, ...@@ -14,6 +14,18 @@ def detectionoutput_layer(inputs,
confidence_threshold=0.1, confidence_threshold=0.1,
input_shape=None, input_shape=None,
name=None): name=None):
nms_param_str = nms_param
nms_param = {}
part = nms_param_str.split(',')
for s in part:
if s == '':
break
else:
name, obj = s.split(': ')
if name == 'top_k':
nms_param[name] = int(obj)
else:
nms_param[name] = float(obj)
if nms_param is None: if nms_param is None:
nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
mbox_conf_flatten = inputs[1] mbox_conf_flatten = inputs[1]
...@@ -24,20 +36,21 @@ def detectionoutput_layer(inputs, ...@@ -24,20 +36,21 @@ def detectionoutput_layer(inputs,
pb = fluid.layers.reshape(x=pb, shape=[-1, 4]) pb = fluid.layers.reshape(x=pb, shape=[-1, 4])
pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4]) pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4])
mbox_loc = inputs[0] mbox_loc = inputs[0]
mbox_loc = fluid.layers.reshape(x=mbox_loc, mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[-1, pb.shape[0], 4])
shape=[-1, mbox_conf_flatten.shape[1], 4]) mbox_conf_flatten = fluid.layers.reshape(x=mbox_conf_flatten,
shape=[0, pb.shape[0], -1])
default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
fields = ['eta', 'top_k', 'nms_threshold'] fields = ['eta', 'top_k', 'nms_threshold']
for f in default.keys(): for f in default.keys():
if not nms_param.has_key(f): if f not in nms_param:
nms_param[f] = default[f] nms_param[f] = default[f]
out = fluid.layers.detection_output( out = fluid.layers.detection_output(
scores=mbox_conf_flatten, scores=mbox_conf_flatten,
loc=mbox_loc, loc=mbox_loc,
prior_box=pb, prior_box=pb,
prior_box_var=pbv, prior_box_var=pbv,
background_label=background_label, background_label=background_label_id,
nms_threshold=nms_param["nms_threshold"], nms_threshold=nms_param["nms_threshold"],
nms_top_k=nms_param["top_k"], nms_top_k=nms_param["top_k"],
keep_top_k=keep_top_k, keep_top_k=keep_top_k,
......
...@@ -3,7 +3,7 @@ from x2paddle.core.util import * ...@@ -3,7 +3,7 @@ from x2paddle.core.util import *
def priorbox_shape(input_shape, max_size=None, aspect_ratio=None): def priorbox_shape(input_shape, max_size=None, aspect_ratio=None):
fc_shape = input_shapes[0] fc_shape = input_shape[0]
N = 1 N = 1
if not max_size == None: if not max_size == None:
N += 1 N += 1
...@@ -18,26 +18,27 @@ def priorbox_layer(inputs, ...@@ -18,26 +18,27 @@ def priorbox_layer(inputs,
step=0.0, step=0.0,
offset=0.5, offset=0.5,
min_size=None, min_size=None,
max_size=None, max_size=[],
aspect_ratio=[1.0], aspect_ratio=[1.0],
flip=False, flip=False,
clip=False, clip=False,
variance=[0.1, 0.1, 0.2, 0.2], variance=[0.1, 0.1, 0.2, 0.2],
input_shape=None, input_shape=None,
name=None): name=None):
input = input_shape[0] input = inputs[0]
image = input_shape[1] image = inputs[1]
steps = tuple(step) if type(step) is list or type(step) is tuple else (step, steps = tuple(step) if type(step) is list or type(step) is tuple else (step,
step) step)
box, variance_ = fluid.layers.prior_box(input, box, variance_ = fluid.layers.prior_box(input,
image, image,
min_sizes=list(min_size), min_sizes=min_size,
max_sizes=list(max_size), max_sizes=max_size,
aspect_ratios=list(aspect_ratio), aspect_ratios=aspect_ratio,
variance=list(variance), variance=variance,
flip=flip, flip=flip,
clip=clip, clip=clip,
steps=step, steps=steps,
offset=offset, offset=offset,
name=name, name=name,
min_max_aspect_ratios_order=True) min_max_aspect_ratios_order=True)
......
...@@ -9,12 +9,12 @@ def shufflechannel_shape(input_shape): ...@@ -9,12 +9,12 @@ def shufflechannel_shape(input_shape):
def shufflechannel_layer(inputs, group=None, input_shape=None, name=None): def shufflechannel_layer(inputs, group=None, input_shape=None, name=None):
input = inputs[0] input = inputs[0]
c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1) c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1)
size = int(input_shape[0][1]/group) size = int(input_shape[0][1] / group)
new_c_fm = [] new_c_fm = []
for i in range(size): for i in range(size):
for j in range(group): for j in range(group):
new_c_fm.append(c_fm[j * size + i]) new_c_fm.append(c_fm[j * size + i])
out = fluid.layers.concat(new_c_fm, axis = 1) out = fluid.layers.concat(new_c_fm, axis=1)
return out return out
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
from x2paddle.decoder.caffe_decoder import CaffeGraph from x2paddle.decoder.caffe_decoder import CaffeGraph
from x2paddle.core.op_mapper import OpMapper from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import * from x2paddle.core.util import *
from x2paddle.op_mapper import caffe_shape
from x2paddle.op_mapper.caffe_custom_layer import * from x2paddle.op_mapper.caffe_custom_layer import *
...@@ -33,11 +34,11 @@ class CaffeOpMapper(OpMapper): ...@@ -33,11 +34,11 @@ class CaffeOpMapper(OpMapper):
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if hasattr(self, op): if hasattr(self, op):
self.set_shape(node) self.set_node_shape(node)
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
elif op in custom_layers: elif op in custom_layers:
self.set_shape(node, is_fluid_op=False) self.set_node_shape(node, is_fluid_op=False)
self.deal_custom_layer(node) self.deal_custom_layer(node)
else: else:
raise Exception("Model are not supported yet.") raise Exception("Model are not supported yet.")
...@@ -58,7 +59,7 @@ class CaffeOpMapper(OpMapper): ...@@ -58,7 +59,7 @@ class CaffeOpMapper(OpMapper):
print(op) print(op)
return False return False
def set_shape(self, node, is_fluid_op=True): def set_node_shape(self, node, is_fluid_op=True):
inputs = node.inputs inputs = node.inputs
input_shape = [] input_shape = []
for i, nm in enumerate(inputs): for i, nm in enumerate(inputs):
...@@ -66,12 +67,15 @@ class CaffeOpMapper(OpMapper): ...@@ -66,12 +67,15 @@ class CaffeOpMapper(OpMapper):
tmp = node.layer.bottom[i] tmp = node.layer.bottom[i]
idx = list(last_node.layer.top).index(tmp) idx = list(last_node.layer.top).index(tmp)
input_shape.append(last_node.output_shape[idx]) input_shape.append(last_node.output_shape[idx])
node.set_input_shape(input_shape)
node.input_shape = input_shape
func_name = 'shape_' + node.layer_type.lower()
if is_fluid_op: if is_fluid_op:
node.set_output_shape(input_shape) node.output_shape = getattr(caffe_shape, func_name)(node.layer,
input_shape)
else: else:
node.set_output_shape(compute_output_shape(node), node.output_shape = compute_output_shape(node)
is_input=is_fluid_op)
def adjust_parameters(self, node): def adjust_parameters(self, node):
data = node.data data = node.data
...@@ -87,8 +91,6 @@ class CaffeOpMapper(OpMapper): ...@@ -87,8 +91,6 @@ class CaffeOpMapper(OpMapper):
squeeze_indices.append(0) # Squeeze FC. squeeze_indices.append(0) # Squeeze FC.
for idx in squeeze_indices: for idx in squeeze_indices:
print('Transform the weights of {}...'.format(node.layer_name +
str(idx)))
if idx >= len(data): if idx >= len(data):
continue continue
...@@ -140,7 +142,7 @@ class CaffeOpMapper(OpMapper): ...@@ -140,7 +142,7 @@ class CaffeOpMapper(OpMapper):
dila_h = dila_w = 1 dila_h = dila_w = 1
group = 1 group = 1
c_o = 1 c_o = 1
if kind in ['Convolution', 'Deconvolution', 'ConvolutionDepthwise']: if kind in ['Convolution', 'Deconvolution']:
c_o = params.num_output c_o = params.num_output
dila_len = len(params.dilation) dila_len = len(params.dilation)
if dila_len == 2: if dila_len == 2:
...@@ -165,12 +167,6 @@ class CaffeOpMapper(OpMapper): ...@@ -165,12 +167,6 @@ class CaffeOpMapper(OpMapper):
else: else:
return node.layer_name return node.layer_name
def is_BN(self, node):
return True if node.layer_type == 'BatchNorm' else False
def is_Scale(self, node):
return True if node.layer_type == 'Scale' else False
def Input(self, node): def Input(self, node):
shape = list(node.layer.input_param.shape[0].dim)[1:] shape = list(node.layer.input_param.shape[0].dim)[1:]
dtype = 'float32' dtype = 'float32'
...@@ -198,10 +194,6 @@ class CaffeOpMapper(OpMapper): ...@@ -198,10 +194,6 @@ class CaffeOpMapper(OpMapper):
assert len(node.inputs assert len(node.inputs
) == 1, 'The count of Convolution node\'s input is not 1.' ) == 1, 'The count of Convolution node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'filter_size': 'filter_size':
...@@ -242,10 +234,6 @@ class CaffeOpMapper(OpMapper): ...@@ -242,10 +234,6 @@ class CaffeOpMapper(OpMapper):
assert len(node.inputs assert len(node.inputs
) == 1, 'The count of Deconvolution node\'s input is not 1.' ) == 1, 'The count of Deconvolution node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'output_size': 'output_size':
None, None,
...@@ -287,10 +275,6 @@ class CaffeOpMapper(OpMapper): ...@@ -287,10 +275,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Pooling node\'s input is not 1.' node.inputs) == 1, 'The count of Pooling node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'pool_size': kernel, 'pool_size': kernel,
'pool_stride': stride, 'pool_stride': stride,
...@@ -310,10 +294,6 @@ class CaffeOpMapper(OpMapper): ...@@ -310,10 +294,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of ReLU node\'s input is not 1.' node.inputs) == 1, 'The count of ReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)} attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("relu", node.fluid_code.add_layer("relu",
inputs=input, inputs=input,
...@@ -331,10 +311,6 @@ class CaffeOpMapper(OpMapper): ...@@ -331,10 +311,6 @@ class CaffeOpMapper(OpMapper):
# We'll account for that here. # We'll account for that here.
alpha = params.alpha / float(params.local_size) alpha = params.alpha / float(params.local_size)
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'n': params.local_size, 'n': params.local_size,
'k': 1.0, 'k': 1.0,
...@@ -370,10 +346,6 @@ class CaffeOpMapper(OpMapper): ...@@ -370,10 +346,6 @@ class CaffeOpMapper(OpMapper):
assert params.axis == 1 assert params.axis == 1
assert params.bias_term == True assert params.bias_term == True
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = { attr = {
'size': 'size':
params.num_output, params.num_output,
...@@ -395,10 +367,6 @@ class CaffeOpMapper(OpMapper): ...@@ -395,10 +367,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Softmax node\'s input is not 1.' node.inputs) == 1, 'The count of Softmax node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.softmax_param params = node.layer.softmax_param
axis = params.axis axis = params.axis
shape = node.input_shape[0] shape = node.input_shape[0]
...@@ -414,10 +382,6 @@ class CaffeOpMapper(OpMapper): ...@@ -414,10 +382,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Slice node\'s input is not 1.' node.inputs) == 1, 'The count of Slice node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.slice_param params = node.layer.slice_param
axis = params.axis axis = params.axis
points = list(params.slice_point) points = list(params.slice_point)
...@@ -448,10 +412,6 @@ class CaffeOpMapper(OpMapper): ...@@ -448,10 +412,6 @@ class CaffeOpMapper(OpMapper):
inputs = [] inputs = []
for i in range(len(node.inputs)): for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_bottom_node(node, idx=i, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
inputs.append(input) inputs.append(input)
params = node.layer.concat_param params = node.layer.concat_param
axis = params.axis axis = params.axis
...@@ -465,10 +425,6 @@ class CaffeOpMapper(OpMapper): ...@@ -465,10 +425,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.prelu_param params = node.layer.prelu_param
mode_bool = params.channel_shared mode_bool = params.channel_shared
if mode_bool: if mode_bool:
...@@ -493,10 +449,6 @@ class CaffeOpMapper(OpMapper): ...@@ -493,10 +449,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)} attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("sigmoid", node.fluid_code.add_layer("sigmoid",
inputs=input, inputs=input,
...@@ -507,10 +459,6 @@ class CaffeOpMapper(OpMapper): ...@@ -507,10 +459,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' node.inputs) == 1, 'The count of PReLU node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)} attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("absval", node.fluid_code.add_layer("absval",
inputs=input, inputs=input,
...@@ -527,24 +475,15 @@ class CaffeOpMapper(OpMapper): ...@@ -527,24 +475,15 @@ class CaffeOpMapper(OpMapper):
for shape in node.input_shape: for shape in node.input_shape:
if shape[1] == 1: if shape[1] == 1:
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_bottom_node(node, idx=i, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
inputs[1] = input inputs[1] = input
else: else:
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_bottom_node(node, idx=i, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
inputs[0] = input inputs[0] = input
i += 1 i += 1
params = node.layer.accuracy_param params = node.layer.accuracy_param
top_k = params.top_k top_k = params.top_k
axis = params.axis axis = params.axis
ignore_label = params.ignore_label ignore_label = params.ignore_label
# TODO(syf)
assert axis == 1, 'PaddlePaddle can not support the situation when the axis is not 1.' assert axis == 1, 'PaddlePaddle can not support the situation when the axis is not 1.'
assert not ignore_label >= 0, 'PaddlePaddle can not support the situation when the model has ignore label.' assert not ignore_label >= 0, 'PaddlePaddle can not support the situation when the model has ignore label.'
attr = {'k': top_k} attr = {'k': top_k}
...@@ -557,10 +496,6 @@ class CaffeOpMapper(OpMapper): ...@@ -557,10 +496,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of TanH node\'s input is not 1.' node.inputs) == 1, 'The count of TanH node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
attr = {'name': string(node.layer_name)} attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer("tanh", node.fluid_code.add_layer("tanh",
inputs=input, inputs=input,
...@@ -574,16 +509,8 @@ class CaffeOpMapper(OpMapper): ...@@ -574,16 +509,8 @@ class CaffeOpMapper(OpMapper):
mode = params.operation mode = params.operation
inputs = [] inputs = []
input0 = self.graph.get_bottom_node(node, idx=0, copy=True) input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input0):
tmp = self.graph.get_bottom_node(input0, idx=0, copy=True)
if self.is_BN(tmp):
input0 = tmp
inputs.append(input0) inputs.append(input0)
input1 = self.graph.get_bottom_node(node, idx=1, copy=True) input1 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
inputs.append(input1) inputs.append(input1)
if mode == 0: if mode == 0:
inputs_dict = {} inputs_dict = {}
...@@ -660,10 +587,6 @@ class CaffeOpMapper(OpMapper): ...@@ -660,10 +587,6 @@ class CaffeOpMapper(OpMapper):
node.outputs node.outputs
) == 1, 'The count of BatchNorm node\'s input and output is not 1.' ) == 1, 'The count of BatchNorm node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.batch_norm_param params = node.layer.batch_norm_param
if hasattr(params, 'eps'): if hasattr(params, 'eps'):
eps = params.eps eps = params.eps
...@@ -678,133 +601,96 @@ class CaffeOpMapper(OpMapper): ...@@ -678,133 +601,96 @@ class CaffeOpMapper(OpMapper):
variance *= scaling_factor variance *= scaling_factor
self.weights[node.layer_name + '_mean'] = mean self.weights[node.layer_name + '_mean'] = mean
self.weights[node.layer_name + '_variance'] = variance self.weights[node.layer_name + '_variance'] = variance
if self.graph.get_node(node.outputs[0]).layer_type == 'Scale': attr = {
data = self.graph.get_node(node.outputs[0]).data 'is_test': True,
self.weights[node.layer_name + '_scale'] = np.squeeze(data[0]) 'param_attr': None,
self.weights[node.layer_name + '_offset'] = np.squeeze(data[1]) 'bias_attr': None,
attr = { 'moving_mean_name': string(node.layer_name + '_mean'),
'is_test': True, 'moving_variance_name': string(node.layer_name + '_variance'),
'param_attr': string(node.layer_name + '_scale'), 'epsilon': eps,
'bias_attr': string(node.layer_name + '_offset'), 'name': string(node.layer_name)
'moving_mean_name': string(node.layer_name + '_mean'), }
'moving_variance_name': string(node.layer_name + '_variance'),
'epsilon': eps,
'name': string(node.layer_name)
}
else:
attr = {
'is_test': True,
'param_attr': None,
'bias_attr': None,
'moving_mean_name': string(node.layer_name + '_mean'),
'moving_variance_name': string(node.layer_name + '_variance'),
'epsilon': eps,
'name': string(node.layer_name)
}
node.fluid_code.add_layer("batch_norm", node.fluid_code.add_layer("batch_norm",
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
def Scale(self, node): def Scale(self, node):
assert len(
node.inputs) == 1, 'The count of Scale node\'s input is not 1.' self.weights[node.layer_name + '_scale'] = np.squeeze(node.data[0])
if len(node.inputs) == 1 and self.graph.get_node( self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1])
node.inputs[0]).layer_type == 'BatchNorm': params = node.layer.scale_param
return axis = params.axis
num_axes = params.num_axes
inputs = []
if len(node.inputs) == 2:
# for two tensor, here resets axis to 1. Maybe there is a bug for unkown case.
axis = 1
bias_shape = node.input_shape[0][axis:axis + num_axes]
input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
input1 = self.graph.get_bottom_node(node, idx=1, copy=True)
inputs_dict = {}
inputs_dict['x'] = input0
inputs_dict['y'] = input1
attr = {'axis': axis, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs_dict,
output=node.layer_name + '_mul',
param_attr=attr)
else: else:
self.weights[node.layer_name + '_scale'] = np.squeeze(nose.data[0]) bias_shape = node.input_shape[0][axis:axis + num_axes]
self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1]) input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
params = node.layer.scale_param
axis = params.axis
num_axes = params.num_axes
assert num_axes == 1, "layer scale not support this num_axes[%d] now" % (
num_axes)
inputs = []
if len(node.inputs) == 2:
# for two tensor, here resets axis to 1. Maybe there is a bug for unkown case.
axis = 1
bias_shape = node.input_shape[0][axis:axis + num_axes]
input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input0):
tmp = self.graph.get_bottom_node(input0, idx=0, copy=True)
if self.is_BN(tmp):
input0 = tmp
input1 = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(input1):
tmp = self.graph.get_bottom_node(input1, idx=0, copy=True)
if self.is_BN(tmp):
input1 = tmp
inputs.append(input0)
inputs.append(input1)
attr = {'axis': axis, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node.layer_name + '_mul',
param_attr=attr)
else:
bias_shape = node.input_shape[0][axis:axis + num_axes]
input0 = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input0):
tmp = self.graph.get_bottom_node(input0, idx=0, copy=True)
if self.is_BN(tmp):
input0 = tmp
input0_name = self.get_input_name(input0)
attr = {
'dtype': '{}.dtype'.formatr(input0_name),
'shape': bias_shape,
'name': string(node.layer_name + '_cparam1'),
'attr': string(node.layer_name + '_scale'),
'is_bias': True,
'default_initializer': 'Constant(value=1.0)'
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node,
param_attr=attr)
inputs.append(input0)
inputs.append(node)
attr = {'axis': axis, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node.layer_name + '_mul',
param_attr=attr)
scale_shape = bias_shape
input0_name = self.get_input_name(input0) input0_name = self.get_input_name(input0)
attr = { attr = {
'dtype': '{}.dtype'.formatr(input0_name), 'dtype': '{}.dtype'.format(input0_name),
'shape': scale_shape, 'shape': bias_shape,
'name': string(node.layer_name + '_cparam2'), 'name': string(node.layer_name + '_cparam1'),
'attr': string(node.layer_name + '_offset'), 'attr': string(node.layer_name + '_scale'),
'is_bias': True, 'is_bias': True,
'default_initializer': 'Constant(value=1.0)' 'default_initializer': 'Constant(value=1.0)'
} }
node.fluid_code.add_layer("create_parameter", node.fluid_code.add_layer("create_parameter",
inputs=None, inputs=None,
output=node.layer_name + '_offset_param',
param_attr=attr)
attr = {'axis': axis, 'name': string(node.layer_name + '_add')}
node.fluid_code.add_layer("elementwise_add",
inputs='{}_mul, {}_offset_param'.format(
node.layer_name, node.layer_name),
output=node, output=node,
param_attr=attr) param_attr=attr)
inputs_dict = {}
inputs_dict['x'] = input0
inputs_dict['y'] = node
attr = {'axis': axis, 'name': string(node.layer_name + '_mul')}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs_dict,
output=node.layer_name + '_mul',
param_attr=attr)
scale_shape = bias_shape
input0_name = self.get_input_name(input0)
attr = {
'dtype': '{}.dtype'.format(input0_name),
'shape': scale_shape,
'name': string(node.layer_name + '_cparam2'),
'attr': string(node.layer_name + '_offset'),
'is_bias': True,
'default_initializer': 'Constant(value=1.0)'
}
node.fluid_code.add_layer("create_parameter",
inputs=None,
output=node.layer_name + '_offset_param',
param_attr=attr)
attr = {'axis': axis, 'name': string(node.layer_name + '_add')}
node.fluid_code.add_layer("elementwise_add",
inputs='{}_mul, {}_offset_param'.format(
node.layer_name, node.layer_name),
output=node,
param_attr=attr)
def Reshape(self, node): def Reshape(self, node):
assert len(node.inputs) == 1 and len(
node.outputs
) == 1, 'The count of Reshape node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
top_count = len(input.layer.top) top_count = len(input.layer.top)
if self.is_Scale(input): is_inplace = False if top_count == 1 else True
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
is_inplace, = False if top_count == 1 else True
output_shape = node.output_shape[0] output_shape = node.output_shape[0]
attr = { attr = {
'shape': output_shape, 'shape': output_shape,
'inplace': is_inplace, 'inplace': is_inplace,
'act': None,
'name': string(node.layer_name) 'name': string(node.layer_name)
} }
node.fluid_code.add_layer("reshape", node.fluid_code.add_layer("reshape",
...@@ -817,10 +703,6 @@ class CaffeOpMapper(OpMapper): ...@@ -817,10 +703,6 @@ class CaffeOpMapper(OpMapper):
node.outputs node.outputs
) == 1, 'The count of ArgMax node\'s input and output is not 1.' ) == 1, 'The count of ArgMax node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
input_shape = node.input_shape[0] input_shape = node.input_shape[0]
params = node.layer.argmax_param params = node.layer.argmax_param
out_max_val = params.out_max_val if hasattr(params, out_max_val = params.out_max_val if hasattr(params,
...@@ -859,15 +741,7 @@ class CaffeOpMapper(OpMapper): ...@@ -859,15 +741,7 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 2, 'The count of Crop node\'s input is not 2.' node.inputs) == 2, 'The count of Crop node\'s input is not 2.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
example = self.graph.get_bottom_node(node, idx=1, copy=True) example = self.graph.get_bottom_node(node, idx=1, copy=True)
if self.is_Scale(example):
tmp = self.graph.get_bottom_node(example, idx=0, copy=True)
if self.is_BN(tmp):
example = tmp
params = node.layer.crop_param params = node.layer.crop_param
axis = parmas.axis axis = parmas.axis
input_shape = node.input_shape[0] input_shape = node.input_shape[0]
...@@ -893,10 +767,6 @@ class CaffeOpMapper(OpMapper): ...@@ -893,10 +767,6 @@ class CaffeOpMapper(OpMapper):
node.inputs node.inputs
) == 1, 'The count of DetectionOutput node\'s input is not 1.' ) == 1, 'The count of DetectionOutput node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
shape = node.output_shape[0] shape = node.output_shape[0]
attr = {'shape': shape, 'name': string(node.layer_name)} attr = {'shape': shape, 'name': string(node.layer_name)}
node.fluid_code.add_layer("reshape", node.fluid_code.add_layer("reshape",
...@@ -908,10 +778,6 @@ class CaffeOpMapper(OpMapper): ...@@ -908,10 +778,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Permute node\'s input is not 1.' node.inputs) == 1, 'The count of Permute node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.power_param params = node.layer.power_param
power = params.power power = params.power
scale = params.scale scale = params.scale
...@@ -936,10 +802,6 @@ class CaffeOpMapper(OpMapper): ...@@ -936,10 +802,6 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Reduction node\'s input is not 1.' node.inputs) == 1, 'The count of Reduction node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
params = node.layer.reduction_param params = node.layer.reduction_param
operation = params.operation operation = params.operation
axis = params.axis axis = params.axis
...@@ -1022,10 +884,6 @@ class CaffeOpMapper(OpMapper): ...@@ -1022,10 +884,6 @@ class CaffeOpMapper(OpMapper):
inputs_node = [] inputs_node = []
for i in range(len(node.inputs)): for i in range(len(node.inputs)):
input = self.graph.get_bottom_node(node, idx=i, copy=True) input = self.graph.get_bottom_node(node, idx=i, copy=True)
if self.is_Scale(input):
tmp = self.graph.get_bottom_node(input, idx=0, copy=True)
if self.is_BN(tmp):
input = tmp
inputs_node.append(input) inputs_node.append(input)
node.fluid_code.add_layer(func.__code__.co_name, node.fluid_code.add_layer(func.__code__.co_name,
inputs=inputs_node, inputs=inputs_node,
......
...@@ -13,104 +13,58 @@ ...@@ -13,104 +13,58 @@
# limitations under the License. # limitations under the License.
import math import math
import numbers
from functools import reduce
def get_params_w_h(params):
def get_kernel_parameters(params):
[k_h, k_w] = [1, 1]
if isinstance(params.kernel_size, numbers.Number):
[k_h, k_w] = [params.kernel_size] * 2
elif len(params.kernel_size) > 0:
k_h = params.kernel_h if params.kernel_h else params.kernel_size[0]
k_w = params.kernel_w if params.kernel_w else params.kernel_size[
len(params.kernel_size) - 1]
[s_h, s_w] = [1, 1]
if isinstance(params.stride, numbers.Number):
[s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h else params.stride[0]
s_w = params.stride_w if params.stride_w else params.stride[
len(params.stride) - 1]
[p_h, p_w] = [0, 0]
if isinstance(params.pad, numbers.Number):
[p_h, p_w] = [params.pad] * 2
elif len(params.pad) > 0:
p_h = params.pad_h if params.pad_h else params.pad[0]
p_w = params.pad_w if params.pad_w else params.pad[len(params.pad) - 1]
dila_h = dila_w = 1
if hasattr(params, 'dilation'): if hasattr(params, 'dilation'):
if len(params.dilation) == 0: dila_len = len(params.dilation)
dila_h = 1 if dila_len == 2:
dila_w = 1
elif len(params.dilation) == 1:
dila_h = params.dilation[0]
dila_w = params.dilation[0]
else:
dila_h = params.dilation[0] dila_h = params.dilation[0]
dila_w = params.dilation[1] dila_w = params.dilation[1]
else: elif dila_len == 1:
dila_h = 1 dila_h = dila_w = params.dilation[0]
dila_w = 1
if not isinstance(getattr(params, 'pad'), int):
if len(params.pad) == 0:
pad_h = 0
pad_w = 0
elif len(params.pad) == 1:
pad_h = params.pad[0]
pad_w = params.pad[0]
else:
pad_h, pad_w, = params.pad[0]
pad_w = params.pad[1]
if params.pad_h != 0 or params.pad_w != 0:
pad_h = params.pad_h
pad_w = params.pad_w
else:
if params.pad_h != 0 or params.pad_w != 0:
pad_h = params.pad_h
pad_w = params.pad_w
else:
pad_h = getattr(params, 'pad')
pad_w = getattr(params, 'pad')
if not isinstance(getattr(params, 'kernel_size'), int):
if len(params.kernel_size) == 0:
kernel_h = 1
kernel_w = 1
elif len(params.kernel_size) == 1:
kernel_h = params.kernel_size[0]
kernel_w = params.kernel_size[0]
else:
kernel_h = params.kernel_size[0]
kernel_w = params.kernel_size[1]
if params.kernel_h != 0 or params.kernel_w != 0:
kernel_h = params.kernel_h
kernel_w = params.kernel_w
else:
if params.kernel_h != 0 or params.kernel_w != 0:
kernel_h = params.kernel_h
kernel_w = params.kernel_w
else:
kernel_h = getattr(params, 'kernel_size')
kernel_w = getattr(params, 'kernel_size')
if not isinstance(getattr(params, 'stride'), int):
if len(params.stride) == 0:
stride_h = 1
stride_w = 1
elif len(params.stride) == 1:
stride_h = params.stride[0]
stride_w = params.stride[0]
else:
stride_h = params.stride[0]
stride_w = params.stride[1]
if params.stride_h != 0 or params.stride_w != 0:
stride_h = params.stride_h
stride_w = params.stride_w
else:
if params.stride_h != 0 or params.stride_w != 0:
stride_h = params.stride_h
stride_w = params.stride_w
else: else:
stride_h = getattr(params, 'stride') assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
stride_w = getattr(params, 'stride') dila_len)
return dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w return dila_h, dila_w, p_h, p_w, k_h, k_w, s_h, s_w
def get_filter_output_shape(i_h, i_w, params, round_func): def get_strided_kernel_output_shape(params, input_shape, round_func):
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_params_w_h( i_h = input_shape[2]
i_w = input_shape[3]
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params) params)
o_h = (i_h + 2 * pad_h - (dila_h * o_h = (i_h + 2 * pad_h - (dila_h *
(kernel_h - 1) + 1)) / float(stride_h) + 1 (kernel_h - 1) + 1)) / float(stride_h) + 1
o_w = (i_w + 2 * pad_w - (dila_w * o_w = (i_w + 2 * pad_w - (dila_w *
(kernel_w - 1) + 1)) / float(stride_w) + 1 (kernel_w - 1) + 1)) / float(stride_w) + 1
return (int(round_func(o_h)), int(round_func(o_w))) o_h = int(round_func(o_h))
o_w = int(round_func(o_w))
def get_strided_kernel_output_shape(params, input_shape, round_func):
o_h, o_w = get_filter_output_shape(input_shape[2], input_shape[3], params,
round_func)
has_c_o = hasattr(params, 'num_output') has_c_o = hasattr(params, 'num_output')
c = params.num_output if has_c_o else input_shape[1] c = params.num_output if has_c_o else input_shape[1]
return [[input_shape[0], c, o_h, o_w]] return [[input_shape[0], c, o_h, o_w]]
...@@ -176,7 +130,9 @@ def shape_concat(layer, input_shape): ...@@ -176,7 +130,9 @@ def shape_concat(layer, input_shape):
output_shape = None output_shape = None
for shape in input_shape: for shape in input_shape:
if output_shape is None: if output_shape is None:
output_shape = shape output_shape = []
for i in range(len(shape)):
output_shape.append(shape[i])
else: else:
output_shape[axis] += shape[axis] output_shape[axis] += shape[axis]
return [output_shape] return [output_shape]
...@@ -191,7 +147,9 @@ def shape_slice(layer, input_shape): ...@@ -191,7 +147,9 @@ def shape_slice(layer, input_shape):
points = [0] + points + [count] points = [0] + points + [count]
output_shape = [] output_shape = []
for i in range(len(points)): for i in range(len(points)):
shape = inshape shape = []
for ii in range(len(inshape)):
shape.append(inshape[ii])
size = points[i + 1] - points[i] size = points[i + 1] - points[i]
shape[axis] = size shape[axis] = size
output_shape.append(shape) output_shape.append(shape)
...@@ -238,8 +196,8 @@ def shape_reshape(layer, input_shape): ...@@ -238,8 +196,8 @@ def shape_reshape(layer, input_shape):
inshape = input_shape[0] inshape = input_shape[0]
params = layer.reshape_param params = layer.reshape_param
axis = params.axis if hasattr(params, axis) else 0 axis = params.axis if hasattr(params, 'axis') else 0
num_axes = params.num_axes if hasattr(params, num_axes) else -1 num_axes = params.num_axes if hasattr(params, 'num_axes') else -1
if inshape[0] == -1: if inshape[0] == -1:
inshape[0] = 1 inshape[0] = 1
input_count = count(inshape) input_count = count(inshape)
...@@ -262,14 +220,14 @@ def shape_reshape(layer, input_shape): ...@@ -262,14 +220,14 @@ def shape_reshape(layer, input_shape):
num_axes_replaced = end_axis - start_axis num_axes_replaced = end_axis - start_axis
num_axes_retained = input_num_axes - num_axes_replaced num_axes_retained = input_num_axes - num_axes_replaced
num_new_axes = len(shape['dim']) num_new_axes = len(list(params.shape.dim))
outshape = [] outshape = []
for i in range(start_axis): for i in range(start_axis):
outshape.append(inshape[i]) outshape.append(inshape[i])
for i in range(num_new_axes): for i in range(num_new_axes):
outshape.append(shape['dim'][i]) outshape.append(params.shape.dim[i])
for i in range(end_axis, input_num_axes): for i in range(end_axis, input_num_axes):
outshape.append(inshape[i]) outshape.append(inshape[i])
...@@ -281,7 +239,7 @@ def shape_reshape(layer, input_shape): ...@@ -281,7 +239,7 @@ def shape_reshape(layer, input_shape):
copy_axes = [] copy_axes = []
constant_count = 1 constant_count = 1
for i in range(num_new_axes): for i in range(num_new_axes):
top_dim = shape['dim'][i] top_dim = params.shape.dim[i]
if top_dim == 0: if top_dim == 0:
copy_axes.append(i) copy_axes.append(i)
copy_axis_index = start_axis + i copy_axis_index = start_axis + i
...@@ -297,24 +255,20 @@ def shape_reshape(layer, input_shape): ...@@ -297,24 +255,20 @@ def shape_reshape(layer, input_shape):
l = inshape[0:start_axis] l = inshape[0:start_axis]
if len(l) > 0: if len(l) > 0:
explicit_count *= count(l) explicit_count *= count(l)
l = inshape[end_axis:] l = inshape[end_axis:]
if len(l) > 0: if len(l) > 0:
explicit_count *= count(l) explicit_count *= count(l)
for i in range(len(copy_axes)): for i in range(len(copy_axes)):
explicit_count *= outshape[start_axis + copy_axes[i]] explicit_count *= outshape[start_axis + copy_axes[i]]
assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\ assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\
"must be divisible by product of the specified dimensions[%d] "\ "must be divisible by product of the specified dimensions[%d] "\
% (input_count, explicit_count) % (input_count, explicit_count)
outshape[start_axis + inferred_axis] = input_count / explicit_count outshape[start_axis + inferred_axis] = int(input_count / explicit_count)
output_count = count(outshape) output_count = count(outshape)
assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % ( assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % (
output_count, input_count) output_count, input_count)
if inshape[0] == -1: outshape[0] = -1
outshape[0] = -1
return [outshape] return [outshape]
...@@ -345,18 +299,22 @@ def shape_crop(layer, input_shape): ...@@ -345,18 +299,22 @@ def shape_crop(layer, input_shape):
def shape_flatten(layer, input_shape): def shape_flatten(layer, input_shape):
assert len(input_shape) == 1, "the number of flatten's inputs must be 1" assert len(input_shape) == 1, "the number of flatten's inputs must be 1"
inshape = input_shape[0]
params = layer.flatten_param params = layer.flatten_param
start_axis = params.axis start_axis = params.axis
end_axis = params.end_axis end_axis = params.end_axis
if start_axis < 0: if start_axis < 0:
start_axis += len(input_shape[0]) start_axis += len(inshape)
if end_axis < 0: if end_axis < 0:
end_axis += len(input_shape[0]) + 1 end_axis += len(inshape) + 1
assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\ assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\
% (start_axis, end_axis) % (start_axis, end_axis)
output_shape = [0] * (start_axis - 0) + [ output_shape = inshape[0:start_axis]
-1 if len(inshape[start_axis:end_axis]) != 0:
] + [0] * (len(input_shape[0]) - end_axis) flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis])
output_shape += [flat_sz]
output_shape += inshape[end_axis:len(inshape)]
output_shape[0] = -1
return [output_shape] return [output_shape]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册