diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 18accb882b5d4b185816cc0a0538951555830f11..2edbe391c803914253e3356e673d8c60f12ad59f 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -92,10 +92,14 @@ def tf2paddle(model_path, save_dir): def caffe2paddle(proto, weight, save_dir, caffe_proto): from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper + from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer print("Now translating model from caffe to paddle.") model = CaffeDecoder(proto, weight, caffe_proto) mapper = CaffeOpMapper(model) + optimizer = CaffeOptimizer(mapper) + optimizer.merge_bn_scale() + optimizer.merge_op_activation() mapper.save_inference_model(save_dir) diff --git a/x2paddle/core/fluid_code.py b/x2paddle/core/fluid_code.py index f77d88e2eff3115198e8da6c61e622050e2689b8..c971e5a61244371bb51c629b03dad0e9b35e725b 100644 --- a/x2paddle/core/fluid_code.py +++ b/x2paddle/core/fluid_code.py @@ -14,6 +14,7 @@ from x2paddle.core.graph import GraphNode import collections +from x2paddle.core.util import * class Layer(object): @@ -81,6 +82,8 @@ class Layer(object): param_attr = collections.OrderedDict(self.param_attr) for key, value in param_attr.items(): + if '\n' in str(value): + value = string(str(value).replace('\n', ',')) layer_code = layer_code + key + "={}, ".format(value) layer_code = layer_code.strip(", ") diff --git a/x2paddle/decoder/caffe_decoder.py b/x2paddle/decoder/caffe_decoder.py index e4bd86c4833944ea325868f5770b61e836b90843..e9601337f676d476237ec51d992080c3b193f260 100644 --- a/x2paddle/decoder/caffe_decoder.py +++ b/x2paddle/decoder/caffe_decoder.py @@ -63,17 +63,6 @@ class CaffeGraphNode(GraphNode): def set_params(self, params): self.data = params - def set_output_shape(self, input_shape, is_input=True): - func_name = 'shape_' + self.layer_type.lower() - if is_input: - self.output_shape = getattr(caffe_shape, func_name)(self.layer, - input_shape) - else: - self.output_shape = input_shape - - def set_input_shape(self, input_shape): - self.input_shape = input_shape - class CaffeGraph(Graph): def __init__(self, model, params): diff --git a/x2paddle/op_mapper/caffe_custom_layer/detectionoutput.py b/x2paddle/op_mapper/caffe_custom_layer/detectionoutput.py index b87aa937cba3f5a5eaa95d5805efcfe4c810d62e..173f5f31d5f26545a112d11b0994d73097ebb16b 100644 --- a/x2paddle/op_mapper/caffe_custom_layer/detectionoutput.py +++ b/x2paddle/op_mapper/caffe_custom_layer/detectionoutput.py @@ -14,6 +14,18 @@ def detectionoutput_layer(inputs, confidence_threshold=0.1, input_shape=None, name=None): + nms_param_str = nms_param + nms_param = {} + part = nms_param_str.split(',') + for s in part: + if s == '': + break + else: + name, obj = s.split(': ') + if name == 'top_k': + nms_param[name] = int(obj) + else: + nms_param[name] = float(obj) if nms_param is None: nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} mbox_conf_flatten = inputs[1] @@ -24,20 +36,21 @@ def detectionoutput_layer(inputs, pb = fluid.layers.reshape(x=pb, shape=[-1, 4]) pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4]) mbox_loc = inputs[0] - mbox_loc = fluid.layers.reshape(x=mbox_loc, - shape=[-1, mbox_conf_flatten.shape[1], 4]) + mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[-1, pb.shape[0], 4]) + mbox_conf_flatten = fluid.layers.reshape(x=mbox_conf_flatten, + shape=[0, pb.shape[0], -1]) default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0} fields = ['eta', 'top_k', 'nms_threshold'] for f in default.keys(): - if not nms_param.has_key(f): + if f not in nms_param: nms_param[f] = default[f] out = fluid.layers.detection_output( scores=mbox_conf_flatten, loc=mbox_loc, prior_box=pb, prior_box_var=pbv, - background_label=background_label, + background_label=background_label_id, nms_threshold=nms_param["nms_threshold"], nms_top_k=nms_param["top_k"], keep_top_k=keep_top_k, diff --git a/x2paddle/op_mapper/caffe_custom_layer/priorbox.py b/x2paddle/op_mapper/caffe_custom_layer/priorbox.py index 9c1cb67ead9cf338fab2ffec548672ee26574dc3..829c3e365fefbfc81b9e4efb382726f1d473ecab 100644 --- a/x2paddle/op_mapper/caffe_custom_layer/priorbox.py +++ b/x2paddle/op_mapper/caffe_custom_layer/priorbox.py @@ -3,7 +3,7 @@ from x2paddle.core.util import * def priorbox_shape(input_shape, max_size=None, aspect_ratio=None): - fc_shape = input_shapes[0] + fc_shape = input_shape[0] N = 1 if not max_size == None: N += 1 @@ -18,26 +18,27 @@ def priorbox_layer(inputs, step=0.0, offset=0.5, min_size=None, - max_size=None, + max_size=[], aspect_ratio=[1.0], flip=False, clip=False, variance=[0.1, 0.1, 0.2, 0.2], input_shape=None, name=None): - input = input_shape[0] - image = input_shape[1] + input = inputs[0] + image = inputs[1] steps = tuple(step) if type(step) is list or type(step) is tuple else (step, step) + box, variance_ = fluid.layers.prior_box(input, image, - min_sizes=list(min_size), - max_sizes=list(max_size), - aspect_ratios=list(aspect_ratio), - variance=list(variance), + min_sizes=min_size, + max_sizes=max_size, + aspect_ratios=aspect_ratio, + variance=variance, flip=flip, clip=clip, - steps=step, + steps=steps, offset=offset, name=name, min_max_aspect_ratios_order=True) diff --git a/x2paddle/op_mapper/caffe_custom_layer/shufflechannel.py b/x2paddle/op_mapper/caffe_custom_layer/shufflechannel.py index c6321f2d415d16a56eecdfe9e2287616a3fb7f9e..bc56c2239fee4f6c368921bc1297c1a2eb1c07f6 100644 --- a/x2paddle/op_mapper/caffe_custom_layer/shufflechannel.py +++ b/x2paddle/op_mapper/caffe_custom_layer/shufflechannel.py @@ -9,12 +9,12 @@ def shufflechannel_shape(input_shape): def shufflechannel_layer(inputs, group=None, input_shape=None, name=None): input = inputs[0] c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1) - size = int(input_shape[0][1]/group) + size = int(input_shape[0][1] / group) new_c_fm = [] for i in range(size): for j in range(group): new_c_fm.append(c_fm[j * size + i]) - out = fluid.layers.concat(new_c_fm, axis = 1) + out = fluid.layers.concat(new_c_fm, axis=1) return out diff --git a/x2paddle/op_mapper/caffe_op_mapper.py b/x2paddle/op_mapper/caffe_op_mapper.py index 6121edf9df7065a2d910d6bfd2e7aeb1b222c3fc..ffac5bae9806a9f2dd7d130b4be99557a2c6fbfa 100644 --- a/x2paddle/op_mapper/caffe_op_mapper.py +++ b/x2paddle/op_mapper/caffe_op_mapper.py @@ -17,6 +17,7 @@ import numpy as np from x2paddle.decoder.caffe_decoder import CaffeGraph from x2paddle.core.op_mapper import OpMapper from x2paddle.core.util import * +from x2paddle.op_mapper import caffe_shape from x2paddle.op_mapper.caffe_custom_layer import * @@ -33,11 +34,11 @@ class CaffeOpMapper(OpMapper): node = self.graph.get_node(node_name) op = node.layer_type if hasattr(self, op): - self.set_shape(node) + self.set_node_shape(node) func = getattr(self, op) func(node) elif op in custom_layers: - self.set_shape(node, is_fluid_op=False) + self.set_node_shape(node, is_fluid_op=False) self.deal_custom_layer(node) else: raise Exception("Model are not supported yet.") @@ -58,7 +59,7 @@ class CaffeOpMapper(OpMapper): print(op) return False - def set_shape(self, node, is_fluid_op=True): + def set_node_shape(self, node, is_fluid_op=True): inputs = node.inputs input_shape = [] for i, nm in enumerate(inputs): @@ -66,12 +67,15 @@ class CaffeOpMapper(OpMapper): tmp = node.layer.bottom[i] idx = list(last_node.layer.top).index(tmp) input_shape.append(last_node.output_shape[idx]) - node.set_input_shape(input_shape) + + node.input_shape = input_shape + + func_name = 'shape_' + node.layer_type.lower() if is_fluid_op: - node.set_output_shape(input_shape) + node.output_shape = getattr(caffe_shape, func_name)(node.layer, + input_shape) else: - node.set_output_shape(compute_output_shape(node), - is_input=is_fluid_op) + node.output_shape = compute_output_shape(node) def adjust_parameters(self, node): data = node.data @@ -87,8 +91,6 @@ class CaffeOpMapper(OpMapper): squeeze_indices.append(0) # Squeeze FC. for idx in squeeze_indices: - print('Transform the weights of {}...'.format(node.layer_name + - str(idx))) if idx >= len(data): continue @@ -140,7 +142,7 @@ class CaffeOpMapper(OpMapper): dila_h = dila_w = 1 group = 1 c_o = 1 - if kind in ['Convolution', 'Deconvolution', 'ConvolutionDepthwise']: + if kind in ['Convolution', 'Deconvolution']: c_o = params.num_output dila_len = len(params.dilation) if dila_len == 2: @@ -165,12 +167,6 @@ class CaffeOpMapper(OpMapper): else: return node.layer_name - def is_BN(self, node): - return True if node.layer_type == 'BatchNorm' else False - - def is_Scale(self, node): - return True if node.layer_type == 'Scale' else False - def Input(self, node): shape = list(node.layer.input_param.shape[0].dim)[1:] dtype = 'float32' @@ -198,10 +194,6 @@ class CaffeOpMapper(OpMapper): assert len(node.inputs ) == 1, 'The count of Convolution node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = { 'filter_size': @@ -242,10 +234,6 @@ class CaffeOpMapper(OpMapper): assert len(node.inputs ) == 1, 'The count of Deconvolution node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = { 'output_size': None, @@ -287,10 +275,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of Pooling node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = { 'pool_size': kernel, 'pool_stride': stride, @@ -310,10 +294,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of ReLU node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = {'name': string(node.layer_name)} node.fluid_code.add_layer("relu", inputs=input, @@ -331,10 +311,6 @@ class CaffeOpMapper(OpMapper): # We'll account for that here. alpha = params.alpha / float(params.local_size) input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = { 'n': params.local_size, 'k': 1.0, @@ -370,10 +346,6 @@ class CaffeOpMapper(OpMapper): assert params.axis == 1 assert params.bias_term == True input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = { 'size': params.num_output, @@ -395,10 +367,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of Softmax node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp params = node.layer.softmax_param axis = params.axis shape = node.input_shape[0] @@ -414,10 +382,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of Slice node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp params = node.layer.slice_param axis = params.axis points = list(params.slice_point) @@ -448,10 +412,6 @@ class CaffeOpMapper(OpMapper): inputs = [] for i in range(len(node.inputs)): input = self.graph.get_bottom_node(node, idx=i, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp inputs.append(input) params = node.layer.concat_param axis = params.axis @@ -465,10 +425,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp params = node.layer.prelu_param mode_bool = params.channel_shared if mode_bool: @@ -493,10 +449,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = {'name': string(node.layer_name)} node.fluid_code.add_layer("sigmoid", inputs=input, @@ -507,10 +459,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = {'name': string(node.layer_name)} node.fluid_code.add_layer("absval", inputs=input, @@ -527,24 +475,15 @@ class CaffeOpMapper(OpMapper): for shape in node.input_shape: if shape[1] == 1: input = self.graph.get_bottom_node(node, idx=i, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp inputs[1] = input else: input = self.graph.get_bottom_node(node, idx=i, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp inputs[0] = input i += 1 params = node.layer.accuracy_param top_k = params.top_k axis = params.axis ignore_label = params.ignore_label - # TODO(syf) assert axis == 1, 'PaddlePaddle can not support the situation when the axis is not 1.' assert not ignore_label >= 0, 'PaddlePaddle can not support the situation when the model has ignore label.' attr = {'k': top_k} @@ -557,10 +496,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of TanH node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp attr = {'name': string(node.layer_name)} node.fluid_code.add_layer("tanh", inputs=input, @@ -574,16 +509,8 @@ class CaffeOpMapper(OpMapper): mode = params.operation inputs = [] input0 = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input0): - tmp = self.graph.get_bottom_node(input0, idx=0, copy=True) - if self.is_BN(tmp): - input0 = tmp inputs.append(input0) input1 = self.graph.get_bottom_node(node, idx=1, copy=True) - if self.is_Scale(input1): - tmp = self.graph.get_bottom_node(input1, idx=0, copy=True) - if self.is_BN(tmp): - input1 = tmp inputs.append(input1) if mode == 0: inputs_dict = {} @@ -660,10 +587,6 @@ class CaffeOpMapper(OpMapper): node.outputs ) == 1, 'The count of BatchNorm node\'s input and output is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp params = node.layer.batch_norm_param if hasattr(params, 'eps'): eps = params.eps @@ -678,133 +601,96 @@ class CaffeOpMapper(OpMapper): variance *= scaling_factor self.weights[node.layer_name + '_mean'] = mean self.weights[node.layer_name + '_variance'] = variance - if self.graph.get_node(node.outputs[0]).layer_type == 'Scale': - data = self.graph.get_node(node.outputs[0]).data - self.weights[node.layer_name + '_scale'] = np.squeeze(data[0]) - self.weights[node.layer_name + '_offset'] = np.squeeze(data[1]) - attr = { - 'is_test': True, - 'param_attr': string(node.layer_name + '_scale'), - 'bias_attr': string(node.layer_name + '_offset'), - 'moving_mean_name': string(node.layer_name + '_mean'), - 'moving_variance_name': string(node.layer_name + '_variance'), - 'epsilon': eps, - 'name': string(node.layer_name) - } - else: - attr = { - 'is_test': True, - 'param_attr': None, - 'bias_attr': None, - 'moving_mean_name': string(node.layer_name + '_mean'), - 'moving_variance_name': string(node.layer_name + '_variance'), - 'epsilon': eps, - 'name': string(node.layer_name) - } + attr = { + 'is_test': True, + 'param_attr': None, + 'bias_attr': None, + 'moving_mean_name': string(node.layer_name + '_mean'), + 'moving_variance_name': string(node.layer_name + '_variance'), + 'epsilon': eps, + 'name': string(node.layer_name) + } node.fluid_code.add_layer("batch_norm", inputs=input, output=node, param_attr=attr) def Scale(self, node): - assert len( - node.inputs) == 1, 'The count of Scale node\'s input is not 1.' - if len(node.inputs) == 1 and self.graph.get_node( - node.inputs[0]).layer_type == 'BatchNorm': - return + + self.weights[node.layer_name + '_scale'] = np.squeeze(node.data[0]) + self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1]) + params = node.layer.scale_param + axis = params.axis + num_axes = params.num_axes + inputs = [] + if len(node.inputs) == 2: + # for two tensor, here resets axis to 1. Maybe there is a bug for unkown case. + axis = 1 + bias_shape = node.input_shape[0][axis:axis + num_axes] + input0 = self.graph.get_bottom_node(node, idx=0, copy=True) + input1 = self.graph.get_bottom_node(node, idx=1, copy=True) + inputs_dict = {} + inputs_dict['x'] = input0 + inputs_dict['y'] = input1 + attr = {'axis': axis, 'name': string(node.layer_name + '_mul')} + node.fluid_code.add_layer("elementwise_mul", + inputs=inputs_dict, + output=node.layer_name + '_mul', + param_attr=attr) else: - self.weights[node.layer_name + '_scale'] = np.squeeze(nose.data[0]) - self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1]) - params = node.layer.scale_param - axis = params.axis - num_axes = params.num_axes - assert num_axes == 1, "layer scale not support this num_axes[%d] now" % ( - num_axes) - inputs = [] - if len(node.inputs) == 2: - # for two tensor, here resets axis to 1. Maybe there is a bug for unkown case. - axis = 1 - bias_shape = node.input_shape[0][axis:axis + num_axes] - input0 = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input0): - tmp = self.graph.get_bottom_node(input0, idx=0, copy=True) - if self.is_BN(tmp): - input0 = tmp - input1 = self.graph.get_bottom_node(node, idx=1, copy=True) - if self.is_Scale(input1): - tmp = self.graph.get_bottom_node(input1, idx=0, copy=True) - if self.is_BN(tmp): - input1 = tmp - inputs.append(input0) - inputs.append(input1) - attr = {'axis': axis, 'name': string(node.layer_name + '_mul')} - node.fluid_code.add_layer("elementwise_mul", - inputs=inputs, - output=node.layer_name + '_mul', - param_attr=attr) - else: - bias_shape = node.input_shape[0][axis:axis + num_axes] - input0 = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input0): - tmp = self.graph.get_bottom_node(input0, idx=0, copy=True) - if self.is_BN(tmp): - input0 = tmp - input0_name = self.get_input_name(input0) - attr = { - 'dtype': '{}.dtype'.formatr(input0_name), - 'shape': bias_shape, - 'name': string(node.layer_name + '_cparam1'), - 'attr': string(node.layer_name + '_scale'), - 'is_bias': True, - 'default_initializer': 'Constant(value=1.0)' - } - node.fluid_code.add_layer("create_parameter", - inputs=None, - output=node, - param_attr=attr) - inputs.append(input0) - inputs.append(node) - attr = {'axis': axis, 'name': string(node.layer_name + '_mul')} - node.fluid_code.add_layer("elementwise_mul", - inputs=inputs, - output=node.layer_name + '_mul', - param_attr=attr) - scale_shape = bias_shape + bias_shape = node.input_shape[0][axis:axis + num_axes] + input0 = self.graph.get_bottom_node(node, idx=0, copy=True) input0_name = self.get_input_name(input0) attr = { - 'dtype': '{}.dtype'.formatr(input0_name), - 'shape': scale_shape, - 'name': string(node.layer_name + '_cparam2'), - 'attr': string(node.layer_name + '_offset'), + 'dtype': '{}.dtype'.format(input0_name), + 'shape': bias_shape, + 'name': string(node.layer_name + '_cparam1'), + 'attr': string(node.layer_name + '_scale'), 'is_bias': True, 'default_initializer': 'Constant(value=1.0)' } node.fluid_code.add_layer("create_parameter", inputs=None, - output=node.layer_name + '_offset_param', - param_attr=attr) - attr = {'axis': axis, 'name': string(node.layer_name + '_add')} - node.fluid_code.add_layer("elementwise_add", - inputs='{}_mul, {}_offset_param'.format( - node.layer_name, node.layer_name), output=node, param_attr=attr) + inputs_dict = {} + inputs_dict['x'] = input0 + inputs_dict['y'] = node + attr = {'axis': axis, 'name': string(node.layer_name + '_mul')} + node.fluid_code.add_layer("elementwise_mul", + inputs=inputs_dict, + output=node.layer_name + '_mul', + param_attr=attr) + scale_shape = bias_shape + input0_name = self.get_input_name(input0) + attr = { + 'dtype': '{}.dtype'.format(input0_name), + 'shape': scale_shape, + 'name': string(node.layer_name + '_cparam2'), + 'attr': string(node.layer_name + '_offset'), + 'is_bias': True, + 'default_initializer': 'Constant(value=1.0)' + } + node.fluid_code.add_layer("create_parameter", + inputs=None, + output=node.layer_name + '_offset_param', + param_attr=attr) + attr = {'axis': axis, 'name': string(node.layer_name + '_add')} + node.fluid_code.add_layer("elementwise_add", + inputs='{}_mul, {}_offset_param'.format( + node.layer_name, node.layer_name), + output=node, + param_attr=attr) def Reshape(self, node): - assert len(node.inputs) == 1 and len( - node.outputs - ) == 1, 'The count of Reshape node\'s input and output is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) top_count = len(input.layer.top) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp - is_inplace, = False if top_count == 1 else True + is_inplace = False if top_count == 1 else True output_shape = node.output_shape[0] attr = { 'shape': output_shape, 'inplace': is_inplace, + 'act': None, 'name': string(node.layer_name) } node.fluid_code.add_layer("reshape", @@ -817,10 +703,6 @@ class CaffeOpMapper(OpMapper): node.outputs ) == 1, 'The count of ArgMax node\'s input and output is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp input_shape = node.input_shape[0] params = node.layer.argmax_param out_max_val = params.out_max_val if hasattr(params, @@ -859,15 +741,7 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 2, 'The count of Crop node\'s input is not 2.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp example = self.graph.get_bottom_node(node, idx=1, copy=True) - if self.is_Scale(example): - tmp = self.graph.get_bottom_node(example, idx=0, copy=True) - if self.is_BN(tmp): - example = tmp params = node.layer.crop_param axis = parmas.axis input_shape = node.input_shape[0] @@ -893,10 +767,6 @@ class CaffeOpMapper(OpMapper): node.inputs ) == 1, 'The count of DetectionOutput node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp shape = node.output_shape[0] attr = {'shape': shape, 'name': string(node.layer_name)} node.fluid_code.add_layer("reshape", @@ -908,10 +778,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of Permute node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp params = node.layer.power_param power = params.power scale = params.scale @@ -936,10 +802,6 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of Reduction node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp params = node.layer.reduction_param operation = params.operation axis = params.axis @@ -1022,10 +884,6 @@ class CaffeOpMapper(OpMapper): inputs_node = [] for i in range(len(node.inputs)): input = self.graph.get_bottom_node(node, idx=i, copy=True) - if self.is_Scale(input): - tmp = self.graph.get_bottom_node(input, idx=0, copy=True) - if self.is_BN(tmp): - input = tmp inputs_node.append(input) node.fluid_code.add_layer(func.__code__.co_name, inputs=inputs_node, diff --git a/x2paddle/op_mapper/caffe_shape.py b/x2paddle/op_mapper/caffe_shape.py index 9fd89fe7ad95a49f588888e448d45a74f18d6fc2..a38d3045a74c23de37cb9123dfc2654c21de9467 100644 --- a/x2paddle/op_mapper/caffe_shape.py +++ b/x2paddle/op_mapper/caffe_shape.py @@ -13,104 +13,58 @@ # limitations under the License. import math - - -def get_params_w_h(params): +import numbers +from functools import reduce + + +def get_kernel_parameters(params): + [k_h, k_w] = [1, 1] + if isinstance(params.kernel_size, numbers.Number): + [k_h, k_w] = [params.kernel_size] * 2 + elif len(params.kernel_size) > 0: + k_h = params.kernel_h if params.kernel_h else params.kernel_size[0] + k_w = params.kernel_w if params.kernel_w else params.kernel_size[ + len(params.kernel_size) - 1] + [s_h, s_w] = [1, 1] + if isinstance(params.stride, numbers.Number): + [s_h, s_w] = [params.stride] * 2 + elif len(params.stride) > 0: + s_h = params.stride_h if params.stride_h else params.stride[0] + s_w = params.stride_w if params.stride_w else params.stride[ + len(params.stride) - 1] + [p_h, p_w] = [0, 0] + if isinstance(params.pad, numbers.Number): + [p_h, p_w] = [params.pad] * 2 + elif len(params.pad) > 0: + p_h = params.pad_h if params.pad_h else params.pad[0] + p_w = params.pad_w if params.pad_w else params.pad[len(params.pad) - 1] + dila_h = dila_w = 1 if hasattr(params, 'dilation'): - if len(params.dilation) == 0: - dila_h = 1 - dila_w = 1 - elif len(params.dilation) == 1: - dila_h = params.dilation[0] - dila_w = params.dilation[0] - else: + dila_len = len(params.dilation) + if dila_len == 2: dila_h = params.dilation[0] dila_w = params.dilation[1] - else: - dila_h = 1 - dila_w = 1 - - if not isinstance(getattr(params, 'pad'), int): - if len(params.pad) == 0: - pad_h = 0 - pad_w = 0 - elif len(params.pad) == 1: - pad_h = params.pad[0] - pad_w = params.pad[0] - else: - pad_h, pad_w, = params.pad[0] - pad_w = params.pad[1] - if params.pad_h != 0 or params.pad_w != 0: - pad_h = params.pad_h - pad_w = params.pad_w - else: - if params.pad_h != 0 or params.pad_w != 0: - pad_h = params.pad_h - pad_w = params.pad_w - else: - pad_h = getattr(params, 'pad') - pad_w = getattr(params, 'pad') - - if not isinstance(getattr(params, 'kernel_size'), int): - if len(params.kernel_size) == 0: - kernel_h = 1 - kernel_w = 1 - elif len(params.kernel_size) == 1: - kernel_h = params.kernel_size[0] - kernel_w = params.kernel_size[0] - else: - kernel_h = params.kernel_size[0] - kernel_w = params.kernel_size[1] - if params.kernel_h != 0 or params.kernel_w != 0: - kernel_h = params.kernel_h - kernel_w = params.kernel_w - else: - if params.kernel_h != 0 or params.kernel_w != 0: - kernel_h = params.kernel_h - kernel_w = params.kernel_w - else: - kernel_h = getattr(params, 'kernel_size') - kernel_w = getattr(params, 'kernel_size') - if not isinstance(getattr(params, 'stride'), int): - if len(params.stride) == 0: - stride_h = 1 - stride_w = 1 - elif len(params.stride) == 1: - stride_h = params.stride[0] - stride_w = params.stride[0] - else: - stride_h = params.stride[0] - stride_w = params.stride[1] - if params.stride_h != 0 or params.stride_w != 0: - stride_h = params.stride_h - stride_w = params.stride_w - else: - if params.stride_h != 0 or params.stride_w != 0: - stride_h = params.stride_h - stride_w = params.stride_w + elif dila_len == 1: + dila_h = dila_w = params.dilation[0] else: - stride_h = getattr(params, 'stride') - stride_w = getattr(params, 'stride') - return dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w + assert dila_len == 0, "invalid length[%s] of dilation in convolution" % ( + dila_len) + return dila_h, dila_w, p_h, p_w, k_h, k_w, s_h, s_w -def get_filter_output_shape(i_h, i_w, params, round_func): - dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_params_w_h( +def get_strided_kernel_output_shape(params, input_shape, round_func): + i_h = input_shape[2] + i_w = input_shape[3] + dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters( params) o_h = (i_h + 2 * pad_h - (dila_h * (kernel_h - 1) + 1)) / float(stride_h) + 1 o_w = (i_w + 2 * pad_w - (dila_w * (kernel_w - 1) + 1)) / float(stride_w) + 1 - return (int(round_func(o_h)), int(round_func(o_w))) - - -def get_strided_kernel_output_shape(params, input_shape, round_func): - - o_h, o_w = get_filter_output_shape(input_shape[2], input_shape[3], params, - round_func) + o_h = int(round_func(o_h)) + o_w = int(round_func(o_w)) has_c_o = hasattr(params, 'num_output') c = params.num_output if has_c_o else input_shape[1] - return [[input_shape[0], c, o_h, o_w]] @@ -176,7 +130,9 @@ def shape_concat(layer, input_shape): output_shape = None for shape in input_shape: if output_shape is None: - output_shape = shape + output_shape = [] + for i in range(len(shape)): + output_shape.append(shape[i]) else: output_shape[axis] += shape[axis] return [output_shape] @@ -191,7 +147,9 @@ def shape_slice(layer, input_shape): points = [0] + points + [count] output_shape = [] for i in range(len(points)): - shape = inshape + shape = [] + for ii in range(len(inshape)): + shape.append(inshape[ii]) size = points[i + 1] - points[i] shape[axis] = size output_shape.append(shape) @@ -238,8 +196,8 @@ def shape_reshape(layer, input_shape): inshape = input_shape[0] params = layer.reshape_param - axis = params.axis if hasattr(params, axis) else 0 - num_axes = params.num_axes if hasattr(params, num_axes) else -1 + axis = params.axis if hasattr(params, 'axis') else 0 + num_axes = params.num_axes if hasattr(params, 'num_axes') else -1 if inshape[0] == -1: inshape[0] = 1 input_count = count(inshape) @@ -262,14 +220,14 @@ def shape_reshape(layer, input_shape): num_axes_replaced = end_axis - start_axis num_axes_retained = input_num_axes - num_axes_replaced - num_new_axes = len(shape['dim']) + num_new_axes = len(list(params.shape.dim)) outshape = [] for i in range(start_axis): outshape.append(inshape[i]) for i in range(num_new_axes): - outshape.append(shape['dim'][i]) + outshape.append(params.shape.dim[i]) for i in range(end_axis, input_num_axes): outshape.append(inshape[i]) @@ -281,7 +239,7 @@ def shape_reshape(layer, input_shape): copy_axes = [] constant_count = 1 for i in range(num_new_axes): - top_dim = shape['dim'][i] + top_dim = params.shape.dim[i] if top_dim == 0: copy_axes.append(i) copy_axis_index = start_axis + i @@ -297,24 +255,20 @@ def shape_reshape(layer, input_shape): l = inshape[0:start_axis] if len(l) > 0: explicit_count *= count(l) - l = inshape[end_axis:] if len(l) > 0: explicit_count *= count(l) - for i in range(len(copy_axes)): explicit_count *= outshape[start_axis + copy_axes[i]] - assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\ "must be divisible by product of the specified dimensions[%d] "\ % (input_count, explicit_count) - outshape[start_axis + inferred_axis] = input_count / explicit_count + outshape[start_axis + inferred_axis] = int(input_count / explicit_count) output_count = count(outshape) assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % ( output_count, input_count) - if inshape[0] == -1: - outshape[0] = -1 + outshape[0] = -1 return [outshape] @@ -345,18 +299,22 @@ def shape_crop(layer, input_shape): def shape_flatten(layer, input_shape): assert len(input_shape) == 1, "the number of flatten's inputs must be 1" + inshape = input_shape[0] params = layer.flatten_param start_axis = params.axis end_axis = params.end_axis if start_axis < 0: - start_axis += len(input_shape[0]) + start_axis += len(inshape) if end_axis < 0: - end_axis += len(input_shape[0]) + 1 + end_axis += len(inshape) + 1 assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\ % (start_axis, end_axis) - output_shape = [0] * (start_axis - 0) + [ - -1 - ] + [0] * (len(input_shape[0]) - end_axis) + output_shape = inshape[0:start_axis] + if len(inshape[start_axis:end_axis]) != 0: + flat_sz = reduce(lambda a, b: a * b, inshape[start_axis:end_axis]) + output_shape += [flat_sz] + output_shape += inshape[end_axis:len(inshape)] + output_shape[0] = -1 return [output_shape] diff --git a/x2paddle/optimizer/caffe_optimizer.py b/x2paddle/optimizer/caffe_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..52786ce5b3c03e5eda5e830262e96bf388f37e76 --- /dev/null +++ b/x2paddle/optimizer/caffe_optimizer.py @@ -0,0 +1,68 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.decoder.caffe_decoder import CaffeGraph +from x2paddle.core.util import * + + +class CaffeOptimizer(object): + layers_with_act = ['Convolution', 'Deconvolution', 'InnerProduct'] + activation_ops = ['ReLU', 'Sigmoid'] + + def __init__(self, mapper): + self.graph = mapper.graph + + def merge_bn_scale(self): + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + if node.layer_type == 'Scale': + parent_node = self.graph.get_bottom_node(node, idx=0) + if parent_node.layer_type == 'BatchNorm': + is_delete_node = True if len( + parent_node.outputs) == 1 else False + parent_fluid_layer = parent_node.fluid_code.layers[0] + input = parent_fluid_layer.inputs + parent_param_attr = parent_fluid_layer.param_attr + parent_param_attr['param_attr'] = string(node.layer_name + + '_scale') + parent_param_attr['bias_attr'] = string(node.layer_name + + '_offset') + if is_delete_node: + parent_node.fluid_code.clear() + node.fluid_code.clear() + node.fluid_code.add_layer("batch_norm", + inputs=input, + output=node, + param_attr=parent_param_attr) + + def merge_op_activation(self): + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + if node.layer_type in self.activation_ops: + parent_node = self.graph.get_bottom_node(node, idx=0) + if parent_node.layer_type in self.layers_with_act: + is_delete_node = True if len( + parent_node.outputs) == 1 else False + parent_fluid_layer = parent_node.fluid_code.layers[0] + input = parent_fluid_layer.inputs + parent_param_attr = parent_fluid_layer.param_attr + parent_param_attr['act'] = string(node.layer_type.lower()) + op = parent_fluid_layer.op + if is_delete_node: + parent_node.fluid_code.clear() + node.fluid_code.clear() + node.fluid_code.add_layer(op, + inputs=input, + output=node, + param_attr=parent_param_attr)